import gradio as gr import requests import json import uuid import concurrent.futures from requests.exceptions import ChunkedEncodingError from src.tarot import tarot_cards import os # Define the endpoints host = os.getenv("BACKEND_URL") default_endpoint = f"{host}/chat/default" memory_endpoint = f"{host}/chat/memory" history_endpoint = f"{host}/view_history" # Define the request payload structure class ChatRequest: def __init__(self, session_id, messages, model_id, temperature, seer_name): self.session_id = session_id self.messages = messages self.model_id = model_id self.temperature = temperature self.seer_name = seer_name class ChatRequestWithMemory(ChatRequest): def __init__(self, session_id, messages, model_id, temperature, seer_name, summary_threshold): super().__init__(session_id, messages, model_id, temperature, seer_name) self.summary_threshold = summary_threshold def compare_chatbots(session_id, messages, model_id, temperature, seer_name, summary_threshold, tarot_card): # Convert messages list to a single string # Prepare the payloads print("tarot_card", tarot_card) payload_default = json.dumps({ "session_id": session_id + "_default", "messages": messages, "model_id": model_id, "temperature": temperature, "tarot_card": tarot_card, "seer_name": seer_name, }) payload_memory = json.dumps({ "session_id": session_id + "_memory", "messages": messages, "model_id": model_id, "temperature": temperature, "seer_name": seer_name, "tarot_card": tarot_card, "summary_threshold": summary_threshold, }) headers = { 'Content-Type': 'application/json' } def call_endpoint(url, payload): try: response = requests.request("POST", url, headers=headers, data=payload) if response.status_code == 200: try: return response.text except requests.exceptions.JSONDecodeError: return "Error: Response is not valid JSON" else: return f"Error: {response.status_code} - {response.text}" except ChunkedEncodingError: return "Error: Response ended prematurely" with concurrent.futures.ThreadPoolExecutor() as executor: future_default = executor.submit(call_endpoint, default_endpoint, payload_default) future_memory = executor.submit(call_endpoint, memory_endpoint, payload_memory) response_default_text = future_default.result() response_memory_text = future_memory.result() return response_default_text, response_memory_text # Function to handle chat interaction def chat_interaction(session_id, message, model_id, temperature, seer_name, summary_threshold, chat_history_default, chat_history_memory, tarot_card): response_default, response_memory = compare_chatbots(session_id, message, model_id, temperature, seer_name, summary_threshold, tarot_card) chat_history_default.append((message, response_default)) chat_history_memory.append((message, response_memory)) message = "" tarot_card = [] return message, chat_history_default, chat_history_memory, tarot_card # Function to reload session ID and clear chat history def reload_session_and_clear_chat(): new_session_id = str(uuid.uuid4()) new_session_id_memory = f"{new_session_id}_memory" return new_session_id, new_session_id_memory, [], [] # Function to load chat history def load_chat_history(session_id): try: response = requests.get(f"{history_endpoint}?session_id={session_id}") if response.status_code == 200: return response.json() else: return {"error": f"Error: {response.status_code} - {response.text}"} except Exception as e: return {"error": str(e)} # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Chatbot Comparison") with gr.Row(): with gr.Column(): gr.Markdown("## Default Chatbot") chatbot_default = gr.Chatbot(elem_id="chatbot_default") with gr.Column(): gr.Markdown("## Memory Chatbot") chatbot_memory = gr.Chatbot(elem_id="chatbot_memory") with gr.Row(): message = gr.Textbox(label="Message", show_label=False, scale=3) submit_button = gr.Button("Submit", scale=1, variant="primary") session_id_default = str(uuid.uuid4()) model_id_choices = [ "llama-3.1-8b-instant", "llama-3.1-70b-versatile", "typhoon-v1.5-instruct", "typhoon-v1.5x-70b-instruct", "gemma2-9b-it", ] with gr.Accordion("Settings", open=False): reload_button = gr.Button("Reload Session", scale=1, variant="secondary") session_id = gr.Textbox(label="Session ID", value=session_id_default) model_id = gr.Dropdown(label="Model ID", choices=model_id_choices, value=model_id_choices[0]) temperature = gr.Slider(0, 1, step=0.1, label="Temperature", value=0.5) seer_name = gr.Textbox(label="Seer Name", value="แม่หมอแพตตี้") tarot_card = gr.Dropdown(label="Tarot Card", value=[], choices=tarot_cards, multiselect=True) summary_threshold = gr.Number(label="Summary Threshold", value=5) with gr.Accordion("View History of Memory Chatbot", open=False): session_id_memory = gr.Textbox(label="Session ID", value=f"{session_id_default}_memory") load_history_button = gr.Button("Load Chat History", scale=1, variant="secondary") # New button chat_history_json = gr.JSON(label="Chat History") # New JSON field submit_button.click( lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction( session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card ), inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card], outputs=[message, chatbot_default, chatbot_memory, tarot_card] ) message.submit( lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction( session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card ), inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card], outputs=[message, chatbot_default, chatbot_memory, tarot_card] ) reload_button.click( reload_session_and_clear_chat, inputs=[], outputs=[session_id, session_id_memory, chatbot_default, chatbot_memory] ) load_history_button.click( load_chat_history, inputs=[session_id_memory], outputs=[chat_history_json] ) # Launch the interface demo.launch(show_api=False)