Spaces:
Sleeping
Sleeping
import os | |
from typing import Optional, Any, Literal | |
from smolagents.models import OpenAIServerModel, InferenceClientModel | |
class SmoLModelManager: | |
"""A class to create and manage SmoLAgents model instances with different client backends.""" | |
def __init__(self, | |
model_id: str, | |
model_client: Literal["openai", "inference"] = "openai", | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None): | |
""" | |
Initialize with model configuration parameters. | |
Args: | |
model_id: The model identifier to use | |
model_client: The client backend to use ("openai" or "inference") | |
api_key: The API key (will check environment variable if None) | |
api_base: The API base URL (will check environment variable if None) | |
""" | |
if not model_id: | |
raise ValueError("model_id cannot be empty") | |
self.model_id = model_id | |
if model_client not in ["openai", "inference"]: | |
raise ValueError("model_client must be either 'openai' or 'inference'") | |
self.model_client = model_client | |
# Set client-specific environment variable names | |
if model_client == "openai": | |
self.api_key_env = "OPENROUTER_API_KEY" | |
self.api_base_env = "OPENROUTER_BASE_URL" | |
self.default_api_base = "https://openrouter.ai/api/v1" | |
else: # inference | |
self.api_key_env = "INFERENCE_API_KEY" | |
# Store API credentials | |
self.api_key = api_key or os.getenv(self.api_key_env) | |
self.api_base = api_base or os.getenv(self.api_base_env, self.default_api_base) | |
def create_model(self) -> Any: | |
""" | |
Create and return the appropriate model instance based on model_client. | |
Returns: | |
The configured model instance or None if creation fails | |
Note: | |
This method catches exceptions to prevent app crashes | |
""" | |
# Validate API key is available | |
if not self.api_key: | |
print(f"Warning: No API key provided and {self.api_key_env} environment variable not set") | |
return None | |
try: | |
if self.model_client == "openai": | |
return self._create_openai_model() | |
else: # inference | |
return self._create_inference_model() | |
except Exception as e: | |
print(f"Error creating model: {str(e)}") | |
return None | |
def _create_openai_model(self) -> Any: | |
"""Create an OpenAIServerModel instance.""" | |
try: | |
return OpenAIServerModel( | |
model_id=self.model_id, | |
api_base=self.api_base, | |
api_key=self.api_key | |
) | |
except ImportError: | |
print("Failed to import OpenAIServerModel. Please ensure smolagents is installed.") | |
return None | |
def _create_inference_model(self) -> Any: | |
"""Create an InferenceClientModel instance.""" | |
try: | |
return InferenceClientModel( | |
model=self.model_id, | |
api_key=self.api_key | |
) | |
except ImportError: | |
print("Failed to import InferenceClientModel. Please ensure smolagents is installed.") | |
return None |