import gradio as gr
from os import getenv
from dotenv import load_dotenv
import os
import re
from model import ModelManager
# Configurar la variable de entorno para evitar advertencias de tokenizers (huggingface opcional)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from groq import AsyncClient
from fastrtc import WebRTC, ReplyOnPause, audio_to_bytes, AdditionalOutputs
import numpy as np
import asyncio
from elevenlabs.client import ElevenLabs
from elevenlabs import VoiceSettings
from langchain_text_splitters.markdown import MarkdownHeaderTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.messages import HumanMessage, AIMessage
from agent import RestaurantAgent
# Importar las herramientas
from tools import create_menu_info_tool, create_send_to_kitchen_tool
from utils.logger import log_info, log_warn, log_error, log_success, log_debug
load_dotenv()
# Initialize clients and models to None, will be set during runtime
groq_client = None
eleven_client = None
llm = None
waiter_agent = None
# region RAG
md_path = "data/carta.md"
with open(md_path, "r", encoding="utf-8") as file:
md_content = file.read()
splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "seccion_principal"),
("##", "categoria"),
],
strip_headers=False)
splits = splitter.split_text(md_content)
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-m3", model_kwargs = {'device': 'cpu'})
vector_store = InMemoryVectorStore.from_documents(splits, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 4})
# endregion
# Initialize tools to None
guest_info_tool = None
send_to_kitchen_tool = None
tools = None
# region FUNCTIONS
# Function to initialize all components with provided API keys
def initialize_components(openrouter_key, groq_key, elevenlabs_key, model_name):
global groq_client, eleven_client, llm, waiter_agent, guest_info_tool, send_to_kitchen_tool, tools
log_info("Initializing components with provided API keys...")
# Initialize clients with provided keys
if groq_key:
groq_client = AsyncClient(api_key=groq_key)
if elevenlabs_key:
eleven_client = ElevenLabs(api_key=elevenlabs_key)
if openrouter_key:
# Initialize LLM
model_manager = ModelManager(
api_key=openrouter_key,
api_base=getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"),
model_name=model_name,
helicone_api_key=getenv("HELICONE_API_KEY", "")
)
llm = model_manager.create_model()
# Initialize tools
guest_info_tool = create_menu_info_tool(retriever)
send_to_kitchen_tool = create_send_to_kitchen_tool(llm=llm)
tools = [guest_info_tool, send_to_kitchen_tool]
# Initialize the agent
waiter_agent = RestaurantAgent(
llm=llm,
tools=tools
)
log_success("Components initialized successfully.")
else:
log_warn("OpenRouter API key is required for LLM initialization.")
return {
"groq_client": groq_client is not None,
"eleven_client": eleven_client is not None,
"llm": llm is not None,
"agent": waiter_agent is not None
}
def extract_user_response(assistant_text):
"""Extrae solo la respuesta del usuario, eliminando el bloque """
# Buscar el patrón y extraer todo lo que viene después
think_pattern = r'.*?\s*'
user_response = re.sub(think_pattern, '', assistant_text, flags=re.DOTALL)
return user_response.strip()
async def handle_text_input(message, history, openrouter_key, groq_key, elevenlabs_key, model_name):
"""Handles text input, generates response, updates chat history."""
global waiter_agent, llm
# Initialize components if needed
if waiter_agent is None or llm is None or model_name != getattr(llm, "model_name", ""):
status = initialize_components(openrouter_key, groq_key, elevenlabs_key, model_name)
if not status["agent"]:
return history + [
{"role": "user", "content": message},
{"role": "assistant", "content": "Error: Could not initialize the agent. Please check your API keys."}
]
current_history = history if isinstance(history, list) else []
log_info("-" * 20)
log_info(f"Received text input: '{message}', current history: {current_history}")
try:
# 1. Actualizar el historial con el mensaje del usuario
user_message = {"role": "user", "content": message}
history_with_user = current_history + [user_message]
# 2. Invocar al agente con la consulta del usuario
log_info("Iniciando procesamiento con LangGraph...")
# Invocar el agente con el texto de la consulta
langchain_messages = []
for msg in current_history:
if msg["role"] == "user":
langchain_messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
langchain_messages.append(AIMessage(content=msg["content"]))
langchain_messages.append(HumanMessage(content=message))
graph_result = waiter_agent.invoke(langchain_messages)
log_debug(f"Resultado del agente: {graph_result}")
messages = graph_result.get("messages", [])
assistant_text = ""
for msg in reversed(messages):
# LangChain puede devolver diferentes clases de mensajes
if hasattr(msg, "__class__") and msg.__class__.__name__ == "AIMessage":
assistant_text = msg.content
break
if not assistant_text:
log_warn("No se encontró respuesta del asistente en los mensajes.")
assistant_text = "Lo siento, no sé cómo responder a eso."
log_info(f"Assistant text: '{assistant_text}'")
# Extraer solo la respuesta del usuario
user_visible_response = extract_user_response(assistant_text)
log_info(f"User visible response: '{user_visible_response}'")
# 3. Actualizar el historial con el mensaje del asistente (solo la parte visible)
assistant_message = {"role": "assistant", "content": user_visible_response}
final_history = history_with_user + [assistant_message]
log_success("Tarea completada con éxito.")
return final_history
except Exception as e:
log_error(f"Error in handle_text_input function: {e}")
import traceback
traceback.print_exc()
return current_history + [
{"role": "user", "content": message},
{"role": "assistant", "content": f"Error: {str(e)}"}
]
async def response(audio: tuple[int, np.ndarray], history, openrouter_key, groq_key, elevenlabs_key, model_name):
"""Handles audio input, generates response, yields UI updates and audio."""
global waiter_agent, llm, groq_client, eleven_client
# Initialize components if needed
if waiter_agent is None or llm is None or groq_client is None or eleven_client is None or model_name != getattr(llm, "model_name", ""):
status = initialize_components(openrouter_key, groq_key, elevenlabs_key, model_name)
if not status["groq_client"]:
yield AdditionalOutputs(history + [{"role": "assistant", "content": "Error: Groq API key is required for audio processing."}])
return
if not status["eleven_client"]:
yield AdditionalOutputs(history + [{"role": "assistant", "content": "Error: ElevenLabs API key is required for audio processing."}])
return
if not status["agent"]:
yield AdditionalOutputs(history + [{"role": "assistant", "content": "Error: Could not initialize the agent. Please check your OpenRouter API key."}])
return
current_history = history if isinstance(history, list) else []
log_info("-" * 20)
log_info(f"Received audio, current history: {current_history}")
try:
# 1. Transcribir el audio a texto
audio_bytes = audio_to_bytes(audio)
transcript = await groq_client.audio.transcriptions.create(
file=("audio-file.mp3", audio_bytes),
model="whisper-large-v3-turbo",
response_format="verbose_json",
)
user_text = transcript.text.strip()
log_info(f"Transcription: '{user_text}'")
# 2. Actualizar el historial con el mensaje del usuario
user_message = {"role": "user", "content": user_text}
history_with_user = current_history + [user_message]
log_info(f"Yielding user message update to UI: {history_with_user}")
yield AdditionalOutputs(history_with_user)
await asyncio.sleep(0.04) # Permite que la UI se actualice antes de continuar
# 4. Invocar al agente con la consulta del usuario
log_info("Iniciando procesamiento con LangGraph...")
langchain_messages = []
for msg in current_history:
if msg["role"] == "user":
langchain_messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
langchain_messages.append(AIMessage(content=msg["content"]))
langchain_messages.append(HumanMessage(content=user_text))
graph_result = waiter_agent.invoke(langchain_messages)
log_debug(f"Resultado del agente: {graph_result}")
# Extraer la respuesta del último mensaje del asistente
messages = graph_result.get("messages", [])
assistant_text = ""
# Buscar el último mensaje del asistente
for msg in reversed(messages):
if hasattr(msg, "__class__") and msg.__class__.__name__ == "AIMessage":
assistant_text = msg.content
break
if not assistant_text:
log_warn("No se encontró respuesta del asistente en los mensajes.")
assistant_text = "Lo siento, no sé cómo responder a eso."
log_info(f"Assistant text: '{assistant_text}'")
user_visible_response = extract_user_response(assistant_text)
log_info(f"User visible response: '{user_visible_response}'")
# 5. Actualizar el historial con el mensaje del asistente
assistant_message = {"role": "assistant", "content": user_visible_response}
final_history = history_with_user + [assistant_message]
# 6. Generar la respuesta de voz
log_info("Generating TTS...")
TARGET_SAMPLE_RATE = 24000 # <<< --- Tasa de muestreo deseada
tts_stream_generator = eleven_client.text_to_speech.convert(
text=user_visible_response,
voice_id="Nh2zY9kknu6z4pZy6FhD",
model_id="eleven_flash_v2_5",
output_format="pcm_24000",
voice_settings=VoiceSettings(
stability=0.0,
similarity_boost=1.0,
style=0.0,
use_speaker_boost=True,
speed=1.1,
)
)
# --- Procesar los chunks a medida que llegan ---
log_info("Receiving and processing TTS audio chunks...")
audio_chunks = []
total_bytes = 0
for chunk in tts_stream_generator:
total_bytes += len(chunk)
# Convertir chunk actual de bytes PCM (int16) a float32 normalizado
if chunk:
audio_int16 = np.frombuffer(chunk, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
audio_float32 = np.clip(audio_float32, -1.0, 1.0) # Asegurar rango
audio_chunks.append(audio_float32)
log_info(f"Received {total_bytes} bytes of TTS audio in total.")
# Concatenar todos los chunks procesados
if audio_chunks:
final_audio = np.concatenate(audio_chunks)
log_info(f"Processed {len(final_audio)} audio samples.")
else:
log_warn("Warning: TTS returned empty audio stream.")
final_audio = np.array([], dtype=np.float32)
# Crear la tupla final
tts_output_tuple = (TARGET_SAMPLE_RATE, final_audio)
log_debug(f"TTS output: {tts_output_tuple}")
log_success("Tarea completada con éxito.")
yield tts_output_tuple
yield AdditionalOutputs(final_history)
except Exception as e:
log_error(f"Error in response function: {e}")
import traceback
traceback.print_exc()
yield np.array([]).astype(np.int16).tobytes()
yield AdditionalOutputs(current_history + [{"role": "assistant", "content": f"Error: {str(e)}"}])
# endregion
with gr.Blocks() as demo:
gr.Markdown("# WAIter Chatbot")
gr.Markdown("See your order being uploaded [here](https://kitchen-dashboard-seven.vercel.app/)")
with gr.Row():
text_openrouter_api_key = gr.Textbox(
label="OpenRouter API Key (required)",
placeholder="Enter your OpenRouter API key",
value=getenv("OPENROUTER_API_KEY") or "",
type="password",
)
text_groq_api_key = gr.Textbox(
label="Groq API Key (required for audio)",
placeholder="Enter your Groq API key",
value=getenv("GROQ_API_KEY") or "",
type="password",
)
text_elevenlabs_api_key = gr.Textbox(
label="Elevenlabs API Key (required for audio)",
placeholder="Enter your Elevenlabs API key",
value=getenv("ELEVENLABS_API_KEY") or "",
type="password",
)
chatbot = gr.Chatbot(
label="Agent",
type="messages",
value=[],
avatar_images=(
None, # User avatar
"https://em-content.zobj.net/source/twitter/376/hugging-face_1f917.png", # Assistant
),
)
with gr.Row():
model_dropdown = gr.Dropdown(
label="Select Model",
choices=["google/gemini-2.5-flash-preview-05-20"],
)
text_input = gr.Textbox(
label="Type your message",
placeholder="Type here and press Enter...",
show_label=True,
)
audio = WebRTC(
label="Speak Here",
mode="send-receive",
modality="audio",
)
text_input.submit(
fn=handle_text_input,
inputs=[
text_input,
chatbot,
text_openrouter_api_key,
text_groq_api_key,
text_elevenlabs_api_key,
model_dropdown
],
outputs=[chatbot],
api_name="submit_text"
).then(
fn=lambda: "", # Limpiar el campo de texto
outputs=[text_input]
)
# Se encarga de manejar la entrada de audio
audio.stream(
fn=ReplyOnPause(
response,
can_interrupt=True,
),
inputs=[audio, chatbot, text_openrouter_api_key, text_groq_api_key, text_elevenlabs_api_key, model_dropdown],
outputs=[audio],
)
# Actualiza el historial de la conversación
audio.on_additional_outputs(
fn=lambda history_update: history_update, # Envia el historial actualizado
outputs=[chatbot], # Actualiza el chatbot
)
if __name__ == "__main__":
demo.launch()