Crystalcareai commited on
Commit
af60662
·
verified ·
1 Parent(s): 982d629

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -645
app.py CHANGED
@@ -1,675 +1,344 @@
1
  import gradio as gr
 
 
 
 
2
  import json
3
- import os
4
- import requests
5
- from cryptography.fernet import Fernet
6
- from huggingface_hub import HfApi
7
- from datetime import datetime
8
- from dataclasses import dataclass, asdict
9
- from typing import List, Dict, Tuple, Optional
10
  import asyncio
11
- import aiohttp
12
- from concurrent.futures import ThreadPoolExecutor
13
- import httpx
14
-
15
- @dataclass
16
- class ModelComparison:
17
- name: str
18
- nick1: str
19
- endpoint1: str
20
- api_key1: str # This will store encrypted key
21
- model1: str
22
- nick2: str
23
- endpoint2: str
24
- api_key2: str # This will store encrypted key
25
- model2: str
26
- active: bool = True
27
- created_at: str = None
28
 
29
- def __post_init__(self):
30
- if self.created_at is None:
31
- self.created_at = datetime.now().isoformat()
32
-
33
- def to_dict(self) -> dict:
34
- return asdict(self)
35
-
36
- @classmethod
37
- def from_dict(cls, data: dict) -> 'ModelComparison':
38
- return cls(**data)
39
 
40
  @dataclass
41
- class Vote:
42
- comparison_id: str
43
- message: str
44
- model1_response: List[Dict]
45
- model2_response: List[Dict]
46
- winner: str
47
- vote_info: str = ""
48
- timestamp: str = None
49
-
50
- def __post_init__(self):
51
- if self.timestamp is None:
52
- self.timestamp = datetime.now().isoformat()
53
-
54
- def to_dict(self) -> dict:
55
- return asdict(self)
56
 
57
- @classmethod
58
- def from_dict(cls, data: dict) -> 'Vote':
59
- return cls(**data)
60
-
61
- class VoteManager:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def __init__(self):
63
- self.votes: List[Vote] = []
64
-
65
- def add_vote(self, vote: Vote):
66
- self.votes.append(vote)
67
 
68
- def get_votes_for_comparison(self, comparison_id: str) -> List[Vote]:
69
- return [vote for vote in self.votes if vote.comparison_id == comparison_id]
 
70
 
71
- def get_results(self, comparison: ModelComparison) -> dict:
72
- comparison_votes = self.get_votes_for_comparison(comparison.name)
73
- total_votes = len(comparison_votes)
74
- model1_votes = len([vote for vote in comparison_votes if vote.winner == comparison.nick1])
75
- model2_votes = len([vote for vote in comparison_votes if vote.winner == comparison.nick2])
76
- even_votes = len([vote for vote in comparison_votes if vote.winner == "Even"])
77
-
78
- return {
79
- "total_votes": total_votes,
80
- "model1_name": comparison.nick1,
81
- "model2_name": comparison.nick2,
82
- "model1_votes": model1_votes,
83
- "model2_votes": model2_votes,
84
- "even_votes": even_votes,
85
- "recent_votes": [vote.to_dict() for vote in comparison_votes[-5:]]
86
- }
87
-
88
- def to_dict_list(self) -> List[dict]:
89
- return [vote.to_dict() for vote in self.votes]
90
-
91
- @classmethod
92
- def from_dict_list(cls, data: List[dict]) -> 'VoteManager':
93
- manager = cls()
94
- manager.votes = [Vote.from_dict(vote_data) for vote_data in data]
95
- return manager
96
-
97
- class ModelComparisonApp:
98
- def __init__(self):
99
- # Initialize encryption key from environment variable
100
- self.encryption_key = os.environ.get('ENCRYPTION_KEY')
101
- if not self.encryption_key:
102
- raise ValueError("ENCRYPTION_KEY environment variable not set")
103
 
104
- self.fernet = Fernet(self.encryption_key.encode())
105
- self.hf_token = os.environ.get('HF_TOKEN')
106
- self.api = HfApi(token=self.hf_token)
107
- self.dataset_repo_id = os.environ.get('REPO_ID')
 
 
108
 
