rag-advanced-2 / app.py
gauri-sharan's picture
Update app.py
8a4fa5e verified
import os
import tempfile
import gradio as gr
import spaces # Required for GPU-enabled Spaces
from pinecone import Pinecone
from langchain_pinecone import PineconeVectorStore
from langchain_community.document_loaders import PyPDFLoader
from langchain_mistralai import ChatMistralAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
# Initialize Pinecone (safe, does not use CUDA)
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
INDEX_NAME = "ragreader"
# This function does all GPU work: embedding creation, document processing, and vector store population
@spaces.GPU
def process_documents(files):
device = "cuda" if hasattr(__import__('torch'), 'cuda') and __import__('torch').cuda.is_available() else "cpu"
embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={"device": device}
)
docs = []
for file in files:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
file.seek(0)
tmp.write(file.read())
tmp.flush()
loader = PyPDFLoader(tmp.name)
docs.extend(loader.load())
os.unlink(tmp.name)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
split_docs = text_splitter.split_documents(docs)
PineconeVectorStore.from_documents(
documents=split_docs,
embedding=embeddings,
index_name=INDEX_NAME
)
return "Documents processed and stored."
# This function creates the QA chain (CPU only)
def init_qa_chain():
# Embeddings must be created inside a GPU function, so we do not re-create here.
# PineconeVectorStore uses the embeddings stored in Pinecone.
llm = ChatMistralAI(
model="mistral-tiny",
temperature=0.3,
mistral_api_key=os.getenv("MISTRAL_API_KEY")
)
# Pass None for embeddings since vectors are already in Pinecone
vector_store = PineconeVectorStore(
index_name=INDEX_NAME,
embedding=None
)
return RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True
)
# State: store the QA chain after processing
qa_chain = None
with gr.Blocks() as demo:
gr.Markdown("## RAG Chatbot - PDF Reader")
file_input = gr.File(file_types=[".pdf"], file_count="multiple", type="file", label="Upload PDFs")
process_btn = gr.Button("Process Documents")
process_output = gr.Textbox(label="Processing Status")
chat_input = gr.Textbox(label="Ask a question about your documents")
chat_btn = gr.Button("Submit Question")
chat_output = gr.Textbox(label="Answer")
source_output = gr.Textbox(label="Relevant Sources")
def process_wrapper(files):
global qa_chain
if not files or len(files) == 0:
return "Please upload at least one PDF."
msg = process_documents(files)
qa_chain = init_qa_chain()
return msg
def chat_with_docs(question):
global qa_chain
if not qa_chain:
return "Please upload and process documents first.", ""
if not question.strip():
return "Please enter a question.", ""
response = qa_chain.invoke({"question": question}, return_only_outputs=True)
sources = "\n".join(
f"{os.path.basename(doc.metadata.get('source', 'unknown'))} (Page {doc.metadata.get('page', 'N/A')})"
for doc in response.get('source_documents', [])[:3]
)
return response.get('answer', "No answer found."), sources
process_btn.click(fn=process_wrapper, inputs=[file_input], outputs=process_output)
chat_btn.click(fn=chat_with_docs, inputs=chat_input, outputs=[chat_output, source_output])
if __name__ == "__main__":
demo.launch()