Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
-
import torch
|
4 |
import gradio as gr
|
5 |
import spaces # Required for GPU-enabled Spaces
|
6 |
|
@@ -12,40 +11,32 @@ from langchain.chains import RetrievalQAWithSourcesChain
|
|
12 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
13 |
from langchain.embeddings import HuggingFaceEmbeddings
|
14 |
|
15 |
-
#
|
16 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
-
|
18 |
-
# Initialize Pinecone
|
19 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
20 |
INDEX_NAME = "ragreader"
|
21 |
|
22 |
-
#
|
23 |
@spaces.GPU
|
24 |
-
def
|
25 |
-
|
|
|
26 |
model_name="BAAI/bge-large-en-v1.5",
|
27 |
model_kwargs={"device": device}
|
28 |
)
|
29 |
-
|
30 |
-
embeddings = init_embeddings()
|
31 |
-
|
32 |
-
# GPU-decorated document processing function
|
33 |
-
@spaces.GPU
|
34 |
-
def process_documents(files):
|
35 |
docs = []
|
36 |
for file in files:
|
37 |
-
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
|
|
38 |
tmp.write(file.read())
|
|
|
39 |
loader = PyPDFLoader(tmp.name)
|
40 |
docs.extend(loader.load())
|
41 |
os.unlink(tmp.name)
|
42 |
-
|
43 |
text_splitter = RecursiveCharacterTextSplitter(
|
44 |
chunk_size=1000,
|
45 |
chunk_overlap=200
|
46 |
)
|
47 |
split_docs = text_splitter.split_documents(docs)
|
48 |
-
|
49 |
PineconeVectorStore.from_documents(
|
50 |
documents=split_docs,
|
51 |
embedding=embeddings,
|
@@ -53,19 +44,20 @@ def process_documents(files):
|
|
53 |
)
|
54 |
return "Documents processed and stored."
|
55 |
|
56 |
-
#
|
57 |
def init_qa_chain():
|
|
|
|
|
58 |
llm = ChatMistralAI(
|
59 |
model="mistral-tiny",
|
60 |
temperature=0.3,
|
61 |
mistral_api_key=os.getenv("MISTRAL_API_KEY")
|
62 |
)
|
63 |
-
|
64 |
vector_store = PineconeVectorStore(
|
65 |
index_name=INDEX_NAME,
|
66 |
-
embedding=
|
67 |
)
|
68 |
-
|
69 |
return RetrievalQAWithSourcesChain.from_chain_type(
|
70 |
llm=llm,
|
71 |
chain_type="stuff",
|
@@ -73,13 +65,13 @@ def init_qa_chain():
|
|
73 |
return_source_documents=True
|
74 |
)
|
75 |
|
|
|
76 |
qa_chain = None
|
77 |
|
78 |
-
# Build the Gradio UI
|
79 |
with gr.Blocks() as demo:
|
80 |
gr.Markdown("## RAG Chatbot - PDF Reader")
|
81 |
|
82 |
-
file_input = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload PDFs")
|
83 |
process_btn = gr.Button("Process Documents")
|
84 |
process_output = gr.Textbox(label="Processing Status")
|
85 |
|
@@ -90,21 +82,26 @@ with gr.Blocks() as demo:
|
|
90 |
|
91 |
def process_wrapper(files):
|
92 |
global qa_chain
|
|
|
|
|
93 |
msg = process_documents(files)
|
94 |
qa_chain = init_qa_chain()
|
95 |
return msg
|
96 |
|
97 |
def chat_with_docs(question):
|
|
|
98 |
if not qa_chain:
|
99 |
return "Please upload and process documents first.", ""
|
|
|
|
|
100 |
response = qa_chain.invoke({"question": question}, return_only_outputs=True)
|
101 |
sources = "\n".join(
|
102 |
f"{os.path.basename(doc.metadata.get('source', 'unknown'))} (Page {doc.metadata.get('page', 'N/A')})"
|
103 |
for doc in response.get('source_documents', [])[:3]
|
104 |
)
|
105 |
-
return response
|
106 |
|
107 |
-
process_btn.click(fn=process_wrapper, inputs=file_input, outputs=process_output)
|
108 |
chat_btn.click(fn=chat_with_docs, inputs=chat_input, outputs=[chat_output, source_output])
|
109 |
|
110 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
import tempfile
|
|
|
3 |
import gradio as gr
|
4 |
import spaces # Required for GPU-enabled Spaces
|
5 |
|
|
|
11 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
from langchain.embeddings import HuggingFaceEmbeddings
|
13 |
|
14 |
+
# Initialize Pinecone (safe, does not use CUDA)
|
|
|
|
|
|
|
15 |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
16 |
INDEX_NAME = "ragreader"
|
17 |
|
18 |
+
# This function does all GPU work: embedding creation, document processing, and vector store population
|
19 |
@spaces.GPU
|
20 |
+
def process_documents(files):
|
21 |
+
device = "cuda" if hasattr(__import__('torch'), 'cuda') and __import__('torch').cuda.is_available() else "cpu"
|
22 |
+
embeddings = HuggingFaceEmbeddings(
|
23 |
model_name="BAAI/bge-large-en-v1.5",
|
24 |
model_kwargs={"device": device}
|
25 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
docs = []
|
27 |
for file in files:
|
28 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
29 |
+
file.seek(0)
|
30 |
tmp.write(file.read())
|
31 |
+
tmp.flush()
|
32 |
loader = PyPDFLoader(tmp.name)
|
33 |
docs.extend(loader.load())
|
34 |
os.unlink(tmp.name)
|
|
|
35 |
text_splitter = RecursiveCharacterTextSplitter(
|
36 |
chunk_size=1000,
|
37 |
chunk_overlap=200
|
38 |
)
|
39 |
split_docs = text_splitter.split_documents(docs)
|
|
|
40 |
PineconeVectorStore.from_documents(
|
41 |
documents=split_docs,
|
42 |
embedding=embeddings,
|
|
|
44 |
)
|
45 |
return "Documents processed and stored."
|
46 |
|
47 |
+
# This function creates the QA chain (CPU only)
|
48 |
def init_qa_chain():
|
49 |
+
# Embeddings must be created inside a GPU function, so we do not re-create here.
|
50 |
+
# PineconeVectorStore uses the embeddings stored in Pinecone.
|
51 |
llm = ChatMistralAI(
|
52 |
model="mistral-tiny",
|
53 |
temperature=0.3,
|
54 |
mistral_api_key=os.getenv("MISTRAL_API_KEY")
|
55 |
)
|
56 |
+
# Pass None for embeddings since vectors are already in Pinecone
|
57 |
vector_store = PineconeVectorStore(
|
58 |
index_name=INDEX_NAME,
|
59 |
+
embedding=None
|
60 |
)
|
|
|
61 |
return RetrievalQAWithSourcesChain.from_chain_type(
|
62 |
llm=llm,
|
63 |
chain_type="stuff",
|
|
|
65 |
return_source_documents=True
|
66 |
)
|
67 |
|
68 |
+
# State: store the QA chain after processing
|
69 |
qa_chain = None
|
70 |
|
|
|
71 |
with gr.Blocks() as demo:
|
72 |
gr.Markdown("## RAG Chatbot - PDF Reader")
|
73 |
|
74 |
+
file_input = gr.File(file_types=[".pdf"], file_count="multiple", type="file", label="Upload PDFs")
|
75 |
process_btn = gr.Button("Process Documents")
|
76 |
process_output = gr.Textbox(label="Processing Status")
|
77 |
|
|
|
82 |
|
83 |
def process_wrapper(files):
|
84 |
global qa_chain
|
85 |
+
if not files or len(files) == 0:
|
86 |
+
return "Please upload at least one PDF."
|
87 |
msg = process_documents(files)
|
88 |
qa_chain = init_qa_chain()
|
89 |
return msg
|
90 |
|
91 |
def chat_with_docs(question):
|
92 |
+
global qa_chain
|
93 |
if not qa_chain:
|
94 |
return "Please upload and process documents first.", ""
|
95 |
+
if not question.strip():
|
96 |
+
return "Please enter a question.", ""
|
97 |
response = qa_chain.invoke({"question": question}, return_only_outputs=True)
|
98 |
sources = "\n".join(
|
99 |
f"{os.path.basename(doc.metadata.get('source', 'unknown'))} (Page {doc.metadata.get('page', 'N/A')})"
|
100 |
for doc in response.get('source_documents', [])[:3]
|
101 |
)
|
102 |
+
return response.get('answer', "No answer found."), sources
|
103 |
|
104 |
+
process_btn.click(fn=process_wrapper, inputs=[file_input], outputs=process_output)
|
105 |
chat_btn.click(fn=chat_with_docs, inputs=chat_input, outputs=[chat_output, source_output])
|
106 |
|
107 |
if __name__ == "__main__":
|