109
- # Initialize datasets
110
- self.models_file = "models.json"
111
- self.votes_file = "votes.json"
112
- self.comparisons: Dict[str, ModelComparison] = {}
113
- self.vote_manager = VoteManager()
114
- self.load_data()
115
-
116
- def load_data(self):
117
- """Load model configurations and votes from datasets"""
118
- try:
119
- model_file_path = self.api.hf_hub_download(
120
- repo_id=self.dataset_repo_id,
121
- filename=self.models_file,
122
- repo_type="dataset",
123
- )
124
- with open(model_file_path, "r") as f:
125
- models_data = json.load(f)
126
- self.comparisons = {
127
- name: ModelComparison.from_dict(data)
128
- for name, data in models_data.items()
129
- }
130
- except Exception as e:
131
- print(f"Error loading models: {e}")
132
- self.comparisons = {}
133
-
134
- try:
135
- votes_file_path = self.api.hf_hub_download(
136
- repo_id=self.dataset_repo_id,
137
- filename=self.votes_file,
138
- repo_type="dataset",
139
  )
140
- with open(votes_file_path, "r") as f:
141
- votes_data = json.load(f)
142
- self.vote_manager = VoteManager.from_dict_list(votes_data)
143
- except Exception as e:
144
- print(f"Error loading votes: {e}")
145
- self.vote_manager = VoteManager()
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- def update_results_and_status(self, comparison_id):
149
- if not comparison_id:
150
- return [
151
- [[0, "", 0, "", 0]],
152
- "No comparison selected",
153
- "Activate",
154
- ]
155
-
156
- results = self.get_comparison_results(comparison_id)
157
- is_active = self.comparisons[comparison_id].active
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- return [
160
- results["total_votes"],
161
- results["model1_name"],
162
- results["model2_name"],
163
- results["model1_votes"],
164
- results["model2_votes"],
165
- results["even_votes"],
166
- "Active" if is_active else "Inactive",
167
- "Deactivate" if is_active else "Activate"
168
- ]
169
-
170
- def load_fresh_data(self):
171
- """Force a fresh data load from HuggingFace"""
172
- self.load_data() # This reloads from HuggingFace
173
 
174
- def refresh_results(self, comparison_id: str):
175
- """Refresh results for a specific comparison with fresh data"""
176
- self.load_fresh_data()
177
- return self.update_results_and_status(comparison_id)
178
-
179
- def save_data(self, file_name: str, data: dict):
180
- """Save data to HuggingFace dataset"""
181
- self.api.upload_file(
182
- path_or_fileobj=json.dumps(data, indent=2).encode(),
183
- path_in_repo=file_name,
184
- repo_id=self.dataset_repo_id,
185
- repo_type="dataset"
186
- )
187
-
188
- def save_models(self):
189
- """Save model comparisons to file"""
190
- models_dict = {
191
- name: comparison.to_dict()
192
- for name, comparison in self.comparisons.items()
193
  }
194
- self.save_data(self.models_file, models_dict)
195
-
196
- def save_votes(self):
197
- """Save votes to file"""
198
- self.save_data(self.votes_file, self.vote_manager.to_dict_list())
199
-
200
- def encrypt_api_key(self, api_key: str) -> str:
201
- """Encrypt API key using Fernet"""
202
- return self.fernet.encrypt(api_key.encode()).decode()
203
-
204
- def decrypt_api_key(self, encrypted_key: str) -> str:
205
- """Decrypt API key using Fernet"""
206
- return self.fernet.decrypt(encrypted_key.encode()).decode()
207
-
208
- def get_active_comparisons(self) -> List[str]:
209
- """Get list of active comparison names"""
210
- return [
211
- name for name, comparison in self.comparisons.items()
212
- if comparison.active
213
- ]
214
-
215
- def get_all_comparisons(self) -> List[str]:
216
- """Get list of all comparison names"""
217
- return list(self.comparisons.keys())
218
-
219
- def add_model_comparison(self, name, nick1, endpoint1, api_key1, model1, nick2, endpoint2, api_key2, model2):
220
- """Add a new model comparison configuration"""
221
- if name in self.comparisons:
222
- return f"Model comparison '{name}' already exists", None, None
223
-
224
- comparison = ModelComparison(
225
- name=name,
226
- nick1=nick1,
227
- endpoint1=endpoint1,
228
- api_key1=self.encrypt_api_key(api_key1),
229
- model1=model1,
230
- nick2=nick2,
231
- endpoint2=endpoint2,
232
- api_key2=self.encrypt_api_key(api_key2),
233
- model2=model2
 
 
 
 
234
  )
235
 
236
- self.comparisons[name] = comparison
237
- self.save_models()
238
-
239
- active_comparisons = self.get_active_comparisons()
240
- all_comparisons = self.get_all_comparisons()
241
- return (
242
- "Model comparison added successfully!",
243
- gr.update(choices=active_comparisons, value=active_comparisons[0]),
244
- gr.update(choices=all_comparisons, value=all_comparisons[0]),
245
- gr.update(visible=True)
246
  )
