Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pypdf | |
| from io import BytesIO | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma # Changed from FAISS | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain.schema import Document | |
| from langgraph.graph import StateGraph, END | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import shutil | |
| from config import ( | |
| EMBEDDING_MODEL, LLM_MODEL, LLM_BASE_URL, LLM_TEMPERATURE, | |
| CHUNK_SIZE, CHUNK_OVERLAP, BATCH_SIZE, OPENROUTER_API_KEY, | |
| VECTOR_STORE_PATH, HF_REPO_ID, HF_FILENAME, HF_REPO_TYPE | |
| ) | |
| from state import RAGState | |
| # Download PDF from Hugging Face Hub | |
| def download_pdf_from_hf(): | |
| try: | |
| pdf_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=HF_FILENAME, | |
| repo_type=HF_REPO_TYPE | |
| ) | |
| return pdf_path | |
| except Exception as e: | |
| st.error(f"Error downloading PDF from Hugging Face Hub: {str(e)}") | |
| return None | |
| # Initialize components with lazy loading | |
| def initialize_components(): | |
| if 'embeddings' not in st.session_state: | |
| with st.spinner("Loading embedding model (first-time download may take a moment)..."): | |
| st.info("Downloading embedding model (approx. 90MB). This is a one-time download.") | |
| st.session_state.embeddings = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL, | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| if 'llm' not in st.session_state: | |
| with st.spinner("Connecting to language model..."): | |
| st.session_state.llm = ChatOpenAI( | |
| model=LLM_MODEL, | |
| api_key=OPENROUTER_API_KEY, | |
| base_url=LLM_BASE_URL, | |
| temperature=LLM_TEMPERATURE | |
| ) | |
| if 'text_splitter' not in st.session_state: | |
| st.session_state.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP | |
| ) | |
| if 'graph' not in st.session_state: | |
| build_graph() | |
| # Build the LangGraph workflow | |
| def build_graph(): | |
| workflow = StateGraph(RAGState) | |
| # Add nodes | |
| workflow.add_node("retrieve", retrieve_context) | |
| workflow.add_node("generate", generate_response) | |
| # Add edges | |
| workflow.add_edge("retrieve", "generate") | |
| workflow.add_edge("generate", END) | |
| # Set entry point | |
| workflow.set_entry_point("retrieve") | |
| # Compile the graph | |
| st.session_state.graph = workflow.compile() | |
| # Extract text from PDF file path | |
| def extract_text_from_pdf(pdf_path): | |
| try: | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = pypdf.PdfReader(file) | |
| text = "" | |
| total_pages = len(pdf_reader.pages) | |
| # Progress indicators | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| for i, page in enumerate(pdf_reader.pages): | |
| text += page.extract_text() + "\n" | |
| # Update progress | |
| progress = (i + 1) / total_pages | |
| progress_bar.progress(progress) | |
| status_text.text(f"Processing page {i+1}/{total_pages}") | |
| # Clear progress indicators | |
| progress_bar.empty() | |
| status_text.empty() | |
| return text | |
| except Exception as e: | |
| st.error(f"Error reading PDF: {str(e)}") | |
| return "" | |
| # Process document from Hugging Face Hub | |
| def process_hf_document(): | |
| # Check if vector store already exists | |
| if os.path.exists(VECTOR_STORE_PATH): | |
| with st.spinner("Loading existing vector store..."): | |
| st.session_state.vector_store = Chroma( | |
| persist_directory=VECTOR_STORE_PATH, | |
| embedding_function=st.session_state.embeddings | |
| ) | |
| return True | |
| # Download PDF from Hugging Face Hub | |
| with st.spinner("Downloading PDF from Hugging Face Hub..."): | |
| pdf_path = download_pdf_from_hf() | |
| if not pdf_path: | |
| return False | |
| # Extract text from PDF | |
| with st.spinner("Extracting text from PDF..."): | |
| text = extract_text_from_pdf(pdf_path) | |
| if not text: | |
| return False | |
| # Create documents and vector store | |
| with st.spinner("Creating vector store..."): | |
| documents = [] | |
| chunks = st.session_state.text_splitter.split_text(text) | |
| for chunk in chunks: | |
| documents.append(Document( | |
| page_content=chunk, | |
| metadata={"source": HF_FILENAME} | |
| )) | |
| st.session_state.vector_store = Chroma.from_documents( | |
| documents=documents, | |
| embedding=st.session_state.embeddings, | |
| persist_directory=VECTOR_STORE_PATH | |
| ) | |
| return True | |
| # Retrieve relevant context from vector store | |
| def retrieve_context(state: RAGState) -> RAGState: | |
| if 'vector_store' not in st.session_state: | |
| state["context"] = "" | |
| return state | |
| # Get the last user message | |
| last_message = state["messages"][-1] | |
| if isinstance(last_message, HumanMessage): | |
| query = last_message.content | |
| # Retrieve similar documents with metadata | |
| docs = st.session_state.vector_store.similarity_search( | |
| query, | |
| k=3, | |
| filter={"source": HF_FILENAME} | |
| ) | |
| context = "\n\n".join([doc.page_content for doc in docs]) | |
| state["context"] = context | |
| return state | |
| # Generate response using LLM with context | |
| def generate_response(state: RAGState) -> RAGState: | |
| last_message = state["messages"][-1] | |
| context = state.get("context", "") | |
| prompt = f"""Context: {context} | |
| Question: {last_message.content} | |
| Please answer the question based on the provided context. If the context doesn't contain relevant information, say so clearly. | |
| Answer:""" | |
| try: | |
| response = st.session_state.llm.invoke([{"role": "user", "content": prompt}]) | |
| ai_message = AIMessage(content=response.content) | |
| state["messages"].append(ai_message) | |
| except Exception as e: | |
| error_message = AIMessage(content=f"Sorry, I encountered an error: {str(e)}") | |
| state["messages"].append(error_message) | |
| return state | |
| # Chat with the RAG system | |
| def chat(message: str) -> str: | |
| initialize_components() | |
| # Ensure document is processed | |
| if 'vector_store' not in st.session_state: | |
| if not process_hf_document(): | |
| return "Error: Failed to process document from Hugging Face Hub" | |
| if 'graph' not in st.session_state: | |
| return "Error: Graph not initialized" | |
| # Create initial state | |
| initial_state = RAGState( | |
| messages=[HumanMessage(content=message)], | |
| documents=[], | |
| context="" | |
| ) | |
| # Run the graph | |
| with st.spinner("Retrieving context and generating response..."): | |
| final_state = st.session_state.graph.invoke(initial_state) | |
| # Return the last AI message | |
| for msg in reversed(final_state["messages"]): | |
| if isinstance(msg, AIMessage): | |
| return msg.content | |
| return "No response generated" | |
| # Reset vector store | |
| def reset_vector_store(): | |
| if os.path.exists(VECTOR_STORE_PATH): | |
| shutil.rmtree(VECTOR_STORE_PATH) | |
| if 'vector_store' in st.session_state: | |
| del st.session_state.vector_store | |
| return True | |
| return False |