ritampatra's picture
Update app.py
33311a4 verified
import gradio as gr
import os
from transformers import pipeline
import faiss
import torch
from PyPDF2 import PdfReader
# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_file):
pdf_reader = PdfReader(pdf_file)
text = ""
for page_num in range(len(pdf_reader.pages)):
text += pdf_reader.pages[page_num].extract_text()
return text
# Function to split text into chunks
def split_text_into_chunks(text, chunk_size=500):
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
# Function to embed text chunks using a pre-trained model
def embed_text_chunks(text_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
embedder = pipeline("feature-extraction", model=model_name)
embeddings = [embedder(chunk)[0][0] for chunk in text_chunks]
return torch.tensor(embeddings)
# Function to build FAISS index for document chunks
def build_faiss_index(embeddings):
d = embeddings.shape[1] # Dimension of embeddings
index = faiss.IndexFlatL2(d)
index.add(embeddings.numpy())
return index
# Function to process uploaded document
def process_document(pdf_file):
# Extract text from the PDF
text = extract_text_from_pdf(pdf_file)
# Split text into chunks
document_chunks = split_text_into_chunks(text)
# Embed document chunks
embeddings = embed_text_chunks(document_chunks)
# Build FAISS index
faiss_index = build_faiss_index(embeddings)
return faiss_index, document_chunks
# Function to query the FAISS index for a question
def query_document(query, faiss_index, document_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
embedder = pipeline("feature-extraction", model=model_name)
# Embed the query
query_embedding = embedder(query)[0][0]
query_embedding = torch.tensor(query_embedding).unsqueeze(0).numpy()
# Search the FAISS index
_, I = faiss_index.search(query_embedding, k=1)
# Get the most relevant chunk
return document_chunks[I[0][0]]
# Gradio interface
def chatbot_interface():
with gr.Blocks() as demo:
state = gr.State() # This state will store the FAISS index and document chunks
# Function to handle document upload
def upload_file(file):
faiss_index, document_chunks = process_document(file.name)
state.value = (faiss_index, document_chunks)
return "Document uploaded and indexed. You can now ask questions."
# Function to handle user queries
def ask_question(query):
if state.value is not None:
faiss_index, document_chunks = state.value
return query_document(query, faiss_index, document_chunks)
return "Please upload a document first."
# Gradio UI
upload = gr.File(label="Upload a PDF document")
question = gr.Textbox(label="Ask a question about the document")
answer = gr.Textbox(label="Answer", interactive=False)
# Layout for Gradio app
gr.Markdown("# Document Chatbot")
upload.change(upload_file, inputs=upload, outputs=None) # Trigger file upload
question.submit(ask_question, inputs=question, outputs=answer) # Trigger question submission
demo.launch()
# Start the chatbot interface
if __name__ == "__main__":
chatbot_interface()