""" Base agent class for distributed computing. """ import torch import ray import uuid import asyncio from typing import Dict, Any, Optional from datetime import datetime import logging from .couchdb_client import CouchDBClient from .config import settings logger = logging.getLogger(__name__) @ray.remote class Agent: """Distributed computing agent for tensor operations and model training.""" def __init__(self): self.agent_id = str(uuid.uuid4()) self.db_client = CouchDBClient() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.current_job: Optional[Dict] = None self._register_agent() self._start_heartbeat() def _register_agent(self): """Register agent with the cluster.""" capabilities = { "device": str(self.device), "cuda_available": torch.cuda.is_available(), "cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0, "memory_available": torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0 } success = self.db_client.register_agent(self.agent_id, capabilities) if not success: raise RuntimeError("Failed to register agent") def _start_heartbeat(self): """Start agent heartbeat.""" async def heartbeat_loop(): while True: try: self.db_client.update_heartbeat(self.agent_id) await asyncio.sleep(30) except Exception as e: logger.error(f"Heartbeat error: {e}") await asyncio.sleep(5) asyncio.create_task(heartbeat_loop()) def process_tensors(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Process tensor operations.""" results = {} for name, tensor in tensors.items(): tensor = tensor.to(self.device) # Perform tensor operations results[name] = self._compute_tensor(tensor) return results def _compute_tensor(self, tensor: torch.Tensor) -> torch.Tensor: """Compute operations on a single tensor.""" # Add custom tensor operations here return tensor async def run(self): """Main agent loop.""" while True: try: # Try to claim a job job = self.db_client.claim_job(self.agent_id) if job: self.current_job = job await self._process_job(job) else: await asyncio.sleep(1) except Exception as e: logger.error(f"Error in agent loop: {e}") await asyncio.sleep(5) async def _process_job(self, job: Dict[str, Any]): """Process a claimed job.""" try: job_type = job['type'] params = job['params'] result = None if job_type == 'gradient_computation': result = await self._compute_gradients(params) elif job_type == 'model_update': result = await self._update_model(params) # Store job results self.db_client.update_job_status( job['_id'], 'completed', result ) except Exception as e: logger.error(f"Job processing error: {e}") self.db_client.update_job_status( job['_id'], 'failed', {'error': str(e)} ) finally: self.current_job = None async def _compute_gradients(self, params: Dict[str, Any]) -> Dict[str, Any]: """Compute gradients for model training.""" try: # Load model checkpoint checkpoint = params.get('checkpoint') if checkpoint: state_dict = torch.load(checkpoint, map_location=self.device) # Compute gradients gradients = self._compute_model_gradients(state_dict, params.get('batch')) # Store gradients in CouchDB gradient_id = self.db_client.store_gradients( self.current_job['_id'], gradients ) return {'gradient_id': gradient_id} except Exception as e: logger.error(f"Gradient computation error: {e}") raise def _compute_model_gradients(self, state_dict: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, Any]: """Compute gradients for a given model state and batch.""" # Convert gradients to serializable format gradients = {} for name, param in state_dict.items(): if param.requires_grad: grad = param.grad if grad is not None: gradients[name] = grad.cpu().numpy().tolist() return gradients async def _update_model(self, params: Dict[str, Any]) -> Dict[str, Any]: """Update model with new parameters.""" try: new_state = params.get('state') if new_state: # Apply model updates state_id = self.db_client.store_model_state(new_state) return {'state_id': state_id} except Exception as e: logger.error(f"Model update error: {e}") raise def shutdown(self): """Shutdown the agent.""" # Update agent status to inactive self.db_client.update_job_status( self.agent_id, 'inactive' ) # Clean up resources if torch.cuda.is_available(): torch.cuda.empty_cache()