Talha812's picture
Update app.py
ab1a53e verified
# app.py
import os
import json
import faiss
import numpy as np
import PyPDF2
import requests
import streamlit as st
from groq import Groq
# Constants
PDF_URL = "https://drive.google.com/uc?export=download&id=1YWX-RYxgtcKO1QETnz1N3rboZUhRZwcH"
VECTOR_DIM = 768
CHUNK_SIZE = 512
# Function to download and extract text from the PDF
def extract_text_from_pdf(url):
response = requests.get(url)
with open("document.pdf", "wb") as f:
f.write(response.content)
with open("document.pdf", "rb") as f:
reader = PyPDF2.PdfReader(f)
text = "\n".join(page.extract_text() for page in reader.pages)
return text
# Function to split text into chunks
def create_chunks(text, chunk_size):
words = text.split()
chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
return chunks
# Function to create FAISS vector store
def create_faiss_index(chunks, vector_dim):
# Check if GPU is available and use it
if faiss.get_num_gpus() > 0:
st.write("Using GPU for FAISS indexing.")
resource = faiss.StandardGpuResources() # Initialize GPU resources
index_flat = faiss.IndexFlatL2(vector_dim)
index = faiss.index_cpu_to_gpu(resource, 0, index_flat)
else:
st.write("Using CPU for FAISS indexing.")
index = faiss.IndexFlatL2(vector_dim)
embeddings = np.random.rand(len(chunks), vector_dim).astype('float32') # Replace with real embeddings
index.add(embeddings)
return index, embeddings
# Initialize Groq API client
def get_groq_client():
return Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Query Groq model
def query_model(client, question):
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": question}],
model="llama-3.3-70b-versatile",
)
return chat_completion.choices[0].message.content
# Streamlit app
def main():
st.title("RAG-Based ChatBot")
# Step 1: Extract text from the document
st.header("Step 1: Extract Text")
if st.button("Extract Text from PDF"):
text = extract_text_from_pdf(PDF_URL)
st.session_state["text"] = text
st.success("Text extracted successfully!")
# Step 2: Chunk the text
st.header("Step 2: Create Chunks")
if "text" in st.session_state and st.button("Create Chunks"):
chunks = create_chunks(st.session_state["text"], CHUNK_SIZE)
st.session_state["chunks"] = chunks
st.success(f"Created {len(chunks)} chunks.")
# Step 3: Create FAISS index
st.header("Step 3: Create Vector Database")
if "chunks" in st.session_state and st.button("Create Vector Database"):
index, embeddings = create_faiss_index(st.session_state["chunks"], VECTOR_DIM)
st.session_state["index"] = index
st.success("FAISS vector database created.")
# Step 4: Ask a question
st.header("Step 4: Query the Model")
question = st.text_input("Ask a question about the document:")
if question and "index" in st.session_state:
client = get_groq_client()
answer = query_model(client, question)
st.write("Answer:", answer)
if __name__ == "__main__":
main()