import gradio as gr from huggingface_hub import InferenceClient from typing import Dict, List, Optional, Generator, AsyncGenerator from dataclasses import dataclass import httpx import json import asyncio import openai import os arcee_api_key = os.environ.get("arcee_api_key") openrouter_api_key = os.environ.get("openrouter_api_key") @dataclass class ModelConfig: name: str base_url: str api_key: str MODEL_CONFIGS = { 1: ModelConfig( name="virtuoso-small", base_url="https://models.arcee.ai/v1/chat/completions", api_key=arcee_api_key ), 2: ModelConfig( name="virtuoso-medium", base_url="https://models.arcee.ai/v1/chat/completions", api_key=arcee_api_key ), 3: ModelConfig( name="virtuoso-large", base_url="https://models.arcee.ai/v1/chat/completions", api_key=arcee_api_key ), 4: ModelConfig( name="anthropic/claude-3.5-sonnet", base_url="https://openrouter.ai/api/v1/chat/completions", api_key=openrouter_api_key ) } class ModelUsageStats: def __init__(self): self.usage_counts = {i: 0 for i in range(1, 5)} self.total_queries = 0 def update(self, complexity: int): self.usage_counts[complexity] += 1 self.total_queries += 1 def get_stats(self) -> str: if self.total_queries == 0: return "No queries processed yet." model_names = { 1: "virtuoso-small", 2: "virtuoso-medium", 3: "virtuoso-large", 4: "claude-3-sonnet" } stats = [] for complexity, count in self.usage_counts.items(): percentage = (count / self.total_queries) * 100 stats.append(f"{model_names[complexity]}: {count} uses ({percentage:.1f}%)") return "\n".join(stats) stats = ModelUsageStats() async def get_complexity(prompt: str) -> int: try: async with httpx.AsyncClient(http2=True) as client: response = await client.post( "http://185.216.20.86:8000/complexity", headers={"Content-Type": "application/json"}, json={"prompt": prompt}, timeout=10 ) response.raise_for_status() return response.json()["complexity"] except Exception as e: print(f"Error getting complexity: {e}") return 3 # Default to medium complexity on error async def get_model_response(message: str, history: List[Dict[str, str]], complexity: int) -> AsyncGenerator[str, None]: model_config = MODEL_CONFIGS[complexity] headers = { "Content-Type": "application/json" } if "openrouter.ai" in model_config.base_url: headers.update({ "HTTP-Referer": "https://github.com/lucataco/gradio-router", "X-Title": "Gradio Router", "Authorization": f"Bearer {model_config.api_key}" }) elif "arcee.ai" in model_config.base_url: headers.update({ "Authorization": f"Bearer {model_config.api_key}" }) try: collected_chunks = [] # For Arcee.ai models, use direct API call with HTTP/2 if "arcee.ai" in model_config.base_url: messages = [{"role": "system", "content": "You are a helpful AI assistant."}] for msg in history: # Clean content content = msg["content"] if isinstance(content, str): content = content.split("\n\n