RAG / app.py
User
added sub directory and changed docker
72832ac
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import time
from typing import Optional
# For loading your HF dataset
from datasets import load_dataset
# LangChain imports
from langchain.docstore.document import Document
from langchain_text_splitters import TokenTextSplitter
from langchain_chroma import Chroma
from langchain_dartmouth.embeddings import DartmouthEmbeddings
from langchain_dartmouth.llms import ChatDartmouthCloud
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# FastAPI initialization
app = FastAPI(title="RAG API", description="Simple API for RAG-based question answering")
os.environ["HF_HOME"] = "./.cache"
os.environ["HF_DATASETS_CACHE"] = "./.cache/datasets"
os.environ["TRANSFORMERS_CACHE"] = "./.cache/transformers"
# Global variables
persist_directory = "./chroma_db"
vector_store = None
retriever = None
rag_chain = None
initialization_complete = False
initialization_in_progress = False
# Models
class QueryRequest(BaseModel):
query: str
num_results: Optional[int] = 3
class QueryResponse(BaseModel):
answer: str
def initialize_rag():
"""
Loads your HF dataset, splits it into chunks, and creates a Chroma vector store.
"""
global vector_store, retriever, rag_chain, initialization_complete, initialization_in_progress
if initialization_complete:
return
if initialization_in_progress:
while initialization_in_progress:
time.sleep(1)
return
initialization_in_progress = True
try:
# 1. Check if Chroma DB already exists
if os.path.exists(persist_directory) and os.listdir(persist_directory):
print("Loading existing vector store from disk...")
embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5")
vector_store = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings_model
)
print(f"Loaded vector store with {vector_store._collection.count()} documents")
else:
print("Creating new vector store from HF dataset...")
# 2. Load your Hugging Face dataset
# Replace "username/dataset_name" with your actual dataset name/ID.
# Make sure to pick the right split ("train", "test", etc.).
hf_dataset = load_dataset("shaamil101/met-documents", split="train")
# 3. Convert rows into LangChain `Document` objects
# We assume your dataset columns are: 'filename' and 'content'.
docs = []
for idx, row in enumerate(hf_dataset):
docs.append(
Document(
page_content=row["content"],
metadata={
"filename": row["filename"],
"id": idx
}
)
)
print(f"Loaded {len(docs)} documents from HF dataset")
# 4. Split documents into chunks
splitter = TokenTextSplitter(
chunk_size=400,
chunk_overlap=0,
encoding_name="cl100k_base"
)
documents = splitter.split_documents(docs)
print(f"Split into {len(documents)} chunks")
# 5. Create the vector store
embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5")
vector_store = Chroma.from_documents(
documents=documents,
embedding=embeddings_model,
persist_directory=persist_directory
)
vector_store.persist()
print(f"Created and persisted vector store with {len(documents)} documents")
# 6. Build a retriever on top of the vector store
global retriever
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
# 7. Create your LLM
llm = ChatDartmouthCloud(model_name="google_genai.gemini-2.0-flash-001")
# 8. Define a prompt template
template = """
You are a helpful assistant that answers questions based on Metropolita Museum of Art in New York using the provided context.
Context:
{context}
Question: {question}
Answer the question based only on the provided context.
If you cannot answer the question with the context, say "I don't have enough information to answer this question."
"""
prompt = PromptTemplate.from_template(template)
# 9. Create the RAG chain
global rag_chain
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
initialization_complete = True
print("RAG pipeline initialized successfully!")
except Exception as e:
print(f"Error initializing RAG pipeline: {e}")
raise
finally:
initialization_in_progress = False
@app.get("/")
def read_root():
return {"message": "RAG API is running. Send POST requests to /query endpoint."}
@app.get("/health")
def health_check():
global initialization_complete
return {
"status": "healthy",
"rag_initialized": initialization_complete
}
@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
# Initialize RAG if not already done
if not initialization_complete:
initialize_rag()
start_time = time.time()
try:
# Retrieve relevant documents
docs = retriever.get_relevant_documents(request.query, k=request.num_results)
# Generate answer
answer = rag_chain.invoke(request.query)
processing_time = time.time() - start_time
print(f"Processed query in {processing_time:.2f} seconds")
return QueryResponse(answer=answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.on_event("startup")
async def startup_event():
# Optionally initialize in a separate thread
import threading
threading.Thread(target=initialize_rag, daemon=True).start()