Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import os | |
| import time | |
| from typing import Optional | |
| # For loading your HF dataset | |
| from datasets import load_dataset | |
| # LangChain imports | |
| from langchain.docstore.document import Document | |
| from langchain_text_splitters import TokenTextSplitter | |
| from langchain_chroma import Chroma | |
| from langchain_dartmouth.embeddings import DartmouthEmbeddings | |
| from langchain_dartmouth.llms import ChatDartmouthCloud | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| # FastAPI initialization | |
| app = FastAPI(title="RAG API", description="Simple API for RAG-based question answering") | |
| # Global variables | |
| persist_directory = "./chroma_db" | |
| vector_store = None | |
| retriever = None | |
| rag_chain = None | |
| initialization_complete = False | |
| initialization_in_progress = False | |
| # Models | |
| class QueryRequest(BaseModel): | |
| query: str | |
| num_results: Optional[int] = 3 | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| def initialize_rag(): | |
| """ | |
| Loads your HF dataset, splits it into chunks, and creates a Chroma vector store. | |
| """ | |
| global vector_store, retriever, rag_chain, initialization_complete, initialization_in_progress | |
| if initialization_complete: | |
| return | |
| if initialization_in_progress: | |
| while initialization_in_progress: | |
| time.sleep(1) | |
| return | |
| initialization_in_progress = True | |
| try: | |
| # 1. Check if Chroma DB already exists | |
| if os.path.exists(persist_directory) and os.listdir(persist_directory): | |
| print("Loading existing vector store from disk...") | |
| embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5") | |
| vector_store = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=embeddings_model | |
| ) | |
| print(f"Loaded vector store with {vector_store._collection.count()} documents") | |
| else: | |
| print("Creating new vector store from HF dataset...") | |
| # 2. Load your Hugging Face dataset | |
| # Replace "username/dataset_name" with your actual dataset name/ID. | |
| # Make sure to pick the right split ("train", "test", etc.). | |
| hf_dataset = load_dataset("shaamil101/met-documents", split="train") | |
| # 3. Convert rows into LangChain `Document` objects | |
| # We assume your dataset columns are: 'filename' and 'content'. | |
| docs = [] | |
| for idx, row in enumerate(hf_dataset): | |
| docs.append( | |
| Document( | |
| page_content=row["content"], | |
| metadata={ | |
| "filename": row["filename"], | |
| "id": idx | |
| } | |
| ) | |
| ) | |
| print(f"Loaded {len(docs)} documents from HF dataset") | |
| # 4. Split documents into chunks | |
| splitter = TokenTextSplitter( | |
| chunk_size=400, | |
| chunk_overlap=0, | |
| encoding_name="cl100k_base" | |
| ) | |
| documents = splitter.split_documents(docs) | |
| print(f"Split into {len(documents)} chunks") | |
| # 5. Create the vector store | |
| embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5") | |
| vector_store = Chroma.from_documents( | |
| documents=documents, | |
| embedding=embeddings_model, | |
| persist_directory=persist_directory | |
| ) | |
| vector_store.persist() | |
| print(f"Created and persisted vector store with {len(documents)} documents") | |
| # 6. Build a retriever on top of the vector store | |
| global retriever | |
| retriever = vector_store.as_retriever(search_kwargs={"k": 5}) | |
| # 7. Create your LLM | |
| llm = ChatDartmouthCloud(model_name="google_genai.gemini-2.0-flash-001") | |
| # 8. Define a prompt template | |
| template = """ | |
| You are a helpful assistant that answers questions based on Metropolita Museum of Art in New York using the provided context. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer the question based only on the provided context. | |
| If you cannot answer the question with the context, say "I don't have enough information to answer this question." | |
| """ | |
| prompt = PromptTemplate.from_template(template) | |
| # 9. Create the RAG chain | |
| global rag_chain | |
| rag_chain = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| initialization_complete = True | |
| print("RAG pipeline initialized successfully!") | |
| except Exception as e: | |
| print(f"Error initializing RAG pipeline: {e}") | |
| raise | |
| finally: | |
| initialization_in_progress = False | |
| def read_root(): | |
| return {"message": "RAG API is running. Send POST requests to /query endpoint."} | |
| def health_check(): | |
| global initialization_complete | |
| return { | |
| "status": "healthy", | |
| "rag_initialized": initialization_complete | |
| } | |
| async def process_query(request: QueryRequest): | |
| # Initialize RAG if not already done | |
| if not initialization_complete: | |
| initialize_rag() | |
| start_time = time.time() | |
| try: | |
| # Retrieve relevant documents | |
| docs = retriever.get_relevant_documents(request.query, k=request.num_results) | |
| # Generate answer | |
| answer = rag_chain.invoke(request.query) | |
| processing_time = time.time() - start_time | |
| print(f"Processed query in {processing_time:.2f} seconds") | |
| return QueryResponse(answer=answer) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def startup_event(): | |
| # Optionally initialize in a separate thread | |
| import threading | |
| threading.Thread(target=initialize_rag, daemon=True).start() | |