Spaces:
Running
Running
import random | |
from poke_env import AccountConfiguration, ServerConfiguration | |
from poke_env.player.random_player import RandomPlayer | |
from agents import OpenAIAgent, GeminiAgent, MistralAgent, MaxDamagePlayer | |
# Custom server configuration | |
CUSTOM_SERVER_URL = "wss://jofthomas.com/showdown/websocket" | |
CUSTOM_ACTION_URL = 'https://play.pokemonshowdown.com/action.php?' | |
custom_config = ServerConfiguration(CUSTOM_SERVER_URL, CUSTOM_ACTION_URL) | |
# Avatar mappings for different agent types | |
AGENT_AVATARS = { | |
'openai': ['giovanni', 'lusamine', 'guzma'], | |
'mistral': ['alder', 'lance', 'cynthia'], | |
'gemini': ['steven', 'diantha', 'leon'], | |
'maxdamage': ['red'], | |
'random': ['youngster'] | |
} | |
def create_agent(agent_type: str, api_key: str = None, model: str = None, username_suffix: str = None): | |
""" | |
Factory function to create different types of Pokemon agents. | |
Args: | |
agent_type (str): Type of agent ('openai', 'gemini', 'mistral', 'maxdamage', 'random') | |
api_key (str, optional): API key for AI agents | |
model (str, optional): Specific model to use | |
username_suffix (str, optional): Suffix for username uniqueness | |
Returns: | |
Player: A Pokemon battle agent | |
""" | |
if not username_suffix: | |
username_suffix = str(random.randint(1000, 9999)) | |
agent_type = agent_type.lower() | |
if agent_type == 'openai': | |
if not api_key: | |
raise ValueError("API key required for OpenAI agent") | |
model = model or "gpt-4o" | |
username = f"OpenAI-{username_suffix}" | |
avatar = random.choice(AGENT_AVATARS['openai']) | |
return OpenAIAgent( | |
account_configuration=AccountConfiguration(username, None), | |
server_configuration=custom_config, | |
api_key=api_key, | |
model=model, | |
avatar=avatar, | |
max_concurrent_battles=1, | |
battle_delay=0.1, | |
save_replays="battle_replays", | |
) | |
elif agent_type == 'gemini': | |
if not api_key: | |
raise ValueError("API key required for Gemini agent") | |
model = model or "gemini-1.5-flash" | |
username = f"Gemini-{username_suffix}" | |
avatar = random.choice(AGENT_AVATARS['gemini']) | |
return GeminiAgent( | |
account_configuration=AccountConfiguration(username, None), | |
server_configuration=custom_config, | |
api_key=api_key, | |
model=model, | |
avatar=avatar, | |
max_concurrent_battles=1, | |
battle_delay=0.1, | |
save_replays="battle_replays", | |
) | |
elif agent_type == 'mistral': | |
if not api_key: | |
raise ValueError("API key required for Mistral agent") | |
model = model or "mistral-large-latest" | |
username = f"Mistral-{username_suffix}" | |
avatar = random.choice(AGENT_AVATARS['mistral']) | |
return MistralAgent( | |
account_configuration=AccountConfiguration(username, None), | |
server_configuration=custom_config, | |
api_key=api_key, | |
model=model, | |
avatar=avatar, | |
max_concurrent_battles=1, | |
battle_delay=0.1, | |
save_replays="battle_replays", | |
) | |
elif agent_type == 'maxdamage': | |
username = f"MaxDamage-{username_suffix}" | |
avatar = random.choice(AGENT_AVATARS['maxdamage']) | |
return MaxDamagePlayer( | |
account_configuration=AccountConfiguration(username, None), | |
server_configuration=custom_config, | |
max_concurrent_battles=1, | |
save_replays="battle_replays", | |
avatar=avatar, | |
) | |
# Random agents removed to prevent automatic random move selection | |
# elif agent_type == 'random': | |
# username = f"Random-{username_suffix}" | |
# avatar = random.choice(AGENT_AVATARS['random']) | |
# | |
# return RandomPlayer( | |
# account_configuration=AccountConfiguration(username, None), | |
# server_configuration=custom_config, | |
# max_concurrent_battles=1, | |
# save_replays="battle_replays", | |
# avatar=avatar, | |
# ) | |
else: | |
raise ValueError(f"Unknown agent type: {agent_type}. Supported types: openai, gemini, mistral, maxdamage") | |
def get_supported_agent_types(): | |
""" | |
Returns a list of supported agent types. | |
Returns: | |
list: List of supported agent type strings | |
""" | |
return ['openai', 'gemini', 'mistral', 'maxdamage'] | |
def get_default_models(): | |
""" | |
Returns default models for each AI agent type. | |
Returns: | |
dict: Mapping of agent types to default models | |
""" | |
return { | |
'openai': 'gpt-4o', | |
'gemini': 'gemini-1.5-flash', | |
'mistral': 'mistral-large-latest' | |
} |