from fastapi import FastAPI, HTTPException from pydantic import BaseModel import os import time from typing import Optional # For loading your HF dataset from datasets import load_dataset # LangChain imports from langchain.docstore.document import Document from langchain_text_splitters import TokenTextSplitter from langchain_chroma import Chroma from langchain_dartmouth.embeddings import DartmouthEmbeddings from langchain_dartmouth.llms import ChatDartmouthCloud from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough # FastAPI initialization app = FastAPI(title="RAG API", description="Simple API for RAG-based question answering") os.environ["HF_HOME"] = "./.cache" os.environ["HF_DATASETS_CACHE"] = "./.cache/datasets" os.environ["TRANSFORMERS_CACHE"] = "./.cache/transformers" # Global variables persist_directory = "./chroma_db" vector_store = None retriever = None rag_chain = None initialization_complete = False initialization_in_progress = False # Models class QueryRequest(BaseModel): query: str num_results: Optional[int] = 3 class QueryResponse(BaseModel): answer: str def initialize_rag(): """ Loads your HF dataset, splits it into chunks, and creates a Chroma vector store. """ global vector_store, retriever, rag_chain, initialization_complete, initialization_in_progress if initialization_complete: return if initialization_in_progress: while initialization_in_progress: time.sleep(1) return initialization_in_progress = True try: # 1. Check if Chroma DB already exists if os.path.exists(persist_directory) and os.listdir(persist_directory): print("Loading existing vector store from disk...") embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5") vector_store = Chroma( persist_directory=persist_directory, embedding_function=embeddings_model ) print(f"Loaded vector store with {vector_store._collection.count()} documents") else: print("Creating new vector store from HF dataset...") # 2. Load your Hugging Face dataset # Replace "username/dataset_name" with your actual dataset name/ID. # Make sure to pick the right split ("train", "test", etc.). hf_dataset = load_dataset("shaamil101/met-documents", split="train") # 3. Convert rows into LangChain `Document` objects # We assume your dataset columns are: 'filename' and 'content'. docs = [] for idx, row in enumerate(hf_dataset): docs.append( Document( page_content=row["content"], metadata={ "filename": row["filename"], "id": idx } ) ) print(f"Loaded {len(docs)} documents from HF dataset") # 4. Split documents into chunks splitter = TokenTextSplitter( chunk_size=400, chunk_overlap=0, encoding_name="cl100k_base" ) documents = splitter.split_documents(docs) print(f"Split into {len(documents)} chunks") # 5. Create the vector store embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5") vector_store = Chroma.from_documents( documents=documents, embedding=embeddings_model, persist_directory=persist_directory ) vector_store.persist() print(f"Created and persisted vector store with {len(documents)} documents") # 6. Build a retriever on top of the vector store global retriever retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # 7. Create your LLM llm = ChatDartmouthCloud(model_name="google_genai.gemini-2.0-flash-001") # 8. Define a prompt template template = """ You are a helpful assistant that answers questions based on Metropolita Museum of Art in New York using the provided context. Context: {context} Question: {question} Answer the question based only on the provided context. If you cannot answer the question with the context, say "I don't have enough information to answer this question." """ prompt = PromptTemplate.from_template(template) # 9. Create the RAG chain global rag_chain rag_chain = ( {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) initialization_complete = True print("RAG pipeline initialized successfully!") except Exception as e: print(f"Error initializing RAG pipeline: {e}") raise finally: initialization_in_progress = False @app.get("/") def read_root(): return {"message": "RAG API is running. Send POST requests to /query endpoint."} @app.get("/health") def health_check(): global initialization_complete return { "status": "healthy", "rag_initialized": initialization_complete } @app.post("/query", response_model=QueryResponse) async def process_query(request: QueryRequest): # Initialize RAG if not already done if not initialization_complete: initialize_rag() start_time = time.time() try: # Retrieve relevant documents docs = retriever.get_relevant_documents(request.query, k=request.num_results) # Generate answer answer = rag_chain.invoke(request.query) processing_time = time.time() - start_time print(f"Processed query in {processing_time:.2f} seconds") return QueryResponse(answer=answer) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.on_event("startup") async def startup_event(): # Optionally initialize in a separate thread import threading threading.Thread(target=initialize_rag, daemon=True).start()