Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
import uuid
|
5 |
+
import concurrent.futures
|
6 |
+
from requests.exceptions import ChunkedEncodingError
|
7 |
+
from src.tarot import tarot_cards
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Define the endpoints
|
11 |
+
host = os.getenv("BACKEND_URL")
|
12 |
+
default_endpoint = f"{host}/chat/default"
|
13 |
+
memory_endpoint = f"{host}/chat/memory"
|
14 |
+
history_endpoint = f"{host}/view_history"
|
15 |
+
|
16 |
+
# Define the request payload structure
|
17 |
+
class ChatRequest:
|
18 |
+
def __init__(self, session_id, messages, model_id, temperature, seer_name):
|
19 |
+
self.session_id = session_id
|
20 |
+
self.messages = messages
|
21 |
+
self.model_id = model_id
|
22 |
+
self.temperature = temperature
|
23 |
+
self.seer_name = seer_name
|
24 |
+
|
25 |
+
class ChatRequestWithMemory(ChatRequest):
|
26 |
+
def __init__(self, session_id, messages, model_id, temperature, seer_name, summary_threshold):
|
27 |
+
super().__init__(session_id, messages, model_id, temperature, seer_name)
|
28 |
+
self.summary_threshold = summary_threshold
|
29 |
+
|
30 |
+
def compare_chatbots(session_id, messages, model_id, temperature, seer_name, summary_threshold, tarot_card):
|
31 |
+
# Convert messages list to a single string
|
32 |
+
# Prepare the payloads
|
33 |
+
print("tarot_card", tarot_card)
|
34 |
+
payload_default = json.dumps({
|
35 |
+
"session_id": session_id + "_default",
|
36 |
+
"messages": messages,
|
37 |
+
"model_id": model_id,
|
38 |
+
"temperature": temperature,
|
39 |
+
"tarot_card": tarot_card,
|
40 |
+
"seer_name": seer_name,
|
41 |
+
})
|
42 |
+
payload_memory = json.dumps({
|
43 |
+
"session_id": session_id + "_memory",
|
44 |
+
"messages": messages,
|
45 |
+
"model_id": model_id,
|
46 |
+
"temperature": temperature,
|
47 |
+
"seer_name": seer_name,
|
48 |
+
"tarot_card": tarot_card,
|
49 |
+
"summary_threshold": summary_threshold,
|
50 |
+
})
|
51 |
+
headers = {
|
52 |
+
'Content-Type': 'application/json'
|
53 |
+
}
|
54 |
+
|
55 |
+
def call_endpoint(url, payload):
|
56 |
+
try:
|
57 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
58 |
+
if response.status_code == 200:
|
59 |
+
try:
|
60 |
+
return response.text
|
61 |
+
except requests.exceptions.JSONDecodeError:
|
62 |
+
return "Error: Response is not valid JSON"
|
63 |
+
else:
|
64 |
+
return f"Error: {response.status_code} - {response.text}"
|
65 |
+
except ChunkedEncodingError:
|
66 |
+
return "Error: Response ended prematurely"
|
67 |
+
|
68 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
69 |
+
future_default = executor.submit(call_endpoint, default_endpoint, payload_default)
|
70 |
+
future_memory = executor.submit(call_endpoint, memory_endpoint, payload_memory)
|
71 |
+
|
72 |
+
response_default_text = future_default.result()
|
73 |
+
response_memory_text = future_memory.result()
|
74 |
+
|
75 |
+
return response_default_text, response_memory_text
|
76 |
+
|
77 |
+
# Function to handle chat interaction
|
78 |
+
def chat_interaction(session_id, message, model_id, temperature, seer_name, summary_threshold, chat_history_default, chat_history_memory, tarot_card):
|
79 |
+
response_default, response_memory = compare_chatbots(session_id, message, model_id, temperature, seer_name, summary_threshold, tarot_card)
|
80 |
+
|
81 |
+
chat_history_default.append((message, response_default))
|
82 |
+
chat_history_memory.append((message, response_memory))
|
83 |
+
|
84 |
+
message = ""
|
85 |
+
tarot_card = []
|
86 |
+
return message, chat_history_default, chat_history_memory, tarot_card
|
87 |
+
|
88 |
+
# Function to reload session ID and clear chat history
|
89 |
+
def reload_session_and_clear_chat():
|
90 |
+
new_session_id = str(uuid.uuid4())
|
91 |
+
new_session_id_memory = f"{new_session_id}_memory"
|
92 |
+
return new_session_id, new_session_id_memory, [], []
|
93 |
+
|
94 |
+
# Function to load chat history
|
95 |
+
def load_chat_history(session_id):
|
96 |
+
try:
|
97 |
+
response = requests.get(f"{history_endpoint}?session_id={session_id}")
|
98 |
+
if response.status_code == 200:
|
99 |
+
return response.json()
|
100 |
+
else:
|
101 |
+
return {"error": f"Error: {response.status_code} - {response.text}"}
|
102 |
+
except Exception as e:
|
103 |
+
return {"error": str(e)}
|
104 |
+
|
105 |
+
# Create the Gradio interface
|
106 |
+
with gr.Blocks() as demo:
|
107 |
+
gr.Markdown("# Chatbot Comparison")
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
with gr.Column():
|
111 |
+
gr.Markdown("## Default Chatbot")
|
112 |
+
chatbot_default = gr.Chatbot(elem_id="chatbot_default")
|
113 |
+
|
114 |
+
with gr.Column():
|
115 |
+
gr.Markdown("## Memory Chatbot")
|
116 |
+
chatbot_memory = gr.Chatbot(elem_id="chatbot_memory")
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
message = gr.Textbox(label="Message", show_label=False, scale=3)
|
120 |
+
submit_button = gr.Button("Submit", scale=1, variant="primary")
|
121 |
+
|
122 |
+
session_id_default = str(uuid.uuid4())
|
123 |
+
|
124 |
+
model_id_choices = [
|
125 |
+
"llama-3.1-8b-instant",
|
126 |
+
"llama-3.1-70b-versatile",
|
127 |
+
"typhoon-v1.5-instruct",
|
128 |
+
"typhoon-v1.5x-70b-instruct",
|
129 |
+
"gemma2-9b-it",
|
130 |
+
]
|
131 |
+
|
132 |
+
with gr.Accordion("Settings", open=False):
|
133 |
+
reload_button = gr.Button("Reload Session", scale=1, variant="secondary")
|
134 |
+
session_id = gr.Textbox(label="Session ID", value=session_id_default)
|
135 |
+
model_id = gr.Dropdown(label="Model ID", choices=model_id_choices, value=model_id_choices[0])
|
136 |
+
temperature = gr.Slider(0, 1, step=0.1, label="Temperature", value=0.5)
|
137 |
+
seer_name = gr.Textbox(label="Seer Name", value="แม่หมอแพตตี้")
|
138 |
+
tarot_card = gr.Dropdown(label="Tarot Card", value=[], choices=tarot_cards, multiselect=True)
|
139 |
+
summary_threshold = gr.Number(label="Summary Threshold", value=5)
|
140 |
+
|
141 |
+
with gr.Accordion("View History of Memory Chatbot", open=False):
|
142 |
+
session_id_memory = gr.Textbox(label="Session ID", value=f"{session_id_default}_memory")
|
143 |
+
load_history_button = gr.Button("Load Chat History", scale=1, variant="secondary") # New button
|
144 |
+
chat_history_json = gr.JSON(label="Chat History") # New JSON field
|
145 |
+
|
146 |
+
submit_button.click(
|
147 |
+
lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction(
|
148 |
+
session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card
|
149 |
+
),
|
150 |
+
inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card],
|
151 |
+
outputs=[message, chatbot_default, chatbot_memory, tarot_card]
|
152 |
+
)
|
153 |
+
|
154 |
+
message.submit(
|
155 |
+
lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction(
|
156 |
+
session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card
|
157 |
+
),
|
158 |
+
inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card],
|
159 |
+
outputs=[message, chatbot_default, chatbot_memory, tarot_card]
|
160 |
+
)
|
161 |
+
|
162 |
+
reload_button.click(
|
163 |
+
reload_session_and_clear_chat,
|
164 |
+
inputs=[],
|
165 |
+
outputs=[session_id, session_id_memory, chatbot_default, chatbot_memory]
|
166 |
+
)
|
167 |
+
|
168 |
+
load_history_button.click(
|
169 |
+
load_chat_history,
|
170 |
+
inputs=[session_id_memory],
|
171 |
+
outputs=[chat_history_json]
|
172 |
+
)
|
173 |
+
|
174 |
+
# Launch the interface
|
175 |
+
demo.launch(show_api=False)
|