247
-
248
- async def get_model_response_async(
249
- self,
250
- session: aiohttp.ClientSession,
251
- endpoint: str,
252
- api_key: str,
253
- model: str,
254
- message: str
255
- ) -> Tuple[str, Optional[str]]:
256
- """Get response from a model using OpenAI-compatible API asynchronously,
257
- now with streaming support for models.arcee.ai endpoints."""
258
- try:
259
- headers = {
260
- "Authorization": f"Bearer {api_key}",
261
- "Content-Type": "application/json"
262
- }
263
-
264
- payload = {
265
- "model": model,
266
- "messages": [{"role": "user", "content": message}],
267
- "temperature": 0.7
268
- }
269
-
270
- if not endpoint.endswith("/chat/completions"):
271
- if endpoint.endswith("/"):
272
- endpoint = f"{endpoint}chat/completions"
273
- else:
274
- endpoint = f"{endpoint}/chat/completions"
275
-
276
- # For models from models.arcee.ai, switch to streaming via httpx + HTTP/2
277
- if "models.arcee.ai" in endpoint:
278
- collected_chunks = []
279
- async with httpx.AsyncClient(http2=True) as client:
280
- async with client.stream(
281
- "POST",
282
- endpoint,
283
- headers=headers,
284
- json={**payload, "stream": True}, # enable streaming
285
- timeout=30.0
286
- ) as response:
287
- if response.status_code != 200:
288
- error_data = await response.aread()
289
- return "", f"Error: HTTP {response.status_code}, {error_data.decode('utf-8')}"
290
-
291
- buffer = []
292
- async for line in response.aiter_lines():
293
- if line.startswith("data: "):
294
- # parse partial chunks
295
- line_data = line.replace("data: ", "").strip()
296
- if line_data == "[DONE]":
297
- break # end of stream
298
- try:
299
- json_response = json.loads(line_data)
300
- delta = json_response["choices"][0].get("delta", {})
301
- if "content" in delta:
302
- buffer.append(delta["content"])
303
- # If we see any punctuation or have a decent buffer, flush it
304
- if len(buffer) >= 10 or any(c in ".,!?\n" for c in buffer[-1]):
305
- collected_chunks.extend(buffer)
306
- buffer = []
307
- except json.JSONDecodeError:
308
- continue
309
- # Flush any remaining
310
- if buffer:
311
- collected_chunks.extend(buffer)
312
-
313
- return "".join(collected_chunks), None
314
- else:
315
- # Original aiohttp approach for non-arcee endpoints
316
- async with session.post(endpoint, headers=headers, json=payload) as response:
317
- if response.status != 200:
318
- error_msg = f"Error: HTTP {response.status}"
319
- try:
320
- error_data = await response.json()
321
- if 'error' in error_data:
322
- error_msg = f"Error: {error_data['error'].get('message', str(error_data))}"
323
- except:
324
- pass
325
- return "", error_msg
326
-
327
- response_data = await response.json()
328
- return response_data["choices"][0]["message"]["content"], None
329
-
330
- except Exception as e:
331
- return "", f"Error: {str(e)}"
332
-
333
- async def compare_models_async(self, comparison_id: str, message: str):
334
- """Compare two models concurrently and get their responses"""
335
- config = self.comparisons[comparison_id]
336
 
337
- async with aiohttp.ClientSession() as session:
338
- # Create tasks for both API calls
339
- task1 = self.get_model_response_async(
340
- session,
341
- config.endpoint1,
342
- self.decrypt_api_key(config.api_key1),
343
- config.model1,
344
- message
345
  )
346
-
347
- task2 = self.get_model_response_async(
348
- session,
349
- config.endpoint2,
350
- self.decrypt_api_key(config.api_key2),
351
- config.model2,
352
- message
353
  )
354
-
355
- # Run both tasks concurrently
356
- response1, error1 = await task1
357
- response2, error2 = await task2
358
-
359
- # Format responses, including error messages if any
360
- response1_formatted = [
361
- {"role": "user", "content": message},
362
- {"role": "assistant", "content": response1 if not error1 else f"⚠️ {error1}"}
363
- ]
364
 
365
- response2_formatted = [
366
- {"role": "user", "content": message},
367
- {"role": "assistant", "content": response2 if not error2 else f"⚠️ {error2}"}
368
- ]
 
369
 
370
- return (
371
- gr.update(type='messages', value=response1_formatted),
372
- gr.update(type='messages', value=response2_formatted),
373
- gr.update(interactive=True)
 
374
  )
