Spaces:
Sleeping
Sleeping
# main.py - FastAPI version of PuppyCompanion | |
import os | |
import logging | |
import json | |
import asyncio | |
from datetime import datetime | |
from typing import List, Dict, Any | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import HTMLResponse, FileResponse | |
from pydantic import BaseModel | |
from dotenv import load_dotenv | |
# Import your existing modules | |
from rag_system import RAGSystem | |
from agent_workflow import AgentWorkflow | |
# Load environment variables | |
load_dotenv() | |
# Logging configuration | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Global variables | |
global_agent = None | |
global_qdrant_client = None | |
global_retriever = None | |
global_documents = None | |
initialization_completed = False | |
# Path to preprocessed data | |
PREPROCESSED_CHUNKS_PATH = "all_books_preprocessed_chunks.json" | |
# Pydantic models | |
class QuestionRequest(BaseModel): | |
question: str | |
class ChatResponse(BaseModel): | |
response: str | |
sources: List[Dict[str, Any]] = [] | |
tool_used: str = "" | |
# WebSocket connection manager | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
def disconnect(self, websocket: WebSocket): | |
self.active_connections.remove(websocket) | |
async def send_log(self, message: str, log_type: str = "info"): | |
timestamp = datetime.now().strftime("%H:%M:%S") | |
log_data = { | |
"timestamp": timestamp, | |
"message": message, | |
"type": log_type | |
} | |
for connection in self.active_connections: | |
try: | |
await connection.send_text(json.dumps(log_data)) | |
except: | |
pass | |
manager = ConnectionManager() | |
def load_preprocessed_chunks(file_path="all_books_preprocessed_chunks.json"): | |
"""Load preprocessed chunks from a JSON file.""" | |
global global_documents | |
if global_documents is not None: | |
logger.info("Using cached document chunks") | |
return global_documents | |
logger.info(f"Loading preprocessed chunks from {file_path}") | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
from langchain_core.documents import Document | |
documents = [] | |
for item in data: | |
doc = Document( | |
page_content=item['page_content'], | |
metadata=item['metadata'] | |
) | |
documents.append(doc) | |
logger.info(f"Loaded {len(documents)} document chunks") | |
global_documents = documents | |
return documents | |
except Exception as e: | |
logger.error(f"Error loading preprocessed chunks: {str(e)}") | |
raise | |
def initialize_retriever(documents): | |
"""Create a retriever from documents using a shared Qdrant client.""" | |
global global_qdrant_client, global_retriever | |
# Return existing retriever if already initialized | |
if global_retriever is not None: | |
logger.info("Using existing global retriever") | |
return global_retriever | |
logger.info("Creating retriever from documents") | |
try: | |
# Use langchain_qdrant to create a vector store | |
from qdrant_client import QdrantClient | |
from langchain_qdrant import QdrantVectorStore | |
from langchain_openai import OpenAIEmbeddings | |
# Create embedding object | |
embeddings = OpenAIEmbeddings() | |
logger.info("Created OpenAI embeddings object") | |
# Create a persistent path for embeddings storage | |
qdrant_path = "/tmp/qdrant_storage" | |
logger.info(f"Using persistent Qdrant storage path: {qdrant_path}") | |
# Create directory for Qdrant storage | |
os.makedirs(qdrant_path, exist_ok=True) | |
# Create or reuse global Qdrant client | |
if global_qdrant_client is None: | |
client = QdrantClient(path=qdrant_path) | |
global_qdrant_client = client | |
logger.info("Created new global Qdrant client with persistent storage") | |
else: | |
client = global_qdrant_client | |
logger.info("Using existing global Qdrant client") | |
# Check if collection already exists | |
try: | |
collections = client.get_collections() | |
collection_exists = any(collection.name == "puppies" for collection in collections.collections) | |
logger.info(f"Collection 'puppies' exists: {collection_exists}") | |
except Exception as e: | |
collection_exists = False | |
logger.info(f"Could not check collections, assuming none exist: {e}") | |
# OpenAI embeddings dimension | |
embedding_dim = 1536 | |
# Create collection only if it doesn't exist | |
if not collection_exists: | |
from qdrant_client.http import models | |
client.create_collection( | |
collection_name="puppies", | |
vectors_config=models.VectorParams( | |
size=embedding_dim, | |
distance=models.Distance.COSINE | |
) | |
) | |
logger.info("Created new collection 'puppies'") | |
else: | |
logger.info("Using existing collection 'puppies'") | |
# Create vector store | |
vector_store = QdrantVectorStore( | |
client=client, | |
collection_name="puppies", | |
embedding=embeddings | |
) | |
# Add documents only if collection was just created (to avoid duplicates) | |
if not collection_exists: | |
vector_store.add_documents(documents) | |
logger.info(f"Added {len(documents)} documents to vector store") | |
else: | |
logger.info("Using existing embeddings in vector store") | |
# Create retriever | |
retriever = vector_store.as_retriever(search_kwargs={"k": 5}) | |
logger.info("Created retriever") | |
# Store global retriever | |
global_retriever = retriever | |
return retriever | |
except Exception as e: | |
logger.error(f"Error creating retriever: {str(e)}") | |
raise | |
async def initialize_system(): | |
"""Initialize the RAG system and agent""" | |
global global_agent, initialization_completed | |
if initialization_completed: | |
return global_agent | |
await manager.send_log("Starting system initialization...", "info") | |
try: | |
# Load documents | |
await manager.send_log("Loading document chunks...", "info") | |
documents = load_preprocessed_chunks() | |
await manager.send_log(f"Loaded {len(documents)} document chunks", "success") | |
# Create retriever | |
await manager.send_log("Creating retriever...", "info") | |
retriever = initialize_retriever(documents) | |
await manager.send_log("Retriever ready", "success") | |
# Create RAG system | |
await manager.send_log("Setting up RAG system...", "info") | |
rag_system = RAGSystem(retriever) | |
rag_tool = rag_system.create_rag_tool() | |
await manager.send_log("RAG system ready", "success") | |
# Create agent workflow | |
await manager.send_log("Initializing agent workflow...", "info") | |
agent = AgentWorkflow(rag_tool) | |
await manager.send_log("Agent workflow ready", "success") | |
global_agent = agent | |
initialization_completed = True | |
await manager.send_log("System initialization completed!", "success") | |
return agent | |
except Exception as e: | |
await manager.send_log(f"Error during initialization: {str(e)}", "error") | |
raise | |
async def lifespan(app: FastAPI): | |
"""Manage application lifespan""" | |
# Startup | |
try: | |
await initialize_system() | |
logger.info("System initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize system: {e}") | |
raise # ⚠️ IMPORTANT: Arrêter l'application si l'initialisation échoue | |
yield | |
# Shutdown - cleanup if needed | |
logger.info("Application shutdown") | |
# FastAPI app with lifespan | |
app = FastAPI( | |
title="PuppyCompanion", | |
description="AI Assistant for Puppy Care", | |
lifespan=lifespan | |
) | |
async def get_index(): | |
"""Serve the main HTML page""" | |
return FileResponse("static/index.html") | |
async def get_favicon(): | |
"""Return a 204 No Content for favicon to avoid 404 errors""" | |
from fastapi import Response | |
return Response(status_code=204) | |
async def websocket_endpoint(websocket: WebSocket): | |
"""WebSocket endpoint for real-time logs""" | |
await manager.connect(websocket) | |
try: | |
while True: | |
# Keep connection alive | |
await asyncio.sleep(1) | |
except WebSocketDisconnect: | |
manager.disconnect(websocket) | |
async def chat_endpoint(request: QuestionRequest): | |
"""Main chat endpoint""" | |
global global_agent | |
if not initialization_completed or not global_agent: | |
await manager.send_log("System not initialized, starting initialization...", "warning") | |
try: | |
global_agent = await initialize_system() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="System initialization failed") | |
question = request.question | |
await manager.send_log(f"New question: {question}", "info") | |
try: | |
# Process question with agent | |
await manager.send_log("Processing with agent workflow...", "info") | |
result = global_agent.process_question(question) | |
# Extract response and metadata | |
response_content = global_agent.get_final_response(result) | |
# Parse tool usage and send detailed info to debug console | |
tool_used = "Unknown" | |
sources = [] | |
if "[Using RAG tool]" in response_content: | |
tool_used = "RAG Tool" | |
await manager.send_log("Used RAG tool - Knowledge base search", "tool") | |
# Send detailed RAG chunks to debug console | |
if "context" in result: | |
await manager.send_log(f"Retrieved {len(result['context'])} chunks from knowledge base:", "info") | |
for i, doc in enumerate(result["context"], 1): | |
source_name = doc.metadata.get('source', 'Unknown') | |
page = doc.metadata.get('page', 'N/A') | |
chapter = doc.metadata.get('chapter', '') | |
# Create detailed chunk info for console | |
if chapter: | |
chunk_header = f"Chunk {i} - {source_name} (Chapter: {chapter}, Page: {page})" | |
else: | |
chunk_header = f"Chunk {i} - {source_name} (Page: {page})" | |
await manager.send_log(chunk_header, "source") | |
# Send chunk content preview | |
content_preview = doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content | |
await manager.send_log(f"Content: {content_preview}", "chunk") | |
# Collect for sources array (minimal info) | |
source_info = { | |
"chunk": i, | |
"source": source_name, | |
"page": page, | |
"chapter": chapter | |
} | |
sources.append(source_info) | |
elif "[Using Tavily tool]" in response_content: | |
tool_used = "Tavily Tool" | |
await manager.send_log("Used Tavily tool - Web search", "tool") | |
# Extract Tavily sources from response content and send to debug console | |
lines = response_content.split('\n') | |
tavily_sources_count = 0 | |
for line in lines: | |
line_stripped = line.strip() | |
# Look for Tavily source lines like "- *Source 1 - domain.com: Title*" | |
if (line_stripped.startswith('- *Source') and ':' in line_stripped): | |
tavily_sources_count += 1 | |
# Extract and format for debug console | |
try: | |
# Remove markdown formatting for clean display | |
clean_source = line_stripped.replace('- *', '').replace('*', '') | |
await manager.send_log(f"{clean_source}", "source") | |
except: | |
await manager.send_log(f"{line_stripped}", "source") | |
if tavily_sources_count > 0: | |
await manager.send_log(f"Found {tavily_sources_count} web sources", "info") | |
else: | |
await manager.send_log("Searched the web for current information", "source") | |
elif "out of scope" in response_content.lower(): | |
tool_used = "Out of Scope" | |
await manager.send_log("Question outside scope (not dog-related)", "warning") | |
# Clean response content - REMOVE ALL source references for mobile interface | |
clean_response = response_content | |
# Remove tool markers | |
clean_response = clean_response.replace("[Using RAG tool]", "").replace("[Using Tavily tool]", "").strip() | |
# Remove ALL source-related lines with comprehensive patterns | |
lines = clean_response.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
line_stripped = line.strip() | |
# Skip lines that are source references (comprehensive patterns) | |
skip_line = False | |
# Pattern 1: Lines starting with * containing Source/Chunk/Based on | |
if (line_stripped.startswith('*') and | |
('Chunk' in line or 'Source' in line or 'Based on' in line or 'Basé sur' in line)): | |
skip_line = True | |
# Pattern 2: Lines starting with - * containing Source/Chunk/Based on | |
if (line_stripped.startswith('- *') and | |
('Chunk' in line or 'Source' in line or 'Based on' in line or 'Basé sur' in line)): | |
skip_line = True | |
# Pattern 3: Lines that are just chunk references like "- *Chunk 1 - filename*" | |
if (line_stripped.startswith('- *Chunk') and line_stripped.endswith('*')): | |
skip_line = True | |
# Pattern 4: Lines that start with "- *Based on" | |
if line_stripped.startswith('- *Based on'): | |
skip_line = True | |
# Add line only if it's not a source reference and not empty | |
if not skip_line and line_stripped: | |
cleaned_lines.append(line) | |
# Final clean response for mobile interface | |
final_response = '\n'.join(cleaned_lines).strip() | |
# Additional cleanup - remove any remaining source markers at start | |
while final_response.startswith('- *') or final_response.startswith('*'): | |
# Find the end of the line to remove | |
if '\n' in final_response: | |
final_response = final_response.split('\n', 1)[1].strip() | |
else: | |
final_response = "" | |
break | |
# Ensure we have a response | |
if not final_response: | |
final_response = "I apologize, but I couldn't generate a proper response to your question." | |
await manager.send_log(f"Clean response ready for mobile interface", "success") | |
return ChatResponse( | |
response=final_response, | |
sources=sources, # Minimal info for API, detailed info already sent to debug console | |
tool_used=tool_used | |
) | |
except Exception as e: | |
await manager.send_log(f"Error processing question: {str(e)}", "error") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"initialized": initialization_completed, | |
"timestamp": datetime.now().isoformat() | |
} | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |