Spaces:
Build error
Build error
import os | |
import time | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from llama_index.core import ( | |
StorageContext, | |
load_index_from_storage, | |
VectorStoreIndex, | |
SimpleDirectoryReader, | |
ChatPromptTemplate, | |
Settings, | |
) | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.templating import Jinja2Templates | |
import uuid | |
import datetime | |
import json | |
import re | |
from deep_translator import GoogleTranslator | |
import ollama # Import Ollama for inference | |
# Define Pydantic model for incoming request body | |
class MessageRequest(BaseModel): | |
message: str | |
language: str | |
# Initialize FastAPI app | |
app = FastAPI() | |
async def add_security_headers(request: Request, call_next): | |
response = await call_next(request) | |
response.headers["Content-Security-Policy"] = "frame-ancestors *; frame-src *; object-src *;" | |
response.headers["X-Frame-Options"] = "ALLOWALL" | |
return response | |
# Allow CORS requests from any domain | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Static files setup | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="static") | |
# Configure LlamaIndex settings | |
Settings.llm = None # No need for Hugging Face anymore | |
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
PERSIST_DIR = "db" | |
PDF_DIRECTORY = "data" | |
# Ensure directories exist | |
os.makedirs(PDF_DIRECTORY, exist_ok=True) | |
os.makedirs(PERSIST_DIR, exist_ok=True) | |
chat_history = [] | |
current_chat_history = [] | |
def data_ingestion_from_directory(): | |
documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data() | |
storage_context = StorageContext.from_defaults() | |
index = VectorStoreIndex.from_documents(documents) | |
index.storage_context.persist(persist_dir=PERSIST_DIR) | |
def initialize(): | |
start_time = time.time() | |
data_ingestion_from_directory() # Process PDF ingestion at startup | |
print(f"Data ingestion time: {time.time() - start_time} seconds") | |
initialize() # Run initialization tasks | |
def handle_query(query): | |
"""Handles queries using Ollama's local inference""" | |
chat_text_qa_msgs = [ | |
( | |
"user", | |
""" | |
You are the Hotel voice chatbot, and your name is Hotel Helper. Your goal is to provide accurate, professional, and helpful answers based on hotel data. Always keep responses concise (10-15 words max). | |
{context_str} | |
Question: | |
{query_str} | |
""" | |
) | |
] | |
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs) | |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
index = load_index_from_storage(storage_context) | |
context_str = "" | |
for past_query, response in reversed(current_chat_history): | |
if past_query.strip(): | |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n" | |
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str) | |
print(query) | |
try: | |
response = ollama.chat(model="llama3", messages=[{"role": "user", "content": query}]) # Using Ollama | |
response_text = response.get("message", {}).get("content", "Sorry, I couldn't find an answer.") | |
except Exception as e: | |
print(f"Ollama Error: {e}") | |
response_text = "Sorry, something went wrong with the AI model." | |
current_chat_history.append((query, response_text)) | |
return response_text | |
async def load_chat(request: Request, id: str): | |
return templates.TemplateResponse("index.html", {"request": request, "user_id": id}) | |
async def load_voice_chat(request: Request, id: str): | |
return templates.TemplateResponse("voice.html", {"request": request, "user_id": id}) | |
async def chat(request: MessageRequest): | |
"""Handles chat requests and translates responses if needed.""" | |
message = request.message | |
language = request.language | |
language_code = request.language.split('-')[0] | |
response = handle_query(message) # Process the message | |
try: | |
translator = GoogleTranslator(source="en", target=language_code) | |
response_translated = translator.translate(response) | |
except Exception as e: | |
print(f"Translation error: {e}") | |
response_translated = "Sorry, I couldn't translate the response." | |
message_data = { | |
"sender": "User", | |
"message": message, | |
"response": response, | |
"timestamp": datetime.datetime.now().isoformat(), | |
} | |
chat_history.append(message_data) | |
return {"response": response_translated} | |
def read_root(request: Request): | |
return templates.TemplateResponse("home.html", {"request": request}) | |