375
-
376
- def compare_models(self, comparison_id: str, message: str):
377
- """Synchronous wrapper for the async comparison function"""
378
- loop = asyncio.new_event_loop()
379
- asyncio.set_event_loop(loop)
380
- try:
381
- return loop.run_until_complete(
382
- self.compare_models_async(comparison_id, message)
383
- )
384
- finally:
385
- loop.close()
386
-
387
- def toggle_comparison_status(self, comparison_id: str) -> tuple[str, str]:
388
- """Toggle a model comparison between active and inactive states"""
389
- if comparison_id not in self.comparisons:
390
- return "Comparison not found!", "Deactivate"
391
-
392
- self.comparisons[comparison_id].active = not self.comparisons[comparison_id].active
393
- self.save_models()
394
 
395
- new_status = "active" if self.comparisons[comparison_id].active else "inactive"
396
- new_button_text = "Deactivate" if self.comparisons[comparison_id].active else "Activate"
397
-
398
- return f"Comparison is now {new_status}!", new_button_text
399
-
400
- def add_vote(self, comparison_id: str, message: str, response1_output: List[Dict], response2_output: List[Dict], winner: str, vote_info: str) -> str:
401
- """Record a vote for model comparison"""
402
-
403
- if winner is None or winner == "":
404
- return "Please select a voting option", gr.Button(), gr.Textbox(), None
405
-
406
- config = self.comparisons[comparison_id]
407
-
408
- if winner == "Response 1":
409
- winner = config.nick1
410
- elif winner == "Response 2":
411
- winner = config.nick2
412
-
413
- vote = Vote(
414
- comparison_id=comparison_id,
415
- message=message,
416
- model1_response=response1_output,
417
- model2_response=response2_output,
418
- winner=winner,
419
- vote_info=vote_info,
420
- )
421
-
422
- self.vote_manager.add_vote(vote)
423
- self.save_votes()
424
- return "Vote recorded successfully!", gr.update(interactive=False), gr.update(value=""), gr.update(value=None)
425
-
426
- def get_comparison_results(self, comparison_id: str) -> Optional[dict]:
427
- """Get voting results for a specific comparison"""
428
- if not comparison_id or comparison_id not in self.comparisons:
429
- return None
430
-
431
- comparison = self.comparisons[comparison_id]
432
- return self.vote_manager.get_results(comparison)
433
-
434
- def update_comparison_dropdown(self):
435
- active_comparisons = self.get_active_comparisons()
436
-
437
- return gr.update(
438
- choices=active_comparisons,
439
- value=active_comparisons[0] if active_comparisons else None
440
  )
441
 
442
- def create_interface(self):
443
- """Create Gradio interface"""
444
- with gr.Blocks() as interface:
445
- gr.Markdown("# Model Comparison Tool")
446
-
447
- # Get initial active comparisons
448
- active_comparisons = self.get_active_comparisons()
449
- all_comparisons = self.get_all_comparisons()
450
- first_active = active_comparisons[0] if active_comparisons else None
451
- first_comparison = all_comparisons[0] if all_comparisons else None
452
-
453
- # Store button and outputs we'll need to reference later
454
- add_btn = None
455
- add_output = None
456
-
457
- with gr.Tab("Compare Models", id="compare_models_tab") as compare_models_tab:
458
- comparison_dropdown = gr.Dropdown(
459
- choices=active_comparisons,
460
- label="Select Comparison",
461
- value=first_active,
462
- interactive=True,
463
- filterable=False
464
- )
465
- with gr.Row():
466
- response1_output = gr.Chatbot(label="Model 1", type='messages')
467
- response2_output = gr.Chatbot(label="Model 2", type='messages')
468
-
469
- message_input = gr.Textbox(
470
- label="Enter your message",
471
- lines=4,
472
- max_lines=10,
473
- placeholder="Press Enter for new line, Shift + Enter to submit",
474
- autofocus=True,
475
- )
476
- compare_btn = gr.Button("Send", visible=bool(active_comparisons))
477
-
478
- with gr.Row():
479
- vote_radio = gr.Radio(["Response 1", "Even", "Response 2"], label="Vote for better response")
480
-
481
- with gr.Row():
482
- vote_info = gr.Textbox(
483
- label="Vote Explanation",
484
- lines=2,
485
- max_lines=5,
486
- placeholder="Optional: Add a reason for your vote",
487
- )
488
- vote_btn = gr.Button("Submit Vote", interactive=False)
489
- vote_output = gr.Textbox(label="Vote Result")
490
-
491
- compare_models_tab.select(
492
- fn=self.update_comparison_dropdown,
493
- inputs=[],
494
- outputs=comparison_dropdown
495
- )
496
-
497
- with gr.Tab("Add Model Comparison"):
498
- name = gr.Textbox(label="Comparison Name")
499
-
500
- with gr.Row():
501
- nick1 = gr.Textbox(label="Model 1 nickname")
502
- nick2 = gr.Textbox(label="Model 2 nickname")
503
-
504
- with gr.Row():
505
- endpoint1 = gr.Textbox(label="Endpoint 1")
506
- endpoint2 = gr.Textbox(label="Endpoint 2")
507
-
508
- with gr.Row():
509
- api_key1 = gr.Textbox(label="API Key 1", type="password")
510
- api_key2 = gr.Textbox(label="API Key 2", type="password")
511
-
512
- with gr.Row():
513
- model1 = gr.Textbox(label="Model 1")
514
- model2 = gr.Textbox(label="Model 2")
515
-
516
- add_btn = gr.Button("Add Comparison")
517
- add_output = gr.Textbox(label="Result")
518
-
519
- with gr.Tab("Results", id="results_tab") as results_tab:
520
- self.load_fresh_data()
521
-
522
- results_comparison_dropdown = gr.Dropdown(
523
- choices=all_comparisons, # Show all comparisons
524
- label="Select Comparison",
525
- value=first_comparison
526
- )
527
-
528
- initial_results = self.get_comparison_results(first_comparison)
529
-
530
- with gr.Row():
531
- total_votes = gr.Textbox(label="Total votes", value=initial_results["total_votes"] if initial_results else 0, interactive=False)
532
-
533
- with gr.Row():
534
- model1_name = gr.Textbox(label="Model 1", value=initial_results["model1_name"] if initial_results else "", interactive=False)
535
- model2_name = gr.Textbox(label="Model 2", value=initial_results["model2_name"] if initial_results else "", interactive=False)
536
-
537
- with gr.Row():
538
- model1_votes = gr.Textbox(label="Model 1 votes", value=initial_results["model1_votes"] if initial_results else 0, interactive=False)
539
- model2_votes = gr.Textbox(label="Model 2 votes", value=initial_results["model2_votes"] if initial_results else 0, interactive=False)
540
-
541
- with gr.Row():
542
- even_votes = gr.Textbox(label="Even votes", value=initial_results["even_votes"] if initial_results else 0, interactive=False)
543
-
544
-
545
- # Add status indicator
546
- status_text = gr.Textbox(
547
- label="Status",
548
- value="Active" if first_comparison and self.comparisons[first_comparison].active else "Inactive",
549
- interactive=False
550
- )
551
-
552
- toggle_btn = gr.Button(
553
- "Deactivate" if first_comparison and self.comparisons[first_comparison].active else "Activate"
554
- )
555
- toggle_output = gr.Textbox(label="Toggle Result")
556
-
557
- results_tab.select(
558
- fn=lambda x: self.refresh_results(x),
559
- inputs=[results_comparison_dropdown],
560
- outputs=[
561
- total_votes,
562
- model1_name,
563
- model2_name,
564
- model1_votes,
565
- model2_votes,
566
- even_votes,
567
- status_text,
568
- toggle_btn
569
- ]
570
- )
571
-
572
- # Update component interactions
573
- results_comparison_dropdown.change(
574
- fn=self.refresh_results,
575
- inputs=[results_comparison_dropdown],
576
- outputs=[
577
- total_votes,
578
- model1_name,
579
- model2_name,
580
- model1_votes,
581
- model2_votes,
582
- even_votes,
583
- status_text,
584
- toggle_btn
585
- ]
586
- )
587
-
588
- toggle_btn.click(
589
- fn=self.toggle_comparison_status,
590
- inputs=[results_comparison_dropdown],
591
- outputs=[toggle_output, toggle_btn]
592
- ).then(
593
- fn=lambda: (
594
- gr.update(choices=self.get_active_comparisons()),
595
- gr.update(choices=self.get_all_comparisons())
596
- ),
597
- inputs=[],
598
- outputs=[comparison_dropdown, results_comparison_dropdown]
599
- ).then( # Add another refresh after toggle
600
- fn=self.refresh_results,
601
- inputs=[results_comparison_dropdown],
602
- outputs=[
603
- total_votes,
604
- model1_name,
605
- model2_name,
606
- model1_votes,
607
- model2_votes,
608
- even_votes,
609
- status_text,
610
- toggle_btn
611
- ]
612
- )
613
-
614
- comparison_dropdown.change(
615
- fn=lambda: (
616
- gr.update(value=""),
617
- gr.update(value=""),
618
- gr.update(interactive=False),
619
- gr.update(value=""),
620
- gr.update(value=None)
621
- ),
622
- inputs=[],
623
- outputs=[response1_output, response2_output, vote_btn, vote_info, vote_radio]
624
- )
625
-
626
- # Set up comparison tab interactions
627
- compare_btn.click(
628
- fn=self.compare_models,
629
- inputs=[comparison_dropdown, message_input],
630
- outputs=[response1_output, response2_output, vote_btn]
631
- )
632
-
633
- message_input.submit(
634
- fn=self.compare_models,
635
- inputs=[comparison_dropdown, message_input],
636
- outputs=[response1_output, response2_output, vote_btn]
637
- )
638
-
639
- vote_btn.click(
640
- fn=self.add_vote,
641
- inputs=[comparison_dropdown, message_input, response1_output, response2_output, vote_radio, vote_info],
642
- outputs=[vote_output, vote_btn, vote_info, vote_radio]
643
- )
644
-
645
- # .then(
646
- # fn=self.refresh_results, # Use the refresh method that forces data reload
647
- # inputs=[results_comparison_dropdown],
648
- # outputs=[
649
- # total_votes,
650
- # model1_name,
651
- # model2_name,
652
- # model1_votes,
653
- # model2_votes,
654
- # even_votes,
655
- # status_text,
656
- # toggle_btn
657
- # ]
658
- # )
659
-
660
- # Set up add model comparison tab interactions
661
- add_btn.click(
662
- fn=self.add_model_comparison,
663
- inputs=[name, nick1, endpoint1, api_key1, model1, nick2, endpoint2, api_key2, model2],
664
- outputs=[add_output, comparison_dropdown, results_comparison_dropdown, compare_btn]
665
- )
666
-
667
- return interface
668
-
669
- def main():
670
- app = ModelComparisonApp()
671
- interface = app.create_interface()
672
- interface.launch()
673
-
674
  if __name__ == "__main__":
