File size: 3,380 Bytes
8c3d4ad
e0b9cc5
 
879e1ad
5ab0b92
879e1ad
8c3d4ad
e0b9cc5
 
 
 
 
 
879e1ad
8c3d4ad
e0b9cc5
 
 
5ab0b92
e0b9cc5
 
 
 
 
5ab0b92
e0b9cc5
 
 
 
 
 
8c3d4ad
e0b9cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c3d4ad
e0b9cc5
 
 
 
 
 
 
 
 
 
 
 
 
879e1ad
 
8c3d4ad
caa7c9a
33311a4
8c3d4ad
caa7c9a
 
 
33311a4
caa7c9a
8c3d4ad
caa7c9a
 
33311a4
 
caa7c9a
 
8c3d4ad
caa7c9a
 
 
33311a4
8c3d4ad
caa7c9a
8c3d4ad
caa7c9a
 
879e1ad
8c3d4ad
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import os
from transformers import pipeline
import faiss
import torch
from PyPDF2 import PdfReader

# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_file):
    pdf_reader = PdfReader(pdf_file)
    text = ""
    for page_num in range(len(pdf_reader.pages)):
        text += pdf_reader.pages[page_num].extract_text()
    return text

# Function to split text into chunks
def split_text_into_chunks(text, chunk_size=500):
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

# Function to embed text chunks using a pre-trained model
def embed_text_chunks(text_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
    embedder = pipeline("feature-extraction", model=model_name)
    embeddings = [embedder(chunk)[0][0] for chunk in text_chunks]
    return torch.tensor(embeddings)

# Function to build FAISS index for document chunks
def build_faiss_index(embeddings):
    d = embeddings.shape[1]  # Dimension of embeddings
    index = faiss.IndexFlatL2(d)
    index.add(embeddings.numpy())
    return index

# Function to process uploaded document
def process_document(pdf_file):
    # Extract text from the PDF
    text = extract_text_from_pdf(pdf_file)
    
    # Split text into chunks
    document_chunks = split_text_into_chunks(text)
    
    # Embed document chunks
    embeddings = embed_text_chunks(document_chunks)
    
    # Build FAISS index
    faiss_index = build_faiss_index(embeddings)
    
    return faiss_index, document_chunks

# Function to query the FAISS index for a question
def query_document(query, faiss_index, document_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
    embedder = pipeline("feature-extraction", model=model_name)
    
    # Embed the query
    query_embedding = embedder(query)[0][0]
    query_embedding = torch.tensor(query_embedding).unsqueeze(0).numpy()
    
    # Search the FAISS index
    _, I = faiss_index.search(query_embedding, k=1)
    
    # Get the most relevant chunk
    return document_chunks[I[0][0]]

# Gradio interface
def chatbot_interface():
    with gr.Blocks() as demo:
        state = gr.State()  # This state will store the FAISS index and document chunks

        # Function to handle document upload
        def upload_file(file):
            faiss_index, document_chunks = process_document(file.name)
            state.value = (faiss_index, document_chunks)
            return "Document uploaded and indexed. You can now ask questions."

        # Function to handle user queries
        def ask_question(query):
            if state.value is not None:
                faiss_index, document_chunks = state.value
                return query_document(query, faiss_index, document_chunks)
            return "Please upload a document first."

        # Gradio UI
        upload = gr.File(label="Upload a PDF document")
        question = gr.Textbox(label="Ask a question about the document")
        answer = gr.Textbox(label="Answer", interactive=False)

        # Layout for Gradio app
        gr.Markdown("# Document Chatbot")
        upload.change(upload_file, inputs=upload, outputs=None)  # Trigger file upload
        question.submit(ask_question, inputs=question, outputs=answer)  # Trigger question submission

    demo.launch()

# Start the chatbot interface
if __name__ == "__main__":
    chatbot_interface()