Cloud-Agents / cloud_agents /tensor_ops.py
Mentors4EDU's picture
Upload 14 files
f2bab5e verified
raw
history blame
3.09 kB
"""
Tensor operations for distributed computing.
"""
import torch
import numpy as np
from typing import Dict, List, Optional, Union, Tuple
class TensorOps:
"""Utility class for distributed tensor operations."""
@staticmethod
def split_tensor(tensor: torch.Tensor, num_parts: int) -> List[torch.Tensor]:
"""Split a tensor into multiple parts for distributed processing."""
return torch.chunk(tensor, num_parts)
@staticmethod
def merge_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
"""Merge multiple tensors back into a single tensor."""
return torch.cat(tensors, dim=dim)
@staticmethod
def average_gradients(gradients: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""Average gradients from multiple workers."""
avg_gradients = {}
for key in gradients[0].keys():
avg_gradients[key] = torch.mean(torch.stack([g[key] for g in gradients]), dim=0)
return avg_gradients
@staticmethod
def serialize_tensor(tensor: torch.Tensor) -> Dict[str, Union[List, str]]:
"""Serialize a tensor for storage/transmission."""
return {
'data': tensor.cpu().numpy().tolist(),
'shape': list(tensor.shape),
'dtype': str(tensor.dtype)
}
@staticmethod
def deserialize_tensor(tensor_dict: Dict[str, Union[List, str]]) -> torch.Tensor:
"""Deserialize a tensor from storage/transmission format."""
data = np.array(tensor_dict['data'])
shape = tensor_dict['shape']
dtype = getattr(torch, tensor_dict['dtype'].split('.')[-1])
return torch.tensor(data, dtype=dtype).reshape(shape)
@staticmethod
def gradient_clipping(gradients: Dict[str, torch.Tensor], max_norm: float) -> Dict[str, torch.Tensor]:
"""Apply gradient clipping to prevent exploding gradients."""
for k, v in gradients.items():
if v is not None:
torch.nn.utils.clip_grad_norm_(v, max_norm)
return gradients
@staticmethod
def reduce_precision(tensor: torch.Tensor, bits: int = 16) -> torch.Tensor:
"""Reduce tensor precision for efficient transmission."""
if bits == 16:
return tensor.half()
elif bits == 32:
return tensor.float()
else:
raise ValueError("Unsupported precision bits")
@staticmethod
def shard_tensor(tensor: torch.Tensor, shard_size: int) -> List[torch.Tensor]:
"""Shard a tensor into smaller pieces for distributed processing."""
return [tensor[i:i + shard_size] for i in range(0, tensor.size(0), shard_size)]
@staticmethod
def compute_parameter_norm(parameters: Dict[str, torch.Tensor]) -> float:
"""Compute the total norm of all parameters."""
total_norm = 0.0
for param in parameters.values():
total_norm += param.norm().item() ** 2
return total_norm ** 0.5