import gradio as gr from huggingface_hub import InferenceClient from typing import Dict, List, Optional, Generator, AsyncGenerator from dataclasses import dataclass import httpx import json import asyncio import openai import os arcee_api_key = os.environ.get("arcee_api_key") openrouter_api_key = os.environ.get("openrouter_api_key") @dataclass class ModelConfig: name: str base_url: str api_key: str MODEL_CONFIGS = { 1: ModelConfig( name="virtuoso-small", base_url="https://models.arcee.ai/v1/chat/completions", api_key=arcee_api_key ), 2: ModelConfig( name="virtuoso-medium", base_url="https://models.arcee.ai/v1/chat/completions", api_key=arcee_api_key ), 3: ModelConfig( name="virtuoso-large", base_url="https://models.arcee.ai/v1/chat/completions", api_key=arcee_api_key ), 4: ModelConfig( name="anthropic/claude-3.5-sonnet", base_url="https://openrouter.ai/api/v1/chat/completions", api_key=openrouter_api_key ) } class ModelUsageStats: def __init__(self): self.usage_counts = {i: 0 for i in range(1, 5)} self.total_queries = 0 def update(self, complexity: int): self.usage_counts[complexity] += 1 self.total_queries += 1 def get_stats(self) -> str: if self.total_queries == 0: return "No queries processed yet." model_names = { 1: "virtuoso-small", 2: "virtuoso-medium", 3: "virtuoso-large", 4: "claude-3-sonnet" } stats = [] for complexity, count in self.usage_counts.items(): percentage = (count / self.total_queries) * 100 stats.append(f"{model_names[complexity]}: {count} uses ({percentage:.1f}%)") return "\n".join(stats) stats = ModelUsageStats() async def get_complexity(prompt: str) -> int: try: async with httpx.AsyncClient(http2=True) as client: response = await client.post( "http://185.216.20.86:8000/complexity", headers={"Content-Type": "application/json"}, json={"prompt": prompt}, timeout=10 ) response.raise_for_status() return response.json()["complexity"] except Exception as e: print(f"Error getting complexity: {e}") return 3 # Default to medium complexity on error async def get_model_response(message: str, history: List[Dict[str, str]], complexity: int) -> AsyncGenerator[str, None]: model_config = MODEL_CONFIGS[complexity] headers = { "Content-Type": "application/json" } if "openrouter.ai" in model_config.base_url: headers.update({ "HTTP-Referer": "https://github.com/lucataco/gradio-router", "X-Title": "Gradio Router", "Authorization": f"Bearer {model_config.api_key}" }) elif "arcee.ai" in model_config.base_url: headers.update({ "Authorization": f"Bearer {model_config.api_key}" }) try: collected_chunks = [] # For Arcee.ai models, use direct API call with HTTP/2 if "arcee.ai" in model_config.base_url: messages = [{"role": "system", "content": "You are a helpful AI assistant."}] for msg in history: # Clean content content = msg["content"] if isinstance(content, str): content = content.split("\n\n= 10 or any(c in '.,!?\n' for c in buffer[-1]): collected_chunks.extend(buffer) yield "".join(collected_chunks) buffer = [] except json.JSONDecodeError: continue if buffer: # Yield any remaining content collected_chunks.extend(buffer) yield "".join(collected_chunks) # For OpenRouter models, use direct API call with streaming else: messages = [{"role": "system", "content": "You are a helpful AI assistant."}] for msg in history: content = msg["content"] if isinstance(content, str): content = content.split("\n\n= 10 or any(c in '.,!?\n' for c in buffer[-1]): collected_chunks.extend(buffer) yield "".join(collected_chunks) buffer = [] except json.JSONDecodeError: continue if buffer: # Yield any remaining content collected_chunks.extend(buffer) yield "".join(collected_chunks) except Exception as e: error_msg = str(e) print(f"Error getting model response: {error_msg}") if "464" in error_msg: yield "Error: Authentication failed. Please check your API key and try again." elif "Internal Server Error" in error_msg: yield "Error: The server encountered an internal error. Please try again later." else: yield f"Error: Unable to get response from {model_config.name}. {error_msg}" async def chat_wrapper( message: str, history: List[Dict[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, model_usage_stats: str, ): complexity = await get_complexity(message) stats.update(complexity) model_name = MODEL_CONFIGS[complexity].name # Convert history for model model_history = [] for msg in history: if isinstance(msg, dict) and "role" in msg and "content" in msg: # Clean content content = msg["content"] if isinstance(content, str): content = content.split("\n\nModel: {model_name}" # Update stats display stats_text = stats.get_stats() yield [ *history, {"role": "user", "content": message}, {"role": "assistant", "content": response_with_info} ], stats_text with gr.Blocks( theme=gr.themes.Soft( primary_hue="blue", secondary_hue="indigo", neutral_hue="slate", font=("Inter", "system-ui", "sans-serif") ), css=""" .container { max-width: 1000px; margin: auto; padding: 2rem; } .title { text-align: center; font-size: 2.5rem; font-weight: 600; margin: 1rem 0; background: linear-gradient(to right, var(--primary-500), var(--secondary-500)); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .subtitle { text-align: center; font-size: 1.1rem; color: var(--neutral-700); margin-bottom: 2rem; font-weight: 400; } .model-info { font-style: italic; color: var(--neutral-500); font-size: 0.85em; margin-top: 1em; padding-top: 0.5em; border-top: 1px solid var(--neutral-200); opacity: 0.8; } .stats-box { margin-top: 1rem; padding: 1rem; border-radius: 0.75rem; background: color-mix(in srgb, var(--background-fill) 80%, transparent); border: 1px solid var(--neutral-200); font-family: monospace; white-space: pre-line; } .message.assistant { padding-bottom: 1.5em !important; } """ ) as demo: with gr.Column(elem_classes="container"): gr.Markdown("# AI Model Router", elem_classes="title") gr.Markdown( "Your message will be routed to the appropriate AI model based on complexity.", elem_classes="subtitle" ) chatbot = gr.Chatbot( value=[], bubble_full_width=False, show_label=False, height=450, container=True, type="messages" ) with gr.Row(): txt = gr.Textbox( show_label=False, placeholder="Enter your message here...", container=False, scale=7 ) clear = gr.ClearButton( [txt, chatbot], scale=1, variant="secondary", size="sm" ) with gr.Accordion("Advanced Settings", open=False): system_message = gr.Textbox(value="You are a helpful AI assistant.", label="System message") max_tokens = gr.Slider(minimum=16, maximum=4096, value=2048, step=1, label="Max Tokens") temperature = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P") stats_display = gr.Textbox( value=stats.get_stats(), label="Model Usage Statistics", interactive=False, elem_classes="stats-box" ) # Set up event handler for streaming txt.submit( chat_wrapper, [txt, chatbot, system_message, max_tokens, temperature, top_p, stats_display], [chatbot, stats_display], ).then( lambda: "", None, [txt], ) if __name__ == "__main__": demo.queue().launch()