Ibrahim Kaiser
glm-4.5-air with openrouter
708770f
import streamlit as st
import pypdf
from io import BytesIO
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma # Changed from FAISS
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langchain.schema import Document
from langgraph.graph import StateGraph, END
from huggingface_hub import hf_hub_download
import os
import shutil
from config import (
EMBEDDING_MODEL, LLM_MODEL, LLM_BASE_URL, LLM_TEMPERATURE,
CHUNK_SIZE, CHUNK_OVERLAP, BATCH_SIZE, OPENROUTER_API_KEY,
VECTOR_STORE_PATH, HF_REPO_ID, HF_FILENAME, HF_REPO_TYPE
)
from state import RAGState
# Download PDF from Hugging Face Hub
def download_pdf_from_hf():
try:
pdf_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_FILENAME,
repo_type=HF_REPO_TYPE
)
return pdf_path
except Exception as e:
st.error(f"Error downloading PDF from Hugging Face Hub: {str(e)}")
return None
# Initialize components with lazy loading
def initialize_components():
if 'embeddings' not in st.session_state:
with st.spinner("Loading embedding model (first-time download may take a moment)..."):
st.info("Downloading embedding model (approx. 90MB). This is a one-time download.")
st.session_state.embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'}
)
if 'llm' not in st.session_state:
with st.spinner("Connecting to language model..."):
st.session_state.llm = ChatOpenAI(
model=LLM_MODEL,
api_key=OPENROUTER_API_KEY,
base_url=LLM_BASE_URL,
temperature=LLM_TEMPERATURE
)
if 'text_splitter' not in st.session_state:
st.session_state.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP
)
if 'graph' not in st.session_state:
build_graph()
# Build the LangGraph workflow
def build_graph():
workflow = StateGraph(RAGState)
# Add nodes
workflow.add_node("retrieve", retrieve_context)
workflow.add_node("generate", generate_response)
# Add edges
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
# Set entry point
workflow.set_entry_point("retrieve")
# Compile the graph
st.session_state.graph = workflow.compile()
# Extract text from PDF file path
def extract_text_from_pdf(pdf_path):
try:
with open(pdf_path, 'rb') as file:
pdf_reader = pypdf.PdfReader(file)
text = ""
total_pages = len(pdf_reader.pages)
# Progress indicators
progress_bar = st.progress(0)
status_text = st.empty()
for i, page in enumerate(pdf_reader.pages):
text += page.extract_text() + "\n"
# Update progress
progress = (i + 1) / total_pages
progress_bar.progress(progress)
status_text.text(f"Processing page {i+1}/{total_pages}")
# Clear progress indicators
progress_bar.empty()
status_text.empty()
return text
except Exception as e:
st.error(f"Error reading PDF: {str(e)}")
return ""
# Process document from Hugging Face Hub
def process_hf_document():
# Check if vector store already exists
if os.path.exists(VECTOR_STORE_PATH):
with st.spinner("Loading existing vector store..."):
st.session_state.vector_store = Chroma(
persist_directory=VECTOR_STORE_PATH,
embedding_function=st.session_state.embeddings
)
return True
# Download PDF from Hugging Face Hub
with st.spinner("Downloading PDF from Hugging Face Hub..."):
pdf_path = download_pdf_from_hf()
if not pdf_path:
return False
# Extract text from PDF
with st.spinner("Extracting text from PDF..."):
text = extract_text_from_pdf(pdf_path)
if not text:
return False
# Create documents and vector store
with st.spinner("Creating vector store..."):
documents = []
chunks = st.session_state.text_splitter.split_text(text)
for chunk in chunks:
documents.append(Document(
page_content=chunk,
metadata={"source": HF_FILENAME}
))
st.session_state.vector_store = Chroma.from_documents(
documents=documents,
embedding=st.session_state.embeddings,
persist_directory=VECTOR_STORE_PATH
)
return True
# Retrieve relevant context from vector store
def retrieve_context(state: RAGState) -> RAGState:
if 'vector_store' not in st.session_state:
state["context"] = ""
return state
# Get the last user message
last_message = state["messages"][-1]
if isinstance(last_message, HumanMessage):
query = last_message.content
# Retrieve similar documents with metadata
docs = st.session_state.vector_store.similarity_search(
query,
k=3,
filter={"source": HF_FILENAME}
)
context = "\n\n".join([doc.page_content for doc in docs])
state["context"] = context
return state
# Generate response using LLM with context
def generate_response(state: RAGState) -> RAGState:
last_message = state["messages"][-1]
context = state.get("context", "")
prompt = f"""Context: {context}
Question: {last_message.content}
Please answer the question based on the provided context. If the context doesn't contain relevant information, say so clearly.
Answer:"""
try:
response = st.session_state.llm.invoke([{"role": "user", "content": prompt}])
ai_message = AIMessage(content=response.content)
state["messages"].append(ai_message)
except Exception as e:
error_message = AIMessage(content=f"Sorry, I encountered an error: {str(e)}")
state["messages"].append(error_message)
return state
# Chat with the RAG system
def chat(message: str) -> str:
initialize_components()
# Ensure document is processed
if 'vector_store' not in st.session_state:
if not process_hf_document():
return "Error: Failed to process document from Hugging Face Hub"
if 'graph' not in st.session_state:
return "Error: Graph not initialized"
# Create initial state
initial_state = RAGState(
messages=[HumanMessage(content=message)],
documents=[],
context=""
)
# Run the graph
with st.spinner("Retrieving context and generating response..."):
final_state = st.session_state.graph.invoke(initial_state)
# Return the last AI message
for msg in reversed(final_state["messages"]):
if isinstance(msg, AIMessage):
return msg.content
return "No response generated"
# Reset vector store
def reset_vector_store():
if os.path.exists(VECTOR_STORE_PATH):
shutil.rmtree(VECTOR_STORE_PATH)
if 'vector_store' in st.session_state:
del st.session_state.vector_store
return True
return False