Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import json | |
import faiss | |
import numpy as np | |
import PyPDF2 | |
import requests | |
import streamlit as st | |
from groq import Groq | |
# Constants | |
PDF_URL = "https://drive.google.com/uc?export=download&id=1YWX-RYxgtcKO1QETnz1N3rboZUhRZwcH" | |
VECTOR_DIM = 768 | |
CHUNK_SIZE = 512 | |
# Function to download and extract text from the PDF | |
def extract_text_from_pdf(url): | |
response = requests.get(url) | |
with open("document.pdf", "wb") as f: | |
f.write(response.content) | |
with open("document.pdf", "rb") as f: | |
reader = PyPDF2.PdfReader(f) | |
text = "\n".join(page.extract_text() for page in reader.pages) | |
return text | |
# Function to split text into chunks | |
def create_chunks(text, chunk_size): | |
words = text.split() | |
chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
return chunks | |
# Function to create FAISS vector store | |
def create_faiss_index(chunks, vector_dim): | |
# Check if GPU is available and use it | |
if faiss.get_num_gpus() > 0: | |
st.write("Using GPU for FAISS indexing.") | |
resource = faiss.StandardGpuResources() # Initialize GPU resources | |
index_flat = faiss.IndexFlatL2(vector_dim) | |
index = faiss.index_cpu_to_gpu(resource, 0, index_flat) | |
else: | |
st.write("Using CPU for FAISS indexing.") | |
index = faiss.IndexFlatL2(vector_dim) | |
embeddings = np.random.rand(len(chunks), vector_dim).astype('float32') # Replace with real embeddings | |
index.add(embeddings) | |
return index, embeddings | |
# Initialize Groq API client | |
def get_groq_client(): | |
return Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
# Query Groq model | |
def query_model(client, question): | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": question}], | |
model="llama-3.3-70b-versatile", | |
) | |
return chat_completion.choices[0].message.content | |
# Streamlit app | |
def main(): | |
st.title("RAG-Based ChatBot") | |
# Step 1: Extract text from the document | |
st.header("Step 1: Extract Text") | |
if st.button("Extract Text from PDF"): | |
text = extract_text_from_pdf(PDF_URL) | |
st.session_state["text"] = text | |
st.success("Text extracted successfully!") | |
# Step 2: Chunk the text | |
st.header("Step 2: Create Chunks") | |
if "text" in st.session_state and st.button("Create Chunks"): | |
chunks = create_chunks(st.session_state["text"], CHUNK_SIZE) | |
st.session_state["chunks"] = chunks | |
st.success(f"Created {len(chunks)} chunks.") | |
# Step 3: Create FAISS index | |
st.header("Step 3: Create Vector Database") | |
if "chunks" in st.session_state and st.button("Create Vector Database"): | |
index, embeddings = create_faiss_index(st.session_state["chunks"], VECTOR_DIM) | |
st.session_state["index"] = index | |
st.success("FAISS vector database created.") | |
# Step 4: Ask a question | |
st.header("Step 4: Query the Model") | |
question = st.text_input("Ask a question about the document:") | |
if question and "index" in st.session_state: | |
client = get_groq_client() | |
answer = query_model(client, question) | |
st.write("Answer:", answer) | |
if __name__ == "__main__": | |
main() | |