File size: 2,646 Bytes
d7d1d4e
 
 
 
 
 
a25d048
 
 
d7d1d4e
 
0c663c0
d7d1d4e
 
 
 
 
 
 
 
 
0c663c0
 
 
d7d1d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# instance_provider.py
import os
import time
from typing import Dict, Optional
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from dotenv import load_dotenv

load_dotenv()

class InstanceProvider:
    """Manages multiple Cerebras API instances with failover support"""
    
    def __init__(self):
        self.instances: Dict[str, dict] = {}
        self.locked_keys: Dict[str, float] = {}  # key: lock_time
        self.LOCK_DURATION = 1800  # 30 minutes in seconds
        self._initialize_instances()
    
    def _initialize_instances(self):
        """Load all API keys from environment and create instances"""
        api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
        base_url = os.getenv("CEREBRAS_BASE_URL")
        model_name = os.getenv("CEREBRAS_MODEL")
        
        for key in api_keys:
            key = key.strip()
            if key:
                self.instances[key] = {
                    'model': OpenAIModel(
                        model_name,
                        provider=OpenAIProvider(
                            base_url=base_url,
                            api_key=key
                        )
                    ),
                    'error_count': 0
                }
    
    def _clean_locked_keys(self):
        """Remove keys that have been locked beyond the duration"""
        current_time = time.time()
        expired_keys = [
            key for key, lock_time in self.locked_keys.items()
            if current_time - lock_time > self.LOCK_DURATION
        ]
        for key in expired_keys:
            del self.locked_keys[key]
    
    def get_instance(self) -> Optional[OpenAIModel]:
        """Get an available instance, rotating through keys"""
        self._clean_locked_keys()
        
        for key, instance_data in self.instances.items():
            if key not in self.locked_keys:
                return instance_data['model']
        
        # If we get here, all keys are locked
        raise RuntimeError("All API keys exhausted or temporarily locked")
    
    def report_error(self, api_key: str):
        """Report an error for a specific API key and lock it"""
        if api_key in self.instances:
            self.instances[api_key]['error_count'] += 1
            self.locked_keys[api_key] = time.time()
    
    def get_api_key_for_model(self, model: OpenAIModel) -> Optional[str]:
        """Get the API key for a given model instance"""
        for key, instance_data in self.instances.items():
            if instance_data['model'] == model:
                return key
        return None