Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from transformers import pipeline | |
import faiss | |
import torch | |
from PyPDF2 import PdfReader | |
# Function to extract text from a PDF file | |
def extract_text_from_pdf(pdf_file): | |
pdf_reader = PdfReader(pdf_file) | |
text = "" | |
for page_num in range(len(pdf_reader.pages)): | |
text += pdf_reader.pages[page_num].extract_text() | |
return text | |
# Function to split text into chunks | |
def split_text_into_chunks(text, chunk_size=500): | |
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
# Function to embed text chunks using a pre-trained model | |
def embed_text_chunks(text_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"): | |
embedder = pipeline("feature-extraction", model=model_name) | |
embeddings = [embedder(chunk)[0][0] for chunk in text_chunks] | |
return torch.tensor(embeddings) | |
# Function to build FAISS index for document chunks | |
def build_faiss_index(embeddings): | |
d = embeddings.shape[1] # Dimension of embeddings | |
index = faiss.IndexFlatL2(d) | |
index.add(embeddings.numpy()) | |
return index | |
# Function to process uploaded document | |
def process_document(pdf_file): | |
# Extract text from the PDF | |
text = extract_text_from_pdf(pdf_file) | |
# Split text into chunks | |
document_chunks = split_text_into_chunks(text) | |
# Embed document chunks | |
embeddings = embed_text_chunks(document_chunks) | |
# Build FAISS index | |
faiss_index = build_faiss_index(embeddings) | |
return faiss_index, document_chunks | |
# Function to query the FAISS index for a question | |
def query_document(query, faiss_index, document_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"): | |
embedder = pipeline("feature-extraction", model=model_name) | |
# Embed the query | |
query_embedding = embedder(query)[0][0] | |
query_embedding = torch.tensor(query_embedding).unsqueeze(0).numpy() | |
# Search the FAISS index | |
_, I = faiss_index.search(query_embedding, k=1) | |
# Get the most relevant chunk | |
return document_chunks[I[0][0]] | |
# Gradio interface | |
def chatbot_interface(): | |
with gr.Blocks() as demo: | |
state = gr.State() # This state will store the FAISS index and document chunks | |
# Function to handle document upload | |
def upload_file(file): | |
faiss_index, document_chunks = process_document(file.name) | |
state.value = (faiss_index, document_chunks) | |
return "Document uploaded and indexed. You can now ask questions." | |
# Function to handle user queries | |
def ask_question(query): | |
if state.value is not None: | |
faiss_index, document_chunks = state.value | |
return query_document(query, faiss_index, document_chunks) | |
return "Please upload a document first." | |
# Gradio UI | |
upload = gr.File(label="Upload a PDF document") | |
question = gr.Textbox(label="Ask a question about the document") | |
answer = gr.Textbox(label="Answer", interactive=False) | |
# Layout for Gradio app | |
gr.Markdown("# Document Chatbot") | |
upload.change(upload_file, inputs=upload, outputs=None) # Trigger file upload | |
question.submit(ask_question, inputs=question, outputs=answer) # Trigger question submission | |
demo.launch() | |
# Start the chatbot interface | |
if __name__ == "__main__": | |
chatbot_interface() | |