""" Coordinator for distributed model training. """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Dict, List, Any, Optional import asyncio import logging from huggingface_hub import snapshot_download import os import ray from .couchdb_client import CouchDBClient from .config import settings from .tensor_ops import TensorOps logger = logging.getLogger(__name__) class Coordinator: """Coordinator for distributed training of OpenPeerLLM.""" def __init__(self): self.db_client = CouchDBClient() self.model_id = settings.MODEL_ID self.batch_size = settings.BATCH_SIZE self.gradient_accumulation_steps = settings.GRADIENT_ACCUMULATION_STEPS self._initialize_model() def _initialize_model(self): """Initialize the model and tokenizer.""" try: # Download model and tokenizer from Hugging Face cache_dir = snapshot_download(self.model_id) self.model = AutoModelForCausalLM.from_pretrained(cache_dir) self.tokenizer = AutoTokenizer.from_pretrained(cache_dir) # Store initial model state initial_state = { 'model_state': self.model.state_dict(), 'step': 0, 'epoch': 0 } self.db_client.store_model_state(initial_state) except Exception as e: logger.error(f"Failed to initialize model: {e}") raise async def coordinate_training(self, training_config: Dict[str, Any]): """Coordinate distributed training across agents.""" try: num_epochs = training_config.get('num_epochs', 1) steps_per_epoch = training_config.get('steps_per_epoch', 100) for epoch in range(num_epochs): logger.info(f"Starting epoch {epoch}") await self._train_epoch(epoch, steps_per_epoch) # Save checkpoint after each epoch self._save_checkpoint(epoch) except Exception as e: logger.error(f"Training coordination error: {e}") raise async def _train_epoch(self, epoch: int, steps_per_epoch: int): """Train for one epoch.""" for step in range(steps_per_epoch): # Get active agents active_agents = self.db_client.get_active_agents() if not active_agents: logger.warning("No active agents available") await asyncio.sleep(5) continue # Distribute gradient computation jobs gradient_jobs = await self._distribute_gradient_computation( active_agents, self.batch_size ) # Collect and process gradients gradients = await self._collect_gradients(gradient_jobs) if gradients: # Update model with collected gradients self._update_model_parameters(gradients) # Distribute updated model state to agents await self._distribute_model_update() async def _distribute_gradient_computation( self, agents: List[Dict[str, Any]], batch_size: int ) -> List[str]: """Distribute gradient computation jobs to available agents.""" job_ids = [] # Get current model state current_state = self.db_client.get_latest_model_state() if not current_state: raise RuntimeError("No model state available") # Create gradient computation jobs for agent in agents: job_id = self.db_client.create_job( 'gradient_computation', { 'batch_size': batch_size, 'state': current_state['state'] } ) job_ids.append(job_id) return job_ids async def _collect_gradients(self, job_ids: List[str]) -> Optional[List[Dict[str, Any]]]: """Collect gradients from completed jobs.""" all_gradients = [] timeout = 300 # 5 minutes timeout async def wait_for_job(job_id: str) -> Optional[Dict[str, Any]]: start_time = asyncio.get_event_time() while True: if asyncio.get_event_time() - start_time > timeout: logger.warning(f"Job {job_id} timed out") return None job = self.db_client.get_job(job_id) if job['status'] == 'completed': gradient_id = job['result']['gradient_id'] return self.db_client.get_gradients(gradient_id) elif job['status'] == 'failed': logger.error(f"Job {job_id} failed: {job.get('result', {}).get('error')}") return None await asyncio.sleep(1) # Wait for all gradient computations to complete gradient_tasks = [wait_for_job(job_id) for job_id in job_ids] gradients = await asyncio.gather(*gradient_tasks) # Filter out None results (failed jobs) return [g for g in gradients if g is not None] def _update_model_parameters(self, gradients: List[Dict[str, Any]]): """Update model parameters with collected gradients.""" try: # Average gradients from all workers avg_gradients = TensorOps.average_gradients([ {k: torch.tensor(v) for k, v in g.items()} for g in gradients ]) # Apply gradient clipping clipped_gradients = TensorOps.gradient_clipping(avg_gradients, max_norm=1.0) # Update model parameters with torch.no_grad(): for name, param in self.model.named_parameters(): if name in clipped_gradients: param.sub_(clipped_gradients[name] * self.model.config.learning_rate) except Exception as e: logger.error(f"Error updating model parameters: {e}") raise async def _distribute_model_update(self): """Distribute updated model state to all agents.""" try: # Store updated model state state = { 'model_state': self.model.state_dict(), 'timestamp': datetime.utcnow().isoformat() } state_id = self.db_client.store_model_state(state) # Create model update jobs for all active agents active_agents = self.db_client.get_active_agents() for agent in active_agents: self.db_client.create_job( 'model_update', { 'state_id': state_id, 'state': state } ) except Exception as e: logger.error(f"Error distributing model update: {e}") raise def _save_checkpoint(self, epoch: int): """Save a checkpoint of the current model state.""" try: checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints') os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt") torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None }, checkpoint_path) logger.info(f"Saved checkpoint for epoch {epoch}") except Exception as e: logger.error(f"Error saving checkpoint: {e}") raise