Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from typing import Dict, List, Optional, Generator, AsyncGenerator | |
from dataclasses import dataclass | |
import httpx | |
import json | |
import asyncio | |
import openai | |
import os | |
arcee_api_key = os.environ.get("arcee_api_key") | |
openrouter_api_key = os.environ.get("openrouter_api_key") | |
class ModelConfig: | |
name: str | |
base_url: str | |
api_key: str | |
MODEL_CONFIGS = { | |
1: ModelConfig( | |
name="virtuoso-small", | |
base_url="https://models.arcee.ai/v1/chat/completions", | |
api_key=arcee_api_key | |
), | |
2: ModelConfig( | |
name="virtuoso-medium", | |
base_url="https://models.arcee.ai/v1/chat/completions", | |
api_key=arcee_api_key | |
), | |
3: ModelConfig( | |
name="virtuoso-large", | |
base_url="https://models.arcee.ai/v1/chat/completions", | |
api_key=arcee_api_key | |
), | |
4: ModelConfig( | |
name="anthropic/claude-3.5-sonnet", | |
base_url="https://openrouter.ai/api/v1/chat/completions", | |
api_key=openrouter_api_key | |
) | |
} | |
class ModelUsageStats: | |
def __init__(self): | |
self.usage_counts = {i: 0 for i in range(1, 5)} | |
self.total_queries = 0 | |
def update(self, complexity: int): | |
self.usage_counts[complexity] += 1 | |
self.total_queries += 1 | |
def get_stats(self) -> str: | |
if self.total_queries == 0: | |
return "No queries processed yet." | |
model_names = { | |
1: "virtuoso-small", | |
2: "virtuoso-medium", | |
3: "virtuoso-large", | |
4: "claude-3-sonnet" | |
} | |
stats = [] | |
for complexity, count in self.usage_counts.items(): | |
percentage = (count / self.total_queries) * 100 | |
stats.append(f"{model_names[complexity]}: {count} uses ({percentage:.1f}%)") | |
return "\n".join(stats) | |
stats = ModelUsageStats() | |
async def get_complexity(prompt: str) -> int: | |
try: | |
async with httpx.AsyncClient(http2=True) as client: | |
response = await client.post( | |
"http://185.216.20.86:8000/complexity", | |
headers={"Content-Type": "application/json"}, | |
json={"prompt": prompt}, | |
timeout=10 | |
) | |
response.raise_for_status() | |
return response.json()["complexity"] | |
except Exception as e: | |
print(f"Error getting complexity: {e}") | |
return 3 # Default to medium complexity on error | |
async def get_model_response(message: str, history: List[Dict[str, str]], complexity: int) -> AsyncGenerator[str, None]: | |
model_config = MODEL_CONFIGS[complexity] | |
headers = { | |
"Content-Type": "application/json" | |
} | |
if "openrouter.ai" in model_config.base_url: | |
headers.update({ | |
"HTTP-Referer": "https://github.com/lucataco/gradio-router", | |
"X-Title": "Gradio Router", | |
"Authorization": f"Bearer {model_config.api_key}" | |
}) | |
elif "arcee.ai" in model_config.base_url: | |
headers.update({ | |
"Authorization": f"Bearer {model_config.api_key}" | |
}) | |
try: | |
collected_chunks = [] | |
# For Arcee.ai models, use direct API call with HTTP/2 | |
if "arcee.ai" in model_config.base_url: | |
messages = [{"role": "system", "content": "You are a helpful AI assistant."}] | |
for msg in history: | |
# Clean content | |
content = msg["content"] | |
if isinstance(content, str): | |
content = content.split("\n\n<div")[0] | |
messages.append({"role": msg["role"], "content": content}) | |
messages.append({"role": "user", "content": message}) | |
async with httpx.AsyncClient(http2=True) as client: | |
async with client.stream( | |
"POST", | |
model_config.base_url, | |
headers=headers, | |
json={ | |
"model": model_config.name, | |
"messages": messages, | |
"temperature": 0.7, | |
"stream": True | |
}, | |
timeout=30.0 | |
) as response: | |
response.raise_for_status() | |
buffer = [] | |
async for line in response.aiter_lines(): | |
if line.startswith("data: "): | |
try: | |
json_response = json.loads(line.replace("data: ", "")) | |
if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'): | |
buffer.append(json_response['choices'][0]['delta']['content']) | |
if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]): | |
collected_chunks.extend(buffer) | |
yield "".join(collected_chunks) | |
buffer = [] | |
except json.JSONDecodeError: | |
continue | |
if buffer: # Yield any remaining content | |
collected_chunks.extend(buffer) | |
yield "".join(collected_chunks) | |
# For OpenRouter models, use direct API call with streaming | |
else: | |
messages = [{"role": "system", "content": "You are a helpful AI assistant."}] | |
for msg in history: | |
content = msg["content"] | |
if isinstance(content, str): | |
content = content.split("\n\n<div")[0] | |
messages.append({"role": msg["role"], "content": content}) | |
messages.append({"role": "user", "content": message}) | |
async with httpx.AsyncClient(http2=True) as client: | |
async with client.stream( | |
"POST", | |
model_config.base_url, | |
headers=headers, | |
json={ | |
"model": model_config.name, | |
"messages": messages, | |
"temperature": 0.7, | |
"stream": True | |
}, | |
timeout=30.0 | |
) as response: | |
response.raise_for_status() | |
buffer = [] | |
async for line in response.aiter_lines(): | |
if line.startswith("data: "): | |
try: | |
json_response = json.loads(line.replace("data: ", "")) | |
if json_response.get('choices') and json_response['choices'][0].get('delta', {}).get('content'): | |
buffer.append(json_response['choices'][0]['delta']['content']) | |
if len(buffer) >= 10 or any(c in '.,!?\n' for c in buffer[-1]): | |
collected_chunks.extend(buffer) | |
yield "".join(collected_chunks) | |
buffer = [] | |
except json.JSONDecodeError: | |
continue | |
if buffer: # Yield any remaining content | |
collected_chunks.extend(buffer) | |
yield "".join(collected_chunks) | |
except Exception as e: | |
error_msg = str(e) | |
print(f"Error getting model response: {error_msg}") | |
if "464" in error_msg: | |
yield "Error: Authentication failed. Please check your API key and try again." | |
elif "Internal Server Error" in error_msg: | |
yield "Error: The server encountered an internal error. Please try again later." | |
else: | |
yield f"Error: Unable to get response from {model_config.name}. {error_msg}" | |
async def chat_wrapper( | |
message: str, | |
history: List[Dict[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
model_usage_stats: str, | |
): | |
complexity = await get_complexity(message) | |
stats.update(complexity) | |
model_name = MODEL_CONFIGS[complexity].name | |
# Convert history for model | |
model_history = [] | |
for msg in history: | |
if isinstance(msg, dict) and "role" in msg and "content" in msg: | |
# Clean content | |
content = msg["content"] | |
if isinstance(content, str): | |
content = content.split("\n\n<div")[0] | |
model_history.append({"role": msg["role"], "content": content}) | |
# Stream the response | |
full_response = "" | |
async for partial_response in get_model_response(message, model_history, complexity): | |
full_response = partial_response | |
response_with_info = f"{full_response}\n\n<div class='model-info'>Model: {model_name}</div>" | |
# Update stats display | |
stats_text = stats.get_stats() | |
yield [ | |
*history, | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": response_with_info} | |
], stats_text | |
with gr.Blocks( | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="indigo", | |
neutral_hue="slate", | |
font=("Inter", "system-ui", "sans-serif") | |
), | |
css=""" | |
.container { | |
max-width: 1000px; | |
margin: auto; | |
padding: 2rem; | |
} | |
.title { | |
text-align: center; | |
font-size: 2.5rem; | |
font-weight: 600; | |
margin: 1rem 0; | |
background: linear-gradient(to right, var(--primary-500), var(--secondary-500)); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
} | |
.subtitle { | |
text-align: center; | |
font-size: 1.1rem; | |
color: var(--neutral-700); | |
margin-bottom: 2rem; | |
font-weight: 400; | |
} | |
.model-info { | |
font-style: italic; | |
color: var(--neutral-500); | |
font-size: 0.85em; | |
margin-top: 1em; | |
padding-top: 0.5em; | |
border-top: 1px solid var(--neutral-200); | |
opacity: 0.8; | |
} | |
.stats-box { | |
margin-top: 1rem; | |
padding: 1rem; | |
border-radius: 0.75rem; | |
background: color-mix(in srgb, var(--background-fill) 80%, transparent); | |
border: 1px solid var(--neutral-200); | |
font-family: monospace; | |
white-space: pre-line; | |
} | |
.message.assistant { | |
padding-bottom: 1.5em !important; | |
} | |
""" | |
) as demo: | |
with gr.Column(elem_classes="container"): | |
gr.Markdown("# AI Model Router", elem_classes="title") | |
gr.Markdown( | |
"Your message will be routed to the appropriate AI model based on complexity.", | |
elem_classes="subtitle" | |
) | |
chatbot = gr.Chatbot( | |
value=[], | |
bubble_full_width=False, | |
show_label=False, | |
height=450, | |
container=True, | |
type="messages" | |
) | |
with gr.Row(): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Enter your message here...", | |
container=False, | |
scale=7 | |
) | |
clear = gr.ClearButton( | |
[txt, chatbot], | |
scale=1, | |
variant="secondary", | |
size="sm" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
system_message = gr.Textbox(value="You are a helpful AI assistant.", label="System message") | |
max_tokens = gr.Slider(minimum=16, maximum=4096, value=2048, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0, maximum=2, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P") | |
stats_display = gr.Textbox( | |
value=stats.get_stats(), | |
label="Model Usage Statistics", | |
interactive=False, | |
elem_classes="stats-box" | |
) | |
# Set up event handler for streaming | |
txt.submit( | |
chat_wrapper, | |
[txt, chatbot, system_message, max_tokens, temperature, top_p, stats_display], | |
[chatbot, stats_display], | |
).then( | |
lambda: "", | |
None, | |
[txt], | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |