Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import os | |
import time | |
from typing import Optional | |
# For loading your HF dataset | |
from datasets import load_dataset | |
# LangChain imports | |
from langchain.docstore.document import Document | |
from langchain_text_splitters import TokenTextSplitter | |
from langchain_chroma import Chroma | |
from langchain_dartmouth.embeddings import DartmouthEmbeddings | |
from langchain_dartmouth.llms import ChatDartmouthCloud | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
# FastAPI initialization | |
app = FastAPI(title="RAG API", description="Simple API for RAG-based question answering") | |
os.environ["HF_HOME"] = "./.cache" | |
os.environ["HF_DATASETS_CACHE"] = "./.cache/datasets" | |
os.environ["TRANSFORMERS_CACHE"] = "./.cache/transformers" | |
# Global variables | |
persist_directory = "./chroma_db" | |
vector_store = None | |
retriever = None | |
rag_chain = None | |
initialization_complete = False | |
initialization_in_progress = False | |
# Models | |
class QueryRequest(BaseModel): | |
query: str | |
num_results: Optional[int] = 3 | |
class QueryResponse(BaseModel): | |
answer: str | |
def initialize_rag(): | |
""" | |
Loads your HF dataset, splits it into chunks, and creates a Chroma vector store. | |
""" | |
global vector_store, retriever, rag_chain, initialization_complete, initialization_in_progress | |
if initialization_complete: | |
return | |
if initialization_in_progress: | |
while initialization_in_progress: | |
time.sleep(1) | |
return | |
initialization_in_progress = True | |
try: | |
# 1. Check if Chroma DB already exists | |
if os.path.exists(persist_directory) and os.listdir(persist_directory): | |
print("Loading existing vector store from disk...") | |
embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5") | |
vector_store = Chroma( | |
persist_directory=persist_directory, | |
embedding_function=embeddings_model | |
) | |
print(f"Loaded vector store with {vector_store._collection.count()} documents") | |
else: | |
print("Creating new vector store from HF dataset...") | |
# 2. Load your Hugging Face dataset | |
# Replace "username/dataset_name" with your actual dataset name/ID. | |
# Make sure to pick the right split ("train", "test", etc.). | |
hf_dataset = load_dataset("shaamil101/met-documents", split="train") | |
# 3. Convert rows into LangChain `Document` objects | |
# We assume your dataset columns are: 'filename' and 'content'. | |
docs = [] | |
for idx, row in enumerate(hf_dataset): | |
docs.append( | |
Document( | |
page_content=row["content"], | |
metadata={ | |
"filename": row["filename"], | |
"id": idx | |
} | |
) | |
) | |
print(f"Loaded {len(docs)} documents from HF dataset") | |
# 4. Split documents into chunks | |
splitter = TokenTextSplitter( | |
chunk_size=400, | |
chunk_overlap=0, | |
encoding_name="cl100k_base" | |
) | |
documents = splitter.split_documents(docs) | |
print(f"Split into {len(documents)} chunks") | |
# 5. Create the vector store | |
embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5") | |
vector_store = Chroma.from_documents( | |
documents=documents, | |
embedding=embeddings_model, | |
persist_directory=persist_directory | |
) | |
vector_store.persist() | |
print(f"Created and persisted vector store with {len(documents)} documents") | |
# 6. Build a retriever on top of the vector store | |
global retriever | |
retriever = vector_store.as_retriever(search_kwargs={"k": 5}) | |
# 7. Create your LLM | |
llm = ChatDartmouthCloud(model_name="google_genai.gemini-2.0-flash-001") | |
# 8. Define a prompt template | |
template = """ | |
You are a helpful assistant that answers questions based on Metropolita Museum of Art in New York using the provided context. | |
Context: | |
{context} | |
Question: {question} | |
Answer the question based only on the provided context. | |
If you cannot answer the question with the context, say "I don't have enough information to answer this question." | |
""" | |
prompt = PromptTemplate.from_template(template) | |
# 9. Create the RAG chain | |
global rag_chain | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
initialization_complete = True | |
print("RAG pipeline initialized successfully!") | |
except Exception as e: | |
print(f"Error initializing RAG pipeline: {e}") | |
raise | |
finally: | |
initialization_in_progress = False | |
def read_root(): | |
return {"message": "RAG API is running. Send POST requests to /query endpoint."} | |
def health_check(): | |
global initialization_complete | |
return { | |
"status": "healthy", | |
"rag_initialized": initialization_complete | |
} | |
async def process_query(request: QueryRequest): | |
# Initialize RAG if not already done | |
if not initialization_complete: | |
initialize_rag() | |
start_time = time.time() | |
try: | |
# Retrieve relevant documents | |
docs = retriever.get_relevant_documents(request.query, k=request.num_results) | |
# Generate answer | |
answer = rag_chain.invoke(request.query) | |
processing_time = time.time() - start_time | |
print(f"Processed query in {processing_time:.2f} seconds") | |
return QueryResponse(answer=answer) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def startup_event(): | |
# Optionally initialize in a separate thread | |
import threading | |
threading.Thread(target=initialize_rag, daemon=True).start() | |