saksornr commited on
Commit
d8bc601
·
verified ·
1 Parent(s): a28debb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import json
4
+ import uuid
5
+ import concurrent.futures
6
+ from requests.exceptions import ChunkedEncodingError
7
+ from src.tarot import tarot_cards
8
+ import os
9
+
10
+ # Define the endpoints
11
+ host = os.getenv("BACKEND_URL")
12
+ default_endpoint = f"{host}/chat/default"
13
+ memory_endpoint = f"{host}/chat/memory"
14
+ history_endpoint = f"{host}/view_history"
15
+
16
+ # Define the request payload structure
17
+ class ChatRequest:
18
+ def __init__(self, session_id, messages, model_id, temperature, seer_name):
19
+ self.session_id = session_id
20
+ self.messages = messages
21
+ self.model_id = model_id
22
+ self.temperature = temperature
23
+ self.seer_name = seer_name
24
+
25
+ class ChatRequestWithMemory(ChatRequest):
26
+ def __init__(self, session_id, messages, model_id, temperature, seer_name, summary_threshold):
27
+ super().__init__(session_id, messages, model_id, temperature, seer_name)
28
+ self.summary_threshold = summary_threshold
29
+
30
+ def compare_chatbots(session_id, messages, model_id, temperature, seer_name, summary_threshold, tarot_card):
31
+ # Convert messages list to a single string
32
+ # Prepare the payloads
33
+ print("tarot_card", tarot_card)
34
+ payload_default = json.dumps({
35
+ "session_id": session_id + "_default",
36
+ "messages": messages,
37
+ "model_id": model_id,
38
+ "temperature": temperature,
39
+ "tarot_card": tarot_card,
40
+ "seer_name": seer_name,
41
+ })
42
+ payload_memory = json.dumps({
43
+ "session_id": session_id + "_memory",
44
+ "messages": messages,
45
+ "model_id": model_id,
46
+ "temperature": temperature,
47
+ "seer_name": seer_name,
48
+ "tarot_card": tarot_card,
49
+ "summary_threshold": summary_threshold,
50
+ })
51
+ headers = {
52
+ 'Content-Type': 'application/json'
53
+ }
54
+
55
+ def call_endpoint(url, payload):
56
+ try:
57
+ response = requests.request("POST", url, headers=headers, data=payload)
58
+ if response.status_code == 200:
59
+ try:
60
+ return response.text
61
+ except requests.exceptions.JSONDecodeError:
62
+ return "Error: Response is not valid JSON"
63
+ else:
64
+ return f"Error: {response.status_code} - {response.text}"
65
+ except ChunkedEncodingError:
66
+ return "Error: Response ended prematurely"
67
+
68
+ with concurrent.futures.ThreadPoolExecutor() as executor:
69
+ future_default = executor.submit(call_endpoint, default_endpoint, payload_default)
70
+ future_memory = executor.submit(call_endpoint, memory_endpoint, payload_memory)
71
+
72
+ response_default_text = future_default.result()
73
+ response_memory_text = future_memory.result()
74
+
75
+ return response_default_text, response_memory_text
76
+
77
+ # Function to handle chat interaction
78
+ def chat_interaction(session_id, message, model_id, temperature, seer_name, summary_threshold, chat_history_default, chat_history_memory, tarot_card):
79
+ response_default, response_memory = compare_chatbots(session_id, message, model_id, temperature, seer_name, summary_threshold, tarot_card)
80
+
81
+ chat_history_default.append((message, response_default))
82
+ chat_history_memory.append((message, response_memory))
83
+
84
+ message = ""
85
+ tarot_card = []
86
+ return message, chat_history_default, chat_history_memory, tarot_card
87
+
88
+ # Function to reload session ID and clear chat history
89
+ def reload_session_and_clear_chat():
90
+ new_session_id = str(uuid.uuid4())
91
+ new_session_id_memory = f"{new_session_id}_memory"
92
+ return new_session_id, new_session_id_memory, [], []
93
+
94
+ # Function to load chat history
95
+ def load_chat_history(session_id):
96
+ try:
97
+ response = requests.get(f"{history_endpoint}?session_id={session_id}")
98
+ if response.status_code == 200:
99
+ return response.json()
100
+ else:
101
+ return {"error": f"Error: {response.status_code} - {response.text}"}
102
+ except Exception as e:
103
+ return {"error": str(e)}
104
+
105
+ # Create the Gradio interface
106
+ with gr.Blocks() as demo:
107
+ gr.Markdown("# Chatbot Comparison")
108
+
109
+ with gr.Row():
110
+ with gr.Column():
111
+ gr.Markdown("## Default Chatbot")
112
+ chatbot_default = gr.Chatbot(elem_id="chatbot_default")
113
+
114
+ with gr.Column():
115
+ gr.Markdown("## Memory Chatbot")
116
+ chatbot_memory = gr.Chatbot(elem_id="chatbot_memory")
117
+
118
+ with gr.Row():
119
+ message = gr.Textbox(label="Message", show_label=False, scale=3)
120
+ submit_button = gr.Button("Submit", scale=1, variant="primary")
121
+
122
+ session_id_default = str(uuid.uuid4())
123
+
124
+ model_id_choices = [
125
+ "llama-3.1-8b-instant",
126
+ "llama-3.1-70b-versatile",
127
+ "typhoon-v1.5-instruct",
128
+ "typhoon-v1.5x-70b-instruct",
129
+ "gemma2-9b-it",
130
+ ]
131
+
132
+ with gr.Accordion("Settings", open=False):
133
+ reload_button = gr.Button("Reload Session", scale=1, variant="secondary")
134
+ session_id = gr.Textbox(label="Session ID", value=session_id_default)
135
+ model_id = gr.Dropdown(label="Model ID", choices=model_id_choices, value=model_id_choices[0])
136
+ temperature = gr.Slider(0, 1, step=0.1, label="Temperature", value=0.5)
137
+ seer_name = gr.Textbox(label="Seer Name", value="แม่หมอแพตตี้")
138
+ tarot_card = gr.Dropdown(label="Tarot Card", value=[], choices=tarot_cards, multiselect=True)
139
+ summary_threshold = gr.Number(label="Summary Threshold", value=5)
140
+
141
+ with gr.Accordion("View History of Memory Chatbot", open=False):
142
+ session_id_memory = gr.Textbox(label="Session ID", value=f"{session_id_default}_memory")
143
+ load_history_button = gr.Button("Load Chat History", scale=1, variant="secondary") # New button
144
+ chat_history_json = gr.JSON(label="Chat History") # New JSON field
145
+
146
+ submit_button.click(
147
+ lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction(
148
+ session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card
149
+ ),
150
+ inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card],
151
+ outputs=[message, chatbot_default, chatbot_memory, tarot_card]
152
+ )
153
+
154
+ message.submit(
155
+ lambda session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction(
156
+ session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card
157
+ ),
158
+ inputs=[session_id, message, model_id, temperature, seer_name, summary_threshold, chatbot_default, chatbot_memory, tarot_card],
159
+ outputs=[message, chatbot_default, chatbot_memory, tarot_card]
160
+ )
161
+
162
+ reload_button.click(
163
+ reload_session_and_clear_chat,
164
+ inputs=[],
165
+ outputs=[session_id, session_id_memory, chatbot_default, chatbot_memory]
166
+ )
167
+
168
+ load_history_button.click(
169
+ load_chat_history,
170
+ inputs=[session_id_memory],
171
+ outputs=[chat_history_json]
172
+ )
173
+
174
+ # Launch the interface
175
+ demo.launch(show_api=False)