Spaces:
Sleeping
Sleeping
File size: 6,421 Bytes
5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f 72832ac 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b 5637e0f c71af1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import time
from typing import Optional
# For loading your HF dataset
from datasets import load_dataset
# LangChain imports
from langchain.docstore.document import Document
from langchain_text_splitters import TokenTextSplitter
from langchain_chroma import Chroma
from langchain_dartmouth.embeddings import DartmouthEmbeddings
from langchain_dartmouth.llms import ChatDartmouthCloud
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# FastAPI initialization
app = FastAPI(title="RAG API", description="Simple API for RAG-based question answering")
os.environ["HF_HOME"] = "./.cache"
os.environ["HF_DATASETS_CACHE"] = "./.cache/datasets"
os.environ["TRANSFORMERS_CACHE"] = "./.cache/transformers"
# Global variables
persist_directory = "./chroma_db"
vector_store = None
retriever = None
rag_chain = None
initialization_complete = False
initialization_in_progress = False
# Models
class QueryRequest(BaseModel):
query: str
num_results: Optional[int] = 3
class QueryResponse(BaseModel):
answer: str
def initialize_rag():
"""
Loads your HF dataset, splits it into chunks, and creates a Chroma vector store.
"""
global vector_store, retriever, rag_chain, initialization_complete, initialization_in_progress
if initialization_complete:
return
if initialization_in_progress:
while initialization_in_progress:
time.sleep(1)
return
initialization_in_progress = True
try:
# 1. Check if Chroma DB already exists
if os.path.exists(persist_directory) and os.listdir(persist_directory):
print("Loading existing vector store from disk...")
embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5")
vector_store = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings_model
)
print(f"Loaded vector store with {vector_store._collection.count()} documents")
else:
print("Creating new vector store from HF dataset...")
# 2. Load your Hugging Face dataset
# Replace "username/dataset_name" with your actual dataset name/ID.
# Make sure to pick the right split ("train", "test", etc.).
hf_dataset = load_dataset("shaamil101/met-documents", split="train")
# 3. Convert rows into LangChain `Document` objects
# We assume your dataset columns are: 'filename' and 'content'.
docs = []
for idx, row in enumerate(hf_dataset):
docs.append(
Document(
page_content=row["content"],
metadata={
"filename": row["filename"],
"id": idx
}
)
)
print(f"Loaded {len(docs)} documents from HF dataset")
# 4. Split documents into chunks
splitter = TokenTextSplitter(
chunk_size=400,
chunk_overlap=0,
encoding_name="cl100k_base"
)
documents = splitter.split_documents(docs)
print(f"Split into {len(documents)} chunks")
# 5. Create the vector store
embeddings_model = DartmouthEmbeddings(model_name="bge-large-en-v1-5")
vector_store = Chroma.from_documents(
documents=documents,
embedding=embeddings_model,
persist_directory=persist_directory
)
vector_store.persist()
print(f"Created and persisted vector store with {len(documents)} documents")
# 6. Build a retriever on top of the vector store
global retriever
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
# 7. Create your LLM
llm = ChatDartmouthCloud(model_name="google_genai.gemini-2.0-flash-001")
# 8. Define a prompt template
template = """
You are a helpful assistant that answers questions based on Metropolita Museum of Art in New York using the provided context.
Context:
{context}
Question: {question}
Answer the question based only on the provided context.
If you cannot answer the question with the context, say "I don't have enough information to answer this question."
"""
prompt = PromptTemplate.from_template(template)
# 9. Create the RAG chain
global rag_chain
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
initialization_complete = True
print("RAG pipeline initialized successfully!")
except Exception as e:
print(f"Error initializing RAG pipeline: {e}")
raise
finally:
initialization_in_progress = False
@app.get("/")
def read_root():
return {"message": "RAG API is running. Send POST requests to /query endpoint."}
@app.get("/health")
def health_check():
global initialization_complete
return {
"status": "healthy",
"rag_initialized": initialization_complete
}
@app.post("/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
# Initialize RAG if not already done
if not initialization_complete:
initialize_rag()
start_time = time.time()
try:
# Retrieve relevant documents
docs = retriever.get_relevant_documents(request.query, k=request.num_results)
# Generate answer
answer = rag_chain.invoke(request.query)
processing_time = time.time() - start_time
print(f"Processed query in {processing_time:.2f} seconds")
return QueryResponse(answer=answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.on_event("startup")
async def startup_event():
# Optionally initialize in a separate thread
import threading
threading.Thread(target=initialize_rag, daemon=True).start()
|