saksornr's picture
Create app.py
d8bc601 verified
raw
history blame
7.22 kB
import gradio as gr
import requests
import json
import uuid
import concurrent.futures
from requests.exceptions import ChunkedEncodingError
from src.tarot import tarot_cards
import os
# Define the endpoints
host = os.getenv("BACKEND_URL")
default_endpoint = f"{host}/chat/default"
memory_endpoint = f"{host}/chat/memory"
history_endpoint = f"{host}/view_history"
# Define the request payload structure
class ChatRequest:
def __init__(self, session_id, messages, model_id, temperature, seer_name):
self.session_id = session_id
self.messages = messages
self.model_id = model_id
self.temperature = temperature
self.seer_name = seer_name
class ChatRequestWithMemory(ChatRequest):
def __init__(self, session_id, messages, model_id, temperature, seer_name, summary_threshold):
super().__init__(session_id, messages, model_id, temperature, seer_name)
self.summary_threshold = summary_threshold
def compare_chatbots(session_id, messages, model_id, temperature, seer_name, summary_threshold, tarot_card):
# Convert messages list to a single string
# Prepare the payloads
print("tarot_card", tarot_card)
payload_default = json.dumps({
"session_id": session_id + "_default",
"messages": messages,
"model_id": model_id,
"temperature": temperature,
"tarot_card": tarot_card,
"seer_name": seer_name,
})
payload_memory = json.dumps({
"session_id": session_id + "_memory",
"messages": messages,
"model_id": model_id,
"temperature": temperature,
"seer_name": seer_name,
"tarot_card": tarot_card,
"summary_threshold": summary_threshold,
})
headers = {
'Content-Type': 'application/json'
}
def call_endpoint(url, payload):
try:
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
try:
return response.text
except requests.exceptions.JSONDecodeError:
return "Error: Response is not valid JSON"
else:
return f"Error: {response.status_code} - {response.text}"
except ChunkedEncodingError:
return "Error: Response ended prematurely"
with concurrent.futures.ThreadPoolExecutor() as executor:
future_default = executor.submit(call_endpoint, default_endpoint, payload_default)
future_memory = executor.submit(call_endpoint, memory_endpoint, payload_memory)
response_default_text = future_default.result()
response_memory_text = future_memory.result()
return response_default_text, response_memory_text
# Function to handle chat interaction
def chat_interaction(session_id, message, model_id, temperature, seer_name, summary_threshold, chat_history_default, chat_history_memory, tarot_card):
response_default, response_memory = compare_chatbots(session_id, message, model_id, temperature, seer_name, summary_threshold, tarot_card)
chat_history_default.append((message, response_default))
chat_history_memory.append((message, response_memory))
message = ""
tarot_card = []
return message, chat_history_default, chat_history_memory, tarot_card
# Function to reload session ID and clear chat history
def reload_session_and_clear_chat():
new_session_id = str(uuid.uuid4())
new_session_id_memory = f"{new_session_id}_memory"
return new_session_id, new_session_id_memory, [], []
# Function to load chat history
def load_chat_history(session_id):
try:
response = requests.get(f"{history_endpoint}?session_id={session_id}")
if response.status_code == 200:
return response.json()
else:
return {"error": f"Error: {response.status_code} - {response.text}"}
except Exception as e:
return {"error": str(e)}
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Chatbot Comparison")
with gr.Row():
with gr.Column():
gr.Markdown("## Default Chatbot")
chatbot_default = gr.Chatbot(elem_id="chatbot_default")
with gr.Column():
gr.Markdown("## Memory Chatbot")
chatbot_memory = gr.Chatbot(elem_id="chatbot_memory")
with gr.Row():
message = gr.Textbox(label="Message", show_label=False, scale=3)
submit_button = gr.Button("Submit", scale=1, variant="primary")
session_id_default = str(uuid.uuid4())
model_id_choices = [
"llama-3.1-8b-instant",
"llama-3.1-70b-versatile",
"typhoon-v1.5-instruct",
"typhoon-v1.5x-70b-instruct",
"gemma2-9b-it",
]
with gr.Accordion("Settings", open=False):
reload_button = gr.Button("Reload Session", scale=1, variant="secondary")
session_id = gr.Textbox(label="Session ID", value=session_id_default)
model_id = gr.Dropdown(label="Model ID", choices=model_id_choices, value=model_id_choices[0])
temperature = gr.Slider(0, 1, step=0.1, label="Temperature", value=0.5)
seer_name = gr.Textbox(label="Seer Name", value="แม่หมอแพตตี้")
tarot_card = gr.Dropdown(label="Tarot Card", value=[], choices=tarot_cards, multiselect=True)
summary_threshold = gr.Number(label="Summary Threshold", value=5)
with gr.Accordion("View History of Memory Chatbot", open=False):
session_id_memory = gr.Textbox(label="Session ID", value=f"{session_id_default}_memory")
load_history_button = gr.Button("Load Chat History", scale=1, variant="secondary") # New button
chat_history_json = gr.JSON(label="Chat History") # New JSON field
submit_button.click(
lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction(
session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card
),
inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card],
outputs=[message, chatbot_default, chatbot_memory, tarot_card]
)
message.submit(
lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction(
session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card
),
inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card],
outputs=[message, chatbot_default, chatbot_memory, tarot_card]
)
reload_button.click(
reload_session_and_clear_chat,
inputs=[],
outputs=[session_id, session_id_memory, chatbot_default, chatbot_memory]
)
load_history_button.click(
load_chat_history,
inputs=[session_id_memory],
outputs=[chat_history_json]
)
# Launch the interface
demo.launch(show_api=False)