Model-Routing / app.py
Crystalcareai's picture
Update app.py
af60662 verified
import gradio as gr
from huggingface_hub import InferenceClient
from typing import Dict, List, Optional, Generator, AsyncGenerator
from dataclasses import dataclass
import httpx
import json
import asyncio
import openai
import os
arcee_api_key = os.environ.get("arcee_api_key")
openrouter_api_key = os.environ.get("openrouter_api_key")
@dataclass
class ModelConfig:
name: str
base_url: str
api_key: str
MODEL_CONFIGS = {
1: ModelConfig(
name="virtuoso-small",
base_url="https://models.arcee.ai/v1/chat/completions",
api_key=arcee_api_key
),
2: ModelConfig(
name="virtuoso-medium",
base_url="https://models.arcee.ai/v1/chat/completions",
api_key=arcee_api_key
),
3: ModelConfig(
name="virtuoso-large",
base_url="https://models.arcee.ai/v1/chat/completions",
api_key=arcee_api_key
),
4: ModelConfig(
name="anthropic/claude-3.5-sonnet",
base_url="https://openrouter.ai/api/v1/chat/completions",
api_key=openrouter_api_key
)
}
class ModelUsageStats:
def __init__(self):
self.usage_counts = {i: 0 for i in range(1, 5)}
self.total_queries = 0
def update(self, complexity: int):
self.usage_counts[complexity] += 1
self.total_queries += 1
def get_stats(self) -> str:
if self.total_queries == 0:
return "No queries processed yet."
model_names = {
1: "virtuoso-small",
2: "virtuoso-medium",
3: "virtuoso-large",
4: "claude-3-sonnet"
}
stats = []
for complexity, count in self.usage_counts.items():
percentage = (count / self.total_queries) * 100
stats.append(f"{model_names[complexity]}: {count} uses ({percentage:.1f}%)")
return "\n".join(stats)
stats = ModelUsageStats()
async def get_complexity(prompt: str) -> int:
try:
async with httpx.AsyncClient(http2=True) as client:
response = await client.post(
"http://185.216.20.86:8000/complexity",
headers={"Content-Type": "application/json"},
json={"prompt": prompt},
timeout=10
)
response.raise_for_status()
return response.json()["complexity"]
except Exception as e:
print(f"Error getting complexity: {e}")
return 3 # Default to medium complexity on error
async def get_model_response(message: str, history: List[Dict[str, str]], complexity: int) -> AsyncGenerator[str, None]:
model_config = MODEL_CONFIGS[complexity]
headers = {
"Content-Type": "application/json"
}
if "openrouter.ai" in model_config.base_url:
headers.update({
"HTTP-Referer": "https://github.com/lucataco/gradio-router",
"X-Title": "Gradio Router",
"Authorization": f"Bearer {model_config.api_key}"
})
elif "arcee.ai" in model_config.base_url:
headers.update({
"Authorization": f"Bearer {model_config.api_key}"
})
try:
collected_chunks = []
# For Arcee.ai models, use direct API call with HTTP/2
if "arcee.ai" in model_config.base_url:
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
for msg in history:
# Clean content
content = msg["content"]
if isinstance(content, str):
content = content.split("\n\n<div")[0]
messages.append({"role": msg["role"], "content": content})
messages.append({"role": "user", "content": message})
async with httpx.AsyncClient(http2=True) as client:
async with client.stream(
"POST",
model_config.base_url,
headers=headers,
json={
"model": model_config.name,
"messages": messages,
"temperature": 0.7,
"stream": True
},
timeout=30.0
) as response:
response.raise_for_status()
buffer = []
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
json_response = json.loads(line.replace("data: ", ""))
if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'):
buffer.append(json_response['choices'][0]['delta']['content'])
if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]):
collected_chunks.extend(buffer)
yield "".join(collected_chunks)
buffer = []
except json.JSONDecodeError:
continue
if buffer: # Yield any remaining content
collected_chunks.extend(buffer)
yield "".join(collected_chunks)
# For OpenRouter models, use direct API call with streaming
else:
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
for msg in history:
content = msg["content"]
if isinstance(content, str):
content = content.split("\n\n<div")[0]
messages.append({"role": msg["role"], "content": content})
messages.append({"role": "user", "content": message})
async with httpx.AsyncClient(http2=True) as client:
async with client.stream(
"POST",
model_config.base_url,
headers=headers,
json={
"model": model_config.name,
"messages": messages,
"temperature": 0.7,
"stream": True
},
timeout=30.0
) as response:
response.raise_for_status()
buffer = []
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
json_response = json.loads(line.replace("data: ", ""))
if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'):
buffer.append(json_response['choices'][0]['delta']['content'])
if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]):
collected_chunks.extend(buffer)
yield "".join(collected_chunks)
buffer = []
except json.JSONDecodeError:
continue
if buffer: # Yield any remaining content
collected_chunks.extend(buffer)
yield "".join(collected_chunks)
except Exception as e:
error_msg = str(e)
print(f"Error getting model response: {error_msg}")
if "464" in error_msg:
yield "Error: Authentication failed. Please check your API key and try again."
elif "Internal Server Error" in error_msg:
yield "Error: The server encountered an internal error. Please try again later."
else:
yield f"Error: Unable to get response from {model_config.name}. {error_msg}"
async def chat_wrapper(
message: str,
history: List[Dict[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
model_usage_stats: str,
):
complexity = await get_complexity(message)
stats.update(complexity)
model_name = MODEL_CONFIGS[complexity].name
# Convert history for model
model_history = []
for msg in history:
if isinstance(msg, dict) and "role" in msg and "content" in msg:
# Clean content
content = msg["content"]
if isinstance(content, str):
content = content.split("\n\n<div")[0]
model_history.append({"role": msg["role"], "content": content})
# Stream the response
full_response = ""
async for partial_response in get_model_response(message, model_history, complexity):
full_response = partial_response
response_with_info = f"{full_response}\n\n<div class='model-info'>Model: {model_name}</div>"
# Update stats display
stats_text = stats.get_stats()
yield [
*history,
{"role": "user", "content": message},
{"role": "assistant", "content": response_with_info}
], stats_text
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
neutral_hue="slate",
font=("Inter", "system-ui", "sans-serif")
),
css="""
.container {
max-width: 1000px;
margin: auto;
padding: 2rem;
}
.title {
text-align: center;
font-size: 2.5rem;
font-weight: 600;
margin: 1rem 0;
background: linear-gradient(to right, var(--primary-500), var(--secondary-500));
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.subtitle {
text-align: center;
font-size: 1.1rem;
color: var(--neutral-700);
margin-bottom: 2rem;
font-weight: 400;
}
.model-info {
font-style: italic;
color: var(--neutral-500);
font-size: 0.85em;
margin-top: 1em;
padding-top: 0.5em;
border-top: 1px solid var(--neutral-200);
opacity: 0.8;
}
.stats-box {
margin-top: 1rem;
padding: 1rem;
border-radius: 0.75rem;
background: color-mix(in srgb, var(--background-fill) 80%, transparent);
border: 1px solid var(--neutral-200);
font-family: monospace;
white-space: pre-line;
}
.message.assistant {
padding-bottom: 1.5em !important;
}
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("# AI Model Router", elem_classes="title")
gr.Markdown(
"Your message will be routed to the appropriate AI model based on complexity.",
elem_classes="subtitle"
)
chatbot = gr.Chatbot(
value=[],
bubble_full_width=False,
show_label=False,
height=450,
container=True,
type="messages"
)
with gr.Row():
txt = gr.Textbox(
show_label=False,
placeholder="Enter your message here...",
container=False,
scale=7
)
clear = gr.ClearButton(
[txt, chatbot],
scale=1,
variant="secondary",
size="sm"
)
with gr.Accordion("Advanced Settings", open=False):
system_message = gr.Textbox(value="You are a helpful AI assistant.", label="System message")
max_tokens = gr.Slider(minimum=16, maximum=4096, value=2048, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
stats_display = gr.Textbox(
value=stats.get_stats(),
label="Model Usage Statistics",
interactive=False,
elem_classes="stats-box"
)
# Set up event handler for streaming
txt.submit(
chat_wrapper,
[txt, chatbot, system_message, max_tokens, temperature, top_p, stats_display],
[chatbot, stats_display],
).then(
lambda: "",
None,
[txt],
)
if __name__ == "__main__":
demo.queue().launch()