GAIA_agent / model.py
ItzRoBeerT's picture
Added model manager
b352179
import os
from typing import Optional, Any, Literal
from smolagents.models import OpenAIServerModel, InferenceClientModel
class SmoLModelManager:
"""A class to create and manage SmoLAgents model instances with different client backends."""
def __init__(self,
model_id: str,
model_client: Literal["openai", "inference"] = "openai",
api_key: Optional[str] = None,
api_base: Optional[str] = None):
"""
Initialize with model configuration parameters.
Args:
model_id: The model identifier to use
model_client: The client backend to use ("openai" or "inference")
api_key: The API key (will check environment variable if None)
api_base: The API base URL (will check environment variable if None)
"""
if not model_id:
raise ValueError("model_id cannot be empty")
self.model_id = model_id
if model_client not in ["openai", "inference"]:
raise ValueError("model_client must be either 'openai' or 'inference'")
self.model_client = model_client
# Set client-specific environment variable names
if model_client == "openai":
self.api_key_env = "OPENROUTER_API_KEY"
self.api_base_env = "OPENROUTER_BASE_URL"
self.default_api_base = "https://openrouter.ai/api/v1"
else: # inference
self.api_key_env = "INFERENCE_API_KEY"
# Store API credentials
self.api_key = api_key or os.getenv(self.api_key_env)
self.api_base = api_base or os.getenv(self.api_base_env, self.default_api_base)
def create_model(self) -> Any:
"""
Create and return the appropriate model instance based on model_client.
Returns:
The configured model instance or None if creation fails
Note:
This method catches exceptions to prevent app crashes
"""
# Validate API key is available
if not self.api_key:
print(f"Warning: No API key provided and {self.api_key_env} environment variable not set")
return None
try:
if self.model_client == "openai":
return self._create_openai_model()
else: # inference
return self._create_inference_model()
except Exception as e:
print(f"Error creating model: {str(e)}")
return None
def _create_openai_model(self) -> Any:
"""Create an OpenAIServerModel instance."""
try:
return OpenAIServerModel(
model_id=self.model_id,
api_base=self.api_base,
api_key=self.api_key
)
except ImportError:
print("Failed to import OpenAIServerModel. Please ensure smolagents is installed.")
return None
def _create_inference_model(self) -> Any:
"""Create an InferenceClientModel instance."""
try:
return InferenceClientModel(
model=self.model_id,
api_key=self.api_key
)
except ImportError:
print("Failed to import InferenceClientModel. Please ensure smolagents is installed.")
return None