675
- main()
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ from typing import Dict, List, Optional, Generator, AsyncGenerator
4
+ from dataclasses import dataclass
5
+ import httpx
6
  import json
 
 
 
 
 
 
 
7
  import asyncio
8
+ import openai
9
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ arcee_api_key = os.environ.get("arcee_api_key")
12
+ openrouter_api_key = os.environ.get("openrouter_api_key")
 
 
 
 
 
 
 
 
13
 
14
  @dataclass
15
+ class ModelConfig:
16
+ name: str
17
+ base_url: str
18
+ api_key: str
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ MODEL_CONFIGS = {
21
+ 1: ModelConfig(
22
+ name="virtuoso-small",
23
+ base_url="https://models.arcee.ai/v1/chat/completions",
24
+ api_key=arcee_api_key
25
+ ),
26
+ 2: ModelConfig(
27
+ name="virtuoso-medium",
28
+ base_url="https://models.arcee.ai/v1/chat/completions",
29
+ api_key=arcee_api_key
30
+ ),
31
+ 3: ModelConfig(
32
+ name="virtuoso-large",
33
+ base_url="https://models.arcee.ai/v1/chat/completions",
34
+ api_key=arcee_api_key
35
+ ),
36
+ 4: ModelConfig(
37
+ name="anthropic/claude-3.5-sonnet",
38
+ base_url="https://openrouter.ai/api/v1/chat/completions",
39
+ api_key=openrouter_api_key
40
+ )
41
+ }
42
+
43
+ class ModelUsageStats:
44
  def __init__(self):
45
+ self.usage_counts = {i: 0 for i in range(1, 5)}
46
+ self.total_queries = 0
 
 
47
 
48
+ def update(self, complexity: int):
49
+ self.usage_counts[complexity] += 1
50
+ self.total_queries += 1
51
 
52
+ def get_stats(self) -> str:
53
+ if self.total_queries == 0:
54
+ return "No queries processed yet."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ model_names = {
57
+ 1: "virtuoso-small",
58
+ 2: "virtuoso-medium",
59
+ 3: "virtuoso-large",
60
+ 4: "claude-3-sonnet"
61
+ }
62
 
