|
|
|
import os |
|
import time |
|
from typing import Dict, Optional |
|
from pydantic_ai.models.openai import OpenAIModel |
|
from pydantic_ai.providers.openai import OpenAIProvider |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
class InstanceProvider: |
|
"""Manages multiple Cerebras API instances with failover support""" |
|
|
|
def __init__(self): |
|
self.instances: Dict[str, dict] = {} |
|
self.locked_keys: Dict[str, float] = {} |
|
self.LOCK_DURATION = 1800 |
|
self._initialize_instances() |
|
|
|
def _initialize_instances(self): |
|
"""Load all API keys from environment and create instances""" |
|
api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",") |
|
base_url = os.getenv("CEREBRAS_BASE_URL") |
|
model_name = os.getenv("CEREBRAS_MODEL") |
|
|
|
for key in api_keys: |
|
key = key.strip() |
|
if key: |
|
self.instances[key] = { |
|
'model': OpenAIModel( |
|
model_name, |
|
provider=OpenAIProvider( |
|
base_url=base_url, |
|
api_key=key |
|
) |
|
), |
|
'error_count': 0 |
|
} |
|
|
|
def _clean_locked_keys(self): |
|
"""Remove keys that have been locked beyond the duration""" |
|
current_time = time.time() |
|
expired_keys = [ |
|
key for key, lock_time in self.locked_keys.items() |
|
if current_time - lock_time > self.LOCK_DURATION |
|
] |
|
for key in expired_keys: |
|
del self.locked_keys[key] |
|
|
|
def get_instance(self) -> Optional[OpenAIModel]: |
|
"""Get an available instance, rotating through keys""" |
|
self._clean_locked_keys() |
|
|
|
for key, instance_data in self.instances.items(): |
|
if key not in self.locked_keys: |
|
return instance_data['model'] |
|
|
|
|
|
raise RuntimeError("All API keys exhausted or temporarily locked") |
|
|
|
def report_error(self, api_key: str): |
|
"""Report an error for a specific API key and lock it""" |
|
if api_key in self.instances: |
|
self.instances[api_key]['error_count'] += 1 |
|
self.locked_keys[api_key] = time.time() |
|
|
|
def get_api_key_for_model(self, model: OpenAIModel) -> Optional[str]: |
|
"""Get the API key for a given model instance""" |
|
for key, instance_data in self.instances.items(): |
|
if instance_data['model'] == model: |
|
return key |
|
return None |