import os import tempfile import torch import gradio as gr import spaces # Required for GPU-enabled Spaces from pinecone import Pinecone from langchain_pinecone import PineconeVectorStore from langchain_community.document_loaders import PyPDFLoader from langchain_mistralai import ChatMistralAI from langchain.chains import RetrievalQAWithSourcesChain from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings # Set device for embeddings device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize Pinecone pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) INDEX_NAME = "ragreader" # GPU-decorated function to load HuggingFace embeddings on GPU @spaces.GPU def init_embeddings(): return HuggingFaceEmbeddings( model_name="BAAI/bge-large-en-v1.5", model_kwargs={"device": device} ) embeddings = init_embeddings() # GPU-decorated document processing function @spaces.GPU def process_documents(files): docs = [] for file in files: with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(file.read()) loader = PyPDFLoader(tmp.name) docs.extend(loader.load()) os.unlink(tmp.name) text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) split_docs = text_splitter.split_documents(docs) PineconeVectorStore.from_documents( documents=split_docs, embedding=embeddings, index_name=INDEX_NAME ) return "Documents processed and stored." # Initialize the RetrievalQA chain (no GPU decoration needed here) def init_qa_chain(): llm = ChatMistralAI( model="mistral-tiny", temperature=0.3, mistral_api_key=os.getenv("MISTRAL_API_KEY") ) vector_store = PineconeVectorStore( index_name=INDEX_NAME, embedding=embeddings ) return RetrievalQAWithSourcesChain.from_chain_type( llm=llm, chain_type="stuff", retriever=vector_store.as_retriever(search_kwargs={"k": 3}), return_source_documents=True ) qa_chain = None # Build the Gradio UI with gr.Blocks() as demo: gr.Markdown("## RAG Chatbot - PDF Reader") file_input = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload PDFs") process_btn = gr.Button("Process Documents") process_output = gr.Textbox(label="Processing Status") chat_input = gr.Textbox(label="Ask a question about your documents") chat_btn = gr.Button("Submit Question") chat_output = gr.Textbox(label="Answer") source_output = gr.Textbox(label="Relevant Sources") def process_wrapper(files): global qa_chain msg = process_documents(files) qa_chain = init_qa_chain() return msg def chat_with_docs(question): if not qa_chain: return "Please upload and process documents first.", "" response = qa_chain.invoke({"question": question}, return_only_outputs=True) sources = "\n".join( f"{os.path.basename(doc.metadata.get('source', 'unknown'))} (Page {doc.metadata.get('page', 'N/A')})" for doc in response.get('source_documents', [])[:3] ) return response['answer'], sources process_btn.click(fn=process_wrapper, inputs=file_input, outputs=process_output) chat_btn.click(fn=chat_with_docs, inputs=chat_input, outputs=[chat_output, source_output]) if __name__ == "__main__": demo.launch()