63
+ stats = []
64
+ for complexity, count in self.usage_counts.items():
65
+ percentage = (count / self.total_queries) * 100
66
+ stats.append(f"{model_names[complexity]}: {count} uses ({percentage:.1f}%)")
67
+ return "\n".join(stats)
68
+
69
+ stats = ModelUsageStats()
70
+
71
+ async def get_complexity(prompt: str) -> int:
72
+ try:
73
+ async with httpx.AsyncClient(http2=True) as client:
74
+ response = await client.post(
75
+ "http://185.216.20.86:8000/complexity",
76
+ headers={"Content-Type": "application/json"},
77
+ json={"prompt": prompt},
78
+ timeout=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
80
+ response.raise_for_status()
81
+ return response.json()["complexity"]
82
+ except Exception as e:
83
+ print(f"Error getting complexity: {e}")
84
+ return 3 # Default to medium complexity on error
85
+
86
+ async def get_model_response(message: str, history: List[Dict[str, str]], complexity: int) -> AsyncGenerator[str, None]:
87
+ model_config = MODEL_CONFIGS[complexity]
88
+
89
+ headers = {
90
+ "Content-Type": "application/json"
91
+ }
92
+
93
+ if "openrouter.ai" in model_config.base_url:
94
+ headers.update({
95
+ "HTTP-Referer": "https://github.com/lucataco/gradio-router",
96
+ "X-Title": "Gradio Router",
97
+ "Authorization": f"Bearer {model_config.api_key}"
98
+ })
99
+ elif "arcee.ai" in model_config.base_url:
100
+ headers.update({
101
+ "Authorization": f"Bearer {model_config.api_key}"
102
+ })
103
+
104
+ try:
105
+ collected_chunks = []
106
+ # For Arcee.ai models, use direct API call with HTTP/2
107
+ if "arcee.ai" in model_config.base_url:
108
+ messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
109
+ for msg in history:
110
+ # Clean content
111
+ content = msg["content"]
112
+ if isinstance(content, str):
113
+ content = content.split("\n\n<div")[0]
114
+ messages.append({"role": msg["role"], "content": content})
115
+ messages.append({"role": "user", "content": message})
116
 
117
+ async with httpx.AsyncClient(http2=True) as client:
118
+ async with client.stream(
119
+ "POST",
120
+ model_config.base_url,
121
+ headers=headers,
122
+ json={
123
+ "model": model_config.name,
124
+ "messages": messages,
125
+ "temperature": 0.7,
126
+ "stream": True
127
+ },
128
+ timeout=30.0
129
+ ) as response:
130
+ response.raise_for_status()
131
+ buffer = []
132
+ async for line in response.aiter_lines():
133
+ if line.startswith("data: "):
134
+ try:
135
+ json_response = json.loads(line.replace("data: ", ""))
136
+ if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'):
137
+ buffer.append(json_response['choices'][0]['delta']['content'])
138
+ if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]):
139
+ collected_chunks.extend(buffer)
140
+ yield "".join(collected_chunks)
141
+ buffer = []
142
+ except json.JSONDecodeError:
143
+ continue
144
+ if buffer: # Yield any remaining content
145
+ collected_chunks.extend(buffer)
146
+ yield "".join(collected_chunks)
147
+
148
+ # For OpenRouter models, use direct API call with streaming
149
+ else:
150
+ messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
151
+ for msg in history:
152
+ content = msg["content"]
153
+ if isinstance(content, str):
154
+ content = content.split("\n\n<div")[0]
155
+ messages.append({"role": msg["role"], "content": content})
156
+ messages.append({"role": "user", "content": message})
157
+
158
+ async with httpx.AsyncClient(http2=True) as client:
159
+ async with client.stream(
160
+ "POST",
161
+ model_config.base_url,
162
+ headers=headers,
163
+ json={
164
+ "model": model_config.name,
165
+ "messages": messages,
166
+ "temperature": 0.7,
167
+ "stream": True
168
+ },
169
+ timeout=30.0
170
+ ) as response:
171
+ response.raise_for_status()
172
+ buffer = []
173
+ async for line in response.aiter_lines():
174
+ if line.startswith("data: "):
175
+ try:
176
+ json_response = json.loads(line.replace("data: ", ""))
177
+ if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'):
178
+ buffer.append(json_response['choices'][0]['delta']['content'])
179
+ if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]):
180
+ collected_chunks.extend(buffer)
181
+ yield "".join(collected_chunks)
182
+ buffer = []
183
+ except json.JSONDecodeError:
184
+ continue
185
+ if buffer: # Yield any remaining content
186
+ collected_chunks.extend(buffer)
187
+ yield "".join(collected_chunks)
188
+
189
+ except Exception as e:
190
+ error_msg = str(e)
191
+ print(f"Error getting model response: {error_msg}")
192
+ if "464" in error_msg:
193
+ yield "Error: Authentication failed. Please check your API key and try again."
194
+ elif "Internal Server Error" in error_msg:
195
+ yield "Error: The server encountered an internal error. Please try again later."
196
+ else:
197
+ yield f"Error: Unable to get response from {model_config.name}. {error_msg}"
198
+
199
+ async def chat_wrapper(
200
+ message: str,
201
+ history: List[Dict[str, str]],
202
+ system_message: str,
203
+ max_tokens: int,
204
+ temperature: float,
205
+ top_p: float,
206
+ model_usage_stats: str,
207
+ ):
208
+ complexity = await get_complexity(message)
209
+ stats.update(complexity)
210
+ model_name = MODEL_CONFIGS[complexity].name
211
+
212
+ # Convert history for model
213
+ model_history = []
214
+ for msg in history:
215
+ if isinstance(msg, dict) and "role" in msg and "content" in msg:
216
+ # Clean content
217
+ content = msg["content"]
218
+ if isinstance(content, str):
219
+ content = content.split("\n\n<div")[0]
220
+ model_history.append({"role": msg["role"], "content": content})
221
+
222
+ # Stream the response
223
+ full_response = ""
224
+ async for partial_response in get_model_response(message, model_history, complexity):
225
+ full_response = partial_response
226
+ response_with_info = f"{full_response}\n\n<div class='model-info'>Model: {model_name}</div>"
227
 
228
+ # Update stats display
229
+ stats_text = stats.get_stats()
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ yield [
232
+ *history,
233
+ {"role": "user", "content": message},
234
+ {"role": "assistant", "content": response_with_info}
235
+ ], stats_text
236
+
237
+ with gr.Blocks(
238
+ theme=gr.themes.Soft(
239
+ primary_hue="blue",
240
+ secondary_hue="indigo",
241
+ neutral_hue="slate",
242
+ font=("Inter", "system-ui", "sans-serif")
243
+ ),
244
+ css="""
245
+ .container {
246
+ max-width: 1000px;
247
+ margin: auto;
248
+ padding: 2rem;
 
249
  }
250
+ .title {
251
+ text-align: center;
252
+ font-size: 2.5rem;
253
+ font-weight: 600;
254
+ margin: 1rem 0;
255
+ background: linear-gradient(to right, var(--primary-500), var(--secondary-500));
256
+ -webkit-background-clip: text;
257
+ -webkit-text-fill-color: transparent;
258
+ }
259
+ .subtitle {
260
+ text-align: center;
261
+ font-size: 1.1rem;
262
+ color: var(--neutral-700);
263
+ margin-bottom: 2rem;
264
+ font-weight: 400;
265
+ }
266
+ .model-info {
267
+ font-style: italic;
268
+ color: var(--neutral-500);
269
+ font-size: 0.85em;
270
+ margin-top: 1em;
271
+ padding-top: 0.5em;
272
+ border-top: 1px solid var(--neutral-200);
273
+ opacity: 0.8;
274
+ }
275
+ .stats-box {
276
+ margin-top: 1rem;
277
+ padding: 1rem;
278
+ border-radius: 0.75rem;
279
+ background: color-mix(in srgb, var(--background-fill) 80%, transparent);
280
+ border: 1px solid var(--neutral-200);
281
+ font-family: monospace;
282
+ white-space: pre-line;
283
+ }
284
+ .message.assistant {
285
+ padding-bottom: 1.5em !important;
286
+ }
287
+ """
288
+ ) as demo:
289
+ with gr.Column(elem_classes="container"):
290
+ gr.Markdown("# AI Model Router", elem_classes="title")
291
+ gr.Markdown(
292
+ "Your message will be routed to the appropriate AI model based on complexity.",
293
+ elem_classes="subtitle"
294
  )
295
 
296
+ chatbot = gr.Chatbot(
297
+ value=[],
298
+ bubble_full_width=False,
299
+ show_label=False,
300
+ height=450,
301
+ container=True,
302
+ type="messages"
 
 
 
303
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ with gr.Row():
306
+ txt = gr.Textbox(
307
+ show_label=False,
308
+ placeholder="Enter your message here...",
309
+ container=False,
310
+ scale=7
 
 
311
  )
312
+ clear = gr.ClearButton(
313
+ [txt, chatbot],
314
+ scale=1,
315
+ variant="secondary",
316
+ size="sm"
 
 
317
  )
 
 
 
 
 
 
 
 
 
 
318
 
319
+ with gr.Accordion("Advanced Settings", open=False):
320
+ system_message = gr.Textbox(value="You are a helpful AI assistant.", label="System message")
321
+ max_tokens = gr.Slider(minimum=16, maximum=4096, value=2048, step=1, label="Max Tokens")
322
+ temperature = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Temperature")
323
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
324
 
325
+ stats_display = gr.Textbox(
326
+ value=stats.get_stats(),
327
+ label="Model Usage Statistics",
328
+ interactive=False,
329
+ elem_classes="stats-box"
330
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ # Set up event handler for streaming
333
+ txt.submit(
334
+ chat_wrapper,
335
+ [txt, chatbot, system_message, max_tokens, temperature, top_p, stats_display],
336
+ [chatbot, stats_display],
337
+ ).then(
338
+ lambda: "",
339
+ None,
340
+ [txt],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  )
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  if __name__ == "__main__":
344
+ demo.queue().launch()