ritampatra commited on
Commit
e0b9cc5
·
verified ·
1 Parent(s): d154f38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -48
app.py CHANGED
@@ -1,60 +1,64 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModel
 
3
  import faiss
4
- import numpy as np
5
  import torch
6
  from PyPDF2 import PdfReader
7
 
8
- # Load PDF and extract text from it
9
- def load_document(file):
10
- pdf = PdfReader(file)
11
- text = ''
12
- for page_num in range(len(pdf.pages)):
13
- page = pdf.pages[page_num]
14
- text += page.extract_text()
15
  return text
16
 
17
- # Embed the document using Hugging Face model
18
- def embed_text(text):
19
- # Load tokenizer and model from Hugging Face
20
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
21
- model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")
22
-
23
- # Tokenize and embed text
24
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
25
- with torch.no_grad():
26
- outputs = model(**inputs)
27
- embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling to get the embedding
28
- return embeddings.squeeze().numpy()
29
-
30
- # Initialize FAISS index
31
- def initialize_faiss(embedding_size):
32
- index = faiss.IndexFlatL2(embedding_size)
33
- return index
34
 
35
- # Add document embeddings to FAISS index
36
- def add_to_index(index, embeddings):
37
- index.add(embeddings)
 
 
38
 
39
- # Search the FAISS index for the best matching text
40
- def search_index(index, query_embedding, texts, top_k=3):
41
- distances, indices = index.search(np.array([query_embedding]), top_k)
42
- return [texts[i] for i in indices[0]]
 
 
43
 
44
- # Process the document and build the FAISS index
45
- def process_document(file):
46
- text = load_document(file)
47
- chunks = [text[i:i + 512] for i in range(0, len(text), 512)] # Split text into chunks
48
- embeddings = np.vstack([embed_text(chunk) for chunk in chunks]) # Create embeddings for each chunk
49
- faiss_index = initialize_faiss(embeddings.shape[1]) # Initialize FAISS index
50
- add_to_index(faiss_index, embeddings) # Add embeddings to FAISS index
51
- return faiss_index, chunks
 
 
 
 
 
 
 
52
 
53
- # Answer query by searching FAISS index
54
- def query_document(query, faiss_index, document_chunks):
55
- query_embedding = embed_text(query) # Embed query
56
- results = search_index(faiss_index, query_embedding, document_chunks) # Search for the best matching chunks
57
- return "\n\n".join(results) # Return the matching document parts
 
 
 
 
 
 
 
 
58
 
59
  # Gradio interface
60
  def chatbot_interface():
@@ -64,7 +68,7 @@ def chatbot_interface():
64
  # Function to handle document upload
65
  def upload_file(file):
66
  nonlocal faiss_index, document_chunks
67
- faiss_index, document_chunks = process_document(file)
68
  return "Document uploaded and indexed. You can now ask questions."
69
 
70
  # Function to handle user queries
@@ -76,7 +80,7 @@ def chatbot_interface():
76
  # Gradio UI
77
  upload = gr.File(label="Upload a PDF document")
78
  question = gr.Textbox(label="Ask a question about the document")
79
- answer = gr.Textbox(label="Answer", readonly=True)
80
 
81
  # Gradio app layout
82
  with gr.Blocks() as demo:
 
1
  import gradio as gr
2
+ import os
3
+ from transformers import pipeline
4
  import faiss
 
5
  import torch
6
  from PyPDF2 import PdfReader
7
 
8
+ # Function to extract text from a PDF file
9
+ def extract_text_from_pdf(pdf_file):
10
+ pdf_reader = PdfReader(pdf_file)
11
+ text = ""
12
+ for page_num in range(len(pdf_reader.pages)):
13
+ text += pdf_reader.pages[page_num].extract_text()
 
14
  return text
15
 
16
+ # Function to split text into chunks
17
+ def split_text_into_chunks(text, chunk_size=500):
18
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Function to embed text chunks using a pre-trained model
21
+ def embed_text_chunks(text_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
22
+ embedder = pipeline("feature-extraction", model=model_name)
23
+ embeddings = [embedder(chunk)[0][0] for chunk in text_chunks]
24
+ return torch.tensor(embeddings)
25
 
26
+ # Function to build FAISS index for document chunks
27
+ def build_faiss_index(embeddings):
28
+ d = embeddings.shape[1] # Dimension of embeddings
29
+ index = faiss.IndexFlatL2(d)
30
+ index.add(embeddings.numpy())
31
+ return index
32
 
33
+ # Function to process uploaded document
34
+ def process_document(pdf_file):
35
+ # Extract text from the PDF
36
+ text = extract_text_from_pdf(pdf_file)
37
+
38
+ # Split text into chunks
39
+ document_chunks = split_text_into_chunks(text)
40
+
41
+ # Embed document chunks
42
+ embeddings = embed_text_chunks(document_chunks)
43
+
44
+ # Build FAISS index
45
+ faiss_index = build_faiss_index(embeddings)
46
+
47
+ return faiss_index, document_chunks
48
 
49
+ # Function to query the FAISS index for a question
50
+ def query_document(query, faiss_index, document_chunks, model_name="sentence-transformers/all-MiniLM-L6-v2"):
51
+ embedder = pipeline("feature-extraction", model=model_name)
52
+
53
+ # Embed the query
54
+ query_embedding = embedder(query)[0][0]
55
+ query_embedding = torch.tensor(query_embedding).unsqueeze(0).numpy()
56
+
57
+ # Search the FAISS index
58
+ _, I = faiss_index.search(query_embedding, k=1)
59
+
60
+ # Get the most relevant chunk
61
+ return document_chunks[I[0][0]]
62
 
63
  # Gradio interface
64
  def chatbot_interface():
 
68
  # Function to handle document upload
69
  def upload_file(file):
70
  nonlocal faiss_index, document_chunks
71
+ faiss_index, document_chunks = process_document(file.name)
72
  return "Document uploaded and indexed. You can now ask questions."
73
 
74
  # Function to handle user queries
 
80
  # Gradio UI
81
  upload = gr.File(label="Upload a PDF document")
82
  question = gr.Textbox(label="Ask a question about the document")
83
+ answer = gr.Textbox(label="Answer", interactive=False) # Updated to interactive=False
84
 
85
  # Gradio app layout
86
  with gr.Blocks() as demo: