gauri-sharan commited on
Commit
8a4fa5e
·
verified ·
1 Parent(s): 7abd8f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -25
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import tempfile
3
- import torch
4
  import gradio as gr
5
  import spaces # Required for GPU-enabled Spaces
6
 
@@ -12,40 +11,32 @@ from langchain.chains import RetrievalQAWithSourcesChain
12
  from langchain_text_splitters import RecursiveCharacterTextSplitter
13
  from langchain.embeddings import HuggingFaceEmbeddings
14
 
15
- # Set device for embeddings
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
-
18
- # Initialize Pinecone
19
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
20
  INDEX_NAME = "ragreader"
21
 
22
- # GPU-decorated function to load HuggingFace embeddings on GPU
23
  @spaces.GPU
24
- def init_embeddings():
25
- return HuggingFaceEmbeddings(
 
26
  model_name="BAAI/bge-large-en-v1.5",
27
  model_kwargs={"device": device}
28
  )
29
-
30
- embeddings = init_embeddings()
31
-
32
- # GPU-decorated document processing function
33
- @spaces.GPU
34
- def process_documents(files):
35
  docs = []
36
  for file in files:
37
- with tempfile.NamedTemporaryFile(delete=False) as tmp:
 
38
  tmp.write(file.read())
 
39
  loader = PyPDFLoader(tmp.name)
40
  docs.extend(loader.load())
41
  os.unlink(tmp.name)
42
-
43
  text_splitter = RecursiveCharacterTextSplitter(
44
  chunk_size=1000,
45
  chunk_overlap=200
46
  )
47
  split_docs = text_splitter.split_documents(docs)
48
-
49
  PineconeVectorStore.from_documents(
50
  documents=split_docs,
51
  embedding=embeddings,
@@ -53,19 +44,20 @@ def process_documents(files):
53
  )
54
  return "Documents processed and stored."
55
 
56
- # Initialize the RetrievalQA chain (no GPU decoration needed here)
57
  def init_qa_chain():
 
 
58
  llm = ChatMistralAI(
59
  model="mistral-tiny",
60
  temperature=0.3,
61
  mistral_api_key=os.getenv("MISTRAL_API_KEY")
62
  )
63
-
64
  vector_store = PineconeVectorStore(
65
  index_name=INDEX_NAME,
66
- embedding=embeddings
67
  )
68
-
69
  return RetrievalQAWithSourcesChain.from_chain_type(
70
  llm=llm,
71
  chain_type="stuff",
@@ -73,13 +65,13 @@ def init_qa_chain():
73
  return_source_documents=True
74
  )
75
 
 
76
  qa_chain = None
77
 
78
- # Build the Gradio UI
79
  with gr.Blocks() as demo:
80
  gr.Markdown("## RAG Chatbot - PDF Reader")
81
 
82
- file_input = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload PDFs")
83
  process_btn = gr.Button("Process Documents")
84
  process_output = gr.Textbox(label="Processing Status")
85
 
@@ -90,21 +82,26 @@ with gr.Blocks() as demo:
90
 
91
  def process_wrapper(files):
92
  global qa_chain
 
 
93
  msg = process_documents(files)
94
  qa_chain = init_qa_chain()
95
  return msg
96
 
97
  def chat_with_docs(question):
 
98
  if not qa_chain:
99
  return "Please upload and process documents first.", ""
 
 
100
  response = qa_chain.invoke({"question": question}, return_only_outputs=True)
101
  sources = "\n".join(
102
  f"{os.path.basename(doc.metadata.get('source', 'unknown'))} (Page {doc.metadata.get('page', 'N/A')})"
103
  for doc in response.get('source_documents', [])[:3]
104
  )
105
- return response['answer'], sources
106
 
107
- process_btn.click(fn=process_wrapper, inputs=file_input, outputs=process_output)
108
  chat_btn.click(fn=chat_with_docs, inputs=chat_input, outputs=[chat_output, source_output])
109
 
110
  if __name__ == "__main__":
 
1
  import os
2
  import tempfile
 
3
  import gradio as gr
4
  import spaces # Required for GPU-enabled Spaces
5
 
 
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
  from langchain.embeddings import HuggingFaceEmbeddings
13
 
14
+ # Initialize Pinecone (safe, does not use CUDA)
 
 
 
15
  pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
16
  INDEX_NAME = "ragreader"
17
 
18
+ # This function does all GPU work: embedding creation, document processing, and vector store population
19
  @spaces.GPU
20
+ def process_documents(files):
21
+ device = "cuda" if hasattr(__import__('torch'), 'cuda') and __import__('torch').cuda.is_available() else "cpu"
22
+ embeddings = HuggingFaceEmbeddings(
23
  model_name="BAAI/bge-large-en-v1.5",
24
  model_kwargs={"device": device}
25
  )
 
 
 
 
 
 
26
  docs = []
27
  for file in files:
28
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
29
+ file.seek(0)
30
  tmp.write(file.read())
31
+ tmp.flush()
32
  loader = PyPDFLoader(tmp.name)
33
  docs.extend(loader.load())
34
  os.unlink(tmp.name)
 
35
  text_splitter = RecursiveCharacterTextSplitter(
36
  chunk_size=1000,
37
  chunk_overlap=200
38
  )
39
  split_docs = text_splitter.split_documents(docs)
 
40
  PineconeVectorStore.from_documents(
41
  documents=split_docs,
42
  embedding=embeddings,
 
44
  )
45
  return "Documents processed and stored."
46
 
47
+ # This function creates the QA chain (CPU only)
48
  def init_qa_chain():
49
+ # Embeddings must be created inside a GPU function, so we do not re-create here.
50
+ # PineconeVectorStore uses the embeddings stored in Pinecone.
51
  llm = ChatMistralAI(
52
  model="mistral-tiny",
53
  temperature=0.3,
54
  mistral_api_key=os.getenv("MISTRAL_API_KEY")
55
  )
56
+ # Pass None for embeddings since vectors are already in Pinecone
57
  vector_store = PineconeVectorStore(
58
  index_name=INDEX_NAME,
59
+ embedding=None
60
  )
 
61
  return RetrievalQAWithSourcesChain.from_chain_type(
62
  llm=llm,
63
  chain_type="stuff",
 
65
  return_source_documents=True
66
  )
67
 
68
+ # State: store the QA chain after processing
69
  qa_chain = None
70
 
 
71
  with gr.Blocks() as demo:
72
  gr.Markdown("## RAG Chatbot - PDF Reader")
73
 
74
+ file_input = gr.File(file_types=[".pdf"], file_count="multiple", type="file", label="Upload PDFs")
75
  process_btn = gr.Button("Process Documents")
76
  process_output = gr.Textbox(label="Processing Status")
77
 
 
82
 
83
  def process_wrapper(files):
84
  global qa_chain
85
+ if not files or len(files) == 0:
86
+ return "Please upload at least one PDF."
87
  msg = process_documents(files)
88
  qa_chain = init_qa_chain()
89
  return msg
90
 
91
  def chat_with_docs(question):
92
+ global qa_chain
93
  if not qa_chain:
94
  return "Please upload and process documents first.", ""
95
+ if not question.strip():
96
+ return "Please enter a question.", ""
97
  response = qa_chain.invoke({"question": question}, return_only_outputs=True)
98
  sources = "\n".join(
99
  f"{os.path.basename(doc.metadata.get('source', 'unknown'))} (Page {doc.metadata.get('page', 'N/A')})"
100
  for doc in response.get('source_documents', [])[:3]
101
  )
102
+ return response.get('answer', "No answer found."), sources
103
 
104
+ process_btn.click(fn=process_wrapper, inputs=[file_input], outputs=process_output)
105
  chat_btn.click(fn=chat_with_docs, inputs=chat_input, outputs=[chat_output, source_output])
106
 
107
  if __name__ == "__main__":