Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,33 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
from fastapi import FastAPI, Request
|
4 |
-
from fastapi.responses import HTMLResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
6 |
-
from llama_index.core import
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
9 |
from pydantic import BaseModel
|
10 |
-
from fastapi.responses import JSONResponse
|
11 |
-
import uuid # for generating unique IDs
|
12 |
-
import datetime
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
from fastapi.templating import Jinja2Templates
|
15 |
-
|
|
|
16 |
import json
|
17 |
import re
|
18 |
from deep_translator import GoogleTranslator
|
19 |
-
|
20 |
-
|
21 |
|
22 |
# Define Pydantic model for incoming request body
|
23 |
class MessageRequest(BaseModel):
|
24 |
message: str
|
25 |
language: str
|
26 |
|
27 |
-
|
28 |
-
llm_client = InferenceClient(
|
29 |
-
model=repo_id,
|
30 |
-
token=os.getenv("HF_TOKEN"),
|
31 |
-
)
|
32 |
-
|
33 |
-
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
|
34 |
-
|
35 |
app = FastAPI()
|
36 |
|
37 |
@app.middleware("http")
|
@@ -50,29 +46,16 @@ app.add_middleware(
|
|
50 |
allow_headers=["*"],
|
51 |
)
|
52 |
|
53 |
-
|
54 |
-
async def favicon():
|
55 |
-
return HTMLResponse("") # or serve a real favicon if you have one
|
56 |
-
|
57 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
58 |
-
|
59 |
templates = Jinja2Templates(directory="static")
|
60 |
|
61 |
-
# Configure
|
62 |
-
Settings.llm =
|
63 |
-
|
64 |
-
tokenizer_name="meta-llama/Meta-Llama-3-8B",
|
65 |
-
context_window=3000,
|
66 |
-
token=os.getenv("HF_TOKEN"),
|
67 |
-
max_new_tokens=512,
|
68 |
-
generate_kwargs={"temperature": 0.3},
|
69 |
-
)
|
70 |
-
Settings.embed_model = HuggingFaceEmbedding(
|
71 |
-
model_name="BAAI/bge-small-en-v1.5"
|
72 |
-
)
|
73 |
|
74 |
PERSIST_DIR = "db"
|
75 |
-
PDF_DIRECTORY =
|
76 |
|
77 |
# Ensure directories exist
|
78 |
os.makedirs(PDF_DIRECTORY, exist_ok=True)
|
@@ -91,31 +74,15 @@ def initialize():
|
|
91 |
data_ingestion_from_directory() # Process PDF ingestion at startup
|
92 |
print(f"Data ingestion time: {time.time() - start_time} seconds")
|
93 |
|
94 |
-
def split_name(full_name):
|
95 |
-
# Split the name by spaces
|
96 |
-
words = full_name.strip().split()
|
97 |
-
|
98 |
-
# Logic for determining first name and last name
|
99 |
-
if len(words) == 1:
|
100 |
-
first_name = ''
|
101 |
-
last_name = words[0]
|
102 |
-
elif len(words) == 2:
|
103 |
-
first_name = words[0]
|
104 |
-
last_name = words[1]
|
105 |
-
else:
|
106 |
-
first_name = words[0]
|
107 |
-
last_name = ' '.join(words[1:])
|
108 |
-
|
109 |
-
return first_name, last_name
|
110 |
-
|
111 |
initialize() # Run initialization tasks
|
112 |
|
113 |
def handle_query(query):
|
|
|
114 |
chat_text_qa_msgs = [
|
115 |
(
|
116 |
"user",
|
117 |
"""
|
118 |
-
You are the Hotel voice chatbot and your name is
|
119 |
{context_str}
|
120 |
Question:
|
121 |
{query_str}
|
@@ -123,63 +90,63 @@ def handle_query(query):
|
|
123 |
)
|
124 |
]
|
125 |
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
|
126 |
-
|
127 |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
128 |
index = load_index_from_storage(storage_context)
|
129 |
context_str = ""
|
|
|
130 |
for past_query, response in reversed(current_chat_history):
|
131 |
if past_query.strip():
|
132 |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
|
133 |
|
134 |
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
|
|
|
135 |
print(query)
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
current_chat_history.append((query,
|
145 |
-
return
|
146 |
|
147 |
@app.get("/ch/{id}", response_class=HTMLResponse)
|
148 |
async def load_chat(request: Request, id: str):
|
149 |
return templates.TemplateResponse("index.html", {"request": request, "user_id": id})
|
|
|
150 |
@app.get("/voice/{id}", response_class=HTMLResponse)
|
151 |
-
async def
|
152 |
return templates.TemplateResponse("voice.html", {"request": request, "user_id": id})
|
153 |
|
154 |
-
|
155 |
-
|
156 |
@app.post("/chat/")
|
157 |
async def chat(request: MessageRequest):
|
158 |
-
|
|
|
159 |
language = request.language
|
160 |
language_code = request.language.split('-')[0]
|
|
|
161 |
response = handle_query(message) # Process the message
|
162 |
-
|
163 |
try:
|
164 |
-
translator = GoogleTranslator(source=
|
165 |
-
|
166 |
-
#response1 = translator.translate(response, dest=language_code).text
|
167 |
-
print(response1)
|
168 |
except Exception as e:
|
169 |
-
# Handle translation errors
|
170 |
print(f"Translation error: {e}")
|
171 |
-
|
172 |
-
|
173 |
message_data = {
|
174 |
"sender": "User",
|
175 |
"message": message,
|
176 |
"response": response,
|
177 |
-
"timestamp": datetime.datetime.now().isoformat()
|
178 |
}
|
179 |
chat_history.append(message_data)
|
180 |
-
|
|
|
181 |
|
182 |
@app.get("/")
|
183 |
def read_root(request: Request):
|
184 |
return templates.TemplateResponse("home.html", {"request": request})
|
185 |
-
|
|
|
1 |
import os
|
2 |
import time
|
3 |
from fastapi import FastAPI, Request
|
4 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
6 |
+
from llama_index.core import (
|
7 |
+
StorageContext,
|
8 |
+
load_index_from_storage,
|
9 |
+
VectorStoreIndex,
|
10 |
+
SimpleDirectoryReader,
|
11 |
+
ChatPromptTemplate,
|
12 |
+
Settings,
|
13 |
+
)
|
14 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
15 |
from pydantic import BaseModel
|
|
|
|
|
|
|
16 |
from fastapi.middleware.cors import CORSMiddleware
|
17 |
from fastapi.templating import Jinja2Templates
|
18 |
+
import uuid
|
19 |
+
import datetime
|
20 |
import json
|
21 |
import re
|
22 |
from deep_translator import GoogleTranslator
|
23 |
+
import ollama # Import Ollama for inference
|
|
|
24 |
|
25 |
# Define Pydantic model for incoming request body
|
26 |
class MessageRequest(BaseModel):
|
27 |
message: str
|
28 |
language: str
|
29 |
|
30 |
+
# Initialize FastAPI app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
app = FastAPI()
|
32 |
|
33 |
@app.middleware("http")
|
|
|
46 |
allow_headers=["*"],
|
47 |
)
|
48 |
|
49 |
+
# Static files setup
|
|
|
|
|
|
|
50 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
51 |
templates = Jinja2Templates(directory="static")
|
52 |
|
53 |
+
# Configure LlamaIndex settings
|
54 |
+
Settings.llm = None # No need for Hugging Face anymore
|
55 |
+
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
PERSIST_DIR = "db"
|
58 |
+
PDF_DIRECTORY = "data"
|
59 |
|
60 |
# Ensure directories exist
|
61 |
os.makedirs(PDF_DIRECTORY, exist_ok=True)
|
|
|
74 |
data_ingestion_from_directory() # Process PDF ingestion at startup
|
75 |
print(f"Data ingestion time: {time.time() - start_time} seconds")
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
initialize() # Run initialization tasks
|
78 |
|
79 |
def handle_query(query):
|
80 |
+
"""Handles queries using Ollama's local inference"""
|
81 |
chat_text_qa_msgs = [
|
82 |
(
|
83 |
"user",
|
84 |
"""
|
85 |
+
You are the Hotel voice chatbot, and your name is Hotel Helper. Your goal is to provide accurate, professional, and helpful answers based on hotel data. Always keep responses concise (10-15 words max).
|
86 |
{context_str}
|
87 |
Question:
|
88 |
{query_str}
|
|
|
90 |
)
|
91 |
]
|
92 |
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
|
93 |
+
|
94 |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
|
95 |
index = load_index_from_storage(storage_context)
|
96 |
context_str = ""
|
97 |
+
|
98 |
for past_query, response in reversed(current_chat_history):
|
99 |
if past_query.strip():
|
100 |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
|
101 |
|
102 |
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
|
103 |
+
|
104 |
print(query)
|
105 |
+
|
106 |
+
try:
|
107 |
+
response = ollama.chat(model="llama3", messages=[{"role": "user", "content": query}]) # Using Ollama
|
108 |
+
response_text = response.get("message", {}).get("content", "Sorry, I couldn't find an answer.")
|
109 |
+
except Exception as e:
|
110 |
+
print(f"Ollama Error: {e}")
|
111 |
+
response_text = "Sorry, something went wrong with the AI model."
|
112 |
+
|
113 |
+
current_chat_history.append((query, response_text))
|
114 |
+
return response_text
|
115 |
|
116 |
@app.get("/ch/{id}", response_class=HTMLResponse)
|
117 |
async def load_chat(request: Request, id: str):
|
118 |
return templates.TemplateResponse("index.html", {"request": request, "user_id": id})
|
119 |
+
|
120 |
@app.get("/voice/{id}", response_class=HTMLResponse)
|
121 |
+
async def load_voice_chat(request: Request, id: str):
|
122 |
return templates.TemplateResponse("voice.html", {"request": request, "user_id": id})
|
123 |
|
|
|
|
|
124 |
@app.post("/chat/")
|
125 |
async def chat(request: MessageRequest):
|
126 |
+
"""Handles chat requests and translates responses if needed."""
|
127 |
+
message = request.message
|
128 |
language = request.language
|
129 |
language_code = request.language.split('-')[0]
|
130 |
+
|
131 |
response = handle_query(message) # Process the message
|
132 |
+
|
133 |
try:
|
134 |
+
translator = GoogleTranslator(source="en", target=language_code)
|
135 |
+
response_translated = translator.translate(response)
|
|
|
|
|
136 |
except Exception as e:
|
|
|
137 |
print(f"Translation error: {e}")
|
138 |
+
response_translated = "Sorry, I couldn't translate the response."
|
139 |
+
|
140 |
message_data = {
|
141 |
"sender": "User",
|
142 |
"message": message,
|
143 |
"response": response,
|
144 |
+
"timestamp": datetime.datetime.now().isoformat(),
|
145 |
}
|
146 |
chat_history.append(message_data)
|
147 |
+
|
148 |
+
return {"response": response_translated}
|
149 |
|
150 |
@app.get("/")
|
151 |
def read_root(request: Request):
|
152 |
return templates.TemplateResponse("home.html", {"request": request})
|
|