|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.distributed as dist |
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
import logging |
|
|
import os |
|
|
from contextlib import contextmanager |
|
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy |
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
|
|
try: |
|
|
from torch.distributed.pipeline.sync import Pipe |
|
|
from torch.distributed._pipeline.sync import balance |
|
|
except Exception: |
|
|
Pipe = None |
|
|
balance = None |
|
|
|
|
|
from .model import BitTransformerLM, LoggingTransformerEncoderLayer |
|
|
from .error_handling import with_error_recovery, safe_operation |
|
|
from .types import DeviceType, WorldSize, ProcessRank |
|
|
|
|
|
|
|
|
@with_error_recovery(max_retries=2) |
|
|
def setup_distributed(rank: ProcessRank = 0, |
|
|
world_size: WorldSize = 1, |
|
|
backend: str = "nccl", |
|
|
init_method: str = "tcp://localhost:23456") -> bool: |
|
|
"""Initialize distributed training environment.""" |
|
|
if world_size <= 1: |
|
|
return False |
|
|
|
|
|
try: |
|
|
dist.init_process_group( |
|
|
backend=backend, |
|
|
init_method=init_method, |
|
|
world_size=world_size, |
|
|
rank=rank |
|
|
) |
|
|
logging.info(f"Initialized distributed training: rank {rank}/{world_size}") |
|
|
return True |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to initialize distributed training: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def wrap_fsdp(model: BitTransformerLM, |
|
|
sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD, |
|
|
**kwargs) -> FullyShardedDataParallel: |
|
|
"""Return an optimized FSDP wrapped model with transformer-aware sharding.""" |
|
|
device = kwargs.pop("device_id", None) |
|
|
if device is None and torch.cuda.is_available(): |
|
|
device = torch.cuda.current_device() |
|
|
|
|
|
|
|
|
fsdp_config = { |
|
|
"sharding_strategy": sharding_strategy, |
|
|
"cpu_offload": kwargs.pop("cpu_offload", None), |
|
|
"mixed_precision": kwargs.pop("mixed_precision", None), |
|
|
"auto_wrap_policy": transformer_auto_wrap_policy, |
|
|
"backward_prefetch": kwargs.pop("backward_prefetch", None), |
|
|
"forward_prefetch": kwargs.pop("forward_prefetch", False), |
|
|
"limit_all_gathers": kwargs.pop("limit_all_gathers", True), |
|
|
"use_orig_params": kwargs.pop("use_orig_params", True), |
|
|
**kwargs |
|
|
} |
|
|
|
|
|
|
|
|
fsdp_config = {k: v for k, v in fsdp_config.items() if v is not None} |
|
|
|
|
|
if device is not None: |
|
|
model = model.to(device) |
|
|
fsdp_config["device_id"] = device |
|
|
|
|
|
return FullyShardedDataParallel(model, **fsdp_config) |
|
|
|
|
|
|
|
|
class OptimizedPipeline(nn.Module): |
|
|
"""Enhanced pipeline parallelism with BitTransformerLM optimizations.""" |
|
|
|
|
|
def __init__(self, |
|
|
model: BitTransformerLM, |
|
|
num_stages: int = 1, |
|
|
chunks: int = 1, |
|
|
checkpoint: bool = True): |
|
|
super().__init__() |
|
|
|
|
|
if Pipe is None: |
|
|
raise RuntimeError("Pipeline parallelism not available in this build") |
|
|
|
|
|
self.num_stages = num_stages |
|
|
self.chunks = chunks |
|
|
self.checkpoint = checkpoint |
|
|
|
|
|
|
|
|
if num_stages > 1: |
|
|
self.pipeline_model = self._create_pipeline_stages(model, num_stages) |
|
|
else: |
|
|
self.pipeline_model = Pipe(nn.Sequential(model), chunks=chunks) |
|
|
|
|
|
def _create_pipeline_stages(self, model: BitTransformerLM, num_stages: int) -> Pipe: |
|
|
"""Create optimized pipeline stages for BitTransformerLM.""" |
|
|
|
|
|
layers = [] |
|
|
|
|
|
|
|
|
if hasattr(model, 'embedding'): |
|
|
layers.append(model.embedding) |
|
|
if hasattr(model, 'pos_encoding'): |
|
|
layers.append(model.pos_encoding) |
|
|
|
|
|
|
|
|
if hasattr(model, 'layers'): |
|
|
layers.extend(model.layers) |
|
|
elif hasattr(model, 'transformer'): |
|
|
layers.extend(model.transformer.layers) |
|
|
|
|
|
|
|
|
if hasattr(model, 'output_projection'): |
|
|
layers.append(model.output_projection) |
|
|
|
|
|
|
|
|
if balance is not None: |
|
|
partitions = balance(len(layers), num_stages) |
|
|
else: |
|
|
|
|
|
layers_per_stage = len(layers) // num_stages |
|
|
partitions = [layers_per_stage] * num_stages |
|
|
partitions[-1] += len(layers) % num_stages |
|
|
|
|
|
|
|
|
stages = [] |
|
|
start_idx = 0 |
|
|
for partition_size in partitions: |
|
|
end_idx = start_idx + partition_size |
|
|
stage_layers = layers[start_idx:end_idx] |
|
|
stages.append(nn.Sequential(*stage_layers)) |
|
|
start_idx = end_idx |
|
|
|
|
|
return Pipe(nn.Sequential(*stages), chunks=self.chunks) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass through pipeline.""" |
|
|
return self.pipeline_model(x) |
|
|
|
|
|
|
|
|
def make_pipeline(model: BitTransformerLM, |
|
|
chunks: int = 1, |
|
|
num_stages: int = 1, |
|
|
checkpoint: bool = True) -> OptimizedPipeline: |
|
|
"""Create an optimized pipeline with advanced parallelism features.""" |
|
|
return OptimizedPipeline( |
|
|
model=model, |
|
|
num_stages=num_stages, |
|
|
chunks=chunks, |
|
|
checkpoint=checkpoint |
|
|
) |
|
|
|
|
|
|
|
|
class DistributedTrainingManager: |
|
|
"""Manages distributed training configuration and optimization.""" |
|
|
|
|
|
def __init__(self, |
|
|
world_size: WorldSize, |
|
|
rank: ProcessRank, |
|
|
use_pipeline: bool = False, |
|
|
use_fsdp: bool = True): |
|
|
self.world_size = world_size |
|
|
self.rank = rank |
|
|
self.use_pipeline = use_pipeline |
|
|
self.use_fsdp = use_fsdp |
|
|
self.is_distributed = world_size > 1 |
|
|
|
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
def setup_model(self, |
|
|
model: BitTransformerLM, |
|
|
pipeline_stages: int = 1, |
|
|
fsdp_config: Optional[Dict[str, Any]] = None) -> nn.Module: |
|
|
"""Set up model for distributed training.""" |
|
|
if not self.is_distributed: |
|
|
return model |
|
|
|
|
|
with safe_operation("distributed_model_setup"): |
|
|
if self.use_pipeline and pipeline_stages > 1: |
|
|
self.logger.info(f"Setting up pipeline parallelism with {pipeline_stages} stages") |
|
|
return make_pipeline( |
|
|
model, |
|
|
chunks=2, |
|
|
num_stages=pipeline_stages |
|
|
) |
|
|
|
|
|
elif self.use_fsdp: |
|
|
self.logger.info("Setting up FSDP for data parallelism") |
|
|
fsdp_config = fsdp_config or {} |
|
|
return wrap_fsdp(model, **fsdp_config) |
|
|
|
|
|
else: |
|
|
self.logger.info("Using standard DistributedDataParallel") |
|
|
return nn.parallel.DistributedDataParallel(model) |
|
|
|
|
|
def optimize_communication(self, model: nn.Module) -> None: |
|
|
"""Apply communication optimizations for distributed training.""" |
|
|
if not self.is_distributed: |
|
|
return |
|
|
|
|
|
|
|
|
if isinstance(model, nn.parallel.DistributedDataParallel): |
|
|
|
|
|
model._set_ddp_bucket_cap_mb(25) |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(model, '_register_comm_hook'): |
|
|
from torch.distributed.algorithms.ddp_comm_hooks import default |
|
|
model.register_comm_hook( |
|
|
dist.group.WORLD, |
|
|
default.fp16_compress_hook |
|
|
) |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
@contextmanager |
|
|
def training_context(self): |
|
|
"""Context manager for distributed training setup.""" |
|
|
try: |
|
|
if self.is_distributed: |
|
|
self.logger.info("Entering distributed training context") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.set_device(self.rank) |
|
|
yield |
|
|
finally: |
|
|
if self.is_distributed: |
|
|
self.logger.info("Exiting distributed training context") |
|
|
|
|
|
|
|
|
def cleanup_distributed(): |
|
|
"""Clean up distributed training environment.""" |
|
|
if dist.is_initialized(): |
|
|
dist.destroy_process_group() |
|
|
logging.info("Distributed training cleaned up") |
|
|
|
|
|
|
|
|
def get_distributed_config() -> Dict[str, Any]: |
|
|
"""Get current distributed training configuration.""" |
|
|
if not dist.is_initialized(): |
|
|
return {"distributed": False} |
|
|
|
|
|
return { |
|
|
"distributed": True, |
|
|
"world_size": dist.get_world_size(), |
|
|
"rank": dist.get_rank(), |
|
|
"backend": dist.get_backend(), |
|
|
"local_rank": int(os.environ.get("LOCAL_RANK", 0)) if "LOCAL_RANK" in os.environ else None, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def all_reduce_tensor(tensor: torch.Tensor, |
|
|
op: dist.ReduceOp = dist.ReduceOp.SUM) -> torch.Tensor: |
|
|
"""All-reduce operation on tensor across all processes.""" |
|
|
if not dist.is_initialized(): |
|
|
return tensor |
|
|
|
|
|
dist.all_reduce(tensor, op=op) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def gather_tensors(tensor: torch.Tensor, |
|
|
dst: int = 0) -> Optional[List[torch.Tensor]]: |
|
|
"""Gather tensors from all processes to destination rank.""" |
|
|
if not dist.is_initialized(): |
|
|
return [tensor] |
|
|
|
|
|
if dist.get_rank() == dst: |
|
|
tensor_list = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] |
|
|
dist.gather(tensor, tensor_list, dst=dst) |
|
|
return tensor_list |
|
|
else: |
|
|
dist.gather(tensor, dst=dst) |
|
|
return None |
|
|
|
|
|
|
|
|
def broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: |
|
|
"""Broadcast tensor from source rank to all processes.""" |
|
|
if not dist.is_initialized(): |
|
|
return tensor |
|
|
|
|
|
dist.broadcast(tensor, src=src) |
|
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
class PipelineScheduler: |
|
|
"""Advanced scheduler for pipeline parallelism with load balancing.""" |
|
|
|
|
|
def __init__(self, num_stages: int, world_size: int): |
|
|
self.num_stages = num_stages |
|
|
self.world_size = world_size |
|
|
self.stage_times = [0.0] * num_stages |
|
|
self.load_balance_enabled = True |
|
|
|
|
|
def update_stage_timing(self, stage_id: int, execution_time: float): |
|
|
"""Update execution time for a pipeline stage.""" |
|
|
if 0 <= stage_id < self.num_stages: |
|
|
|
|
|
alpha = 0.1 |
|
|
self.stage_times[stage_id] = (1 - alpha) * self.stage_times[stage_id] + alpha * execution_time |
|
|
|
|
|
def get_optimal_chunks(self, batch_size: int) -> int: |
|
|
"""Calculate optimal number of chunks based on stage timing.""" |
|
|
if not self.load_balance_enabled: |
|
|
return max(1, batch_size // 8) |
|
|
|
|
|
|
|
|
max_stage_time = max(self.stage_times) if any(self.stage_times) else 1.0 |
|
|
avg_stage_time = sum(self.stage_times) / len(self.stage_times) if self.stage_times else 1.0 |
|
|
|
|
|
|
|
|
imbalance_factor = max_stage_time / max(avg_stage_time, 1e-6) |
|
|
optimal_chunks = max(2, min(batch_size, int(4 * imbalance_factor))) |
|
|
|
|
|
return optimal_chunks |
|
|
|
|
|
|
|
|
|
|
|
def efficient_gradient_sync(model: nn.Module, gradient_clipping: float = 1.0): |
|
|
"""Perform memory-efficient gradient synchronization across processes.""" |
|
|
if not dist.is_initialized(): |
|
|
return |
|
|
|
|
|
|
|
|
if gradient_clipping > 0: |
|
|
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) |
|
|
|
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
logging.debug(f"Gradient norm before clipping: {total_norm.item():.4f}") |
|
|
|
|
|
|
|
|
bucket_size_mb = 25 |
|
|
parameters = list(model.parameters()) |
|
|
|
|
|
for param in parameters: |
|
|
if param.grad is not None: |
|
|
|
|
|
dist.all_reduce(param.grad, async_op=False) |
|
|
param.grad /= dist.get_world_size() |
|
|
|
|
|
|
|
|
|
|
|
class DistributedMemoryManager: |
|
|
"""Manages memory efficiently across distributed processes.""" |
|
|
|
|
|
def __init__(self, enable_cpu_offload: bool = False): |
|
|
self.enable_cpu_offload = enable_cpu_offload |
|
|
self.memory_stats = {} |
|
|
self.peak_memory = 0 |
|
|
|
|
|
def monitor_memory(self): |
|
|
"""Monitor GPU memory usage across processes.""" |
|
|
if torch.cuda.is_available(): |
|
|
current_memory = torch.cuda.memory_allocated() |
|
|
max_memory = torch.cuda.max_memory_allocated() |
|
|
|
|
|
self.memory_stats = { |
|
|
"current_gb": current_memory / 1e9, |
|
|
"peak_gb": max_memory / 1e9, |
|
|
"rank": dist.get_rank() if dist.is_initialized() else 0 |
|
|
} |
|
|
|
|
|
self.peak_memory = max(self.peak_memory, current_memory) |
|
|
|
|
|
def optimize_memory_usage(self): |
|
|
"""Apply memory optimizations based on current usage.""" |
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
if torch.cuda.memory_allocated() > 0.8 * torch.cuda.max_memory_allocated(): |
|
|
torch.cuda.empty_cache() |
|
|
logging.info("Cleared CUDA cache due to high memory usage") |
|
|
|
|
|
def get_memory_report(self) -> Dict[str, float]: |
|
|
"""Get comprehensive memory usage report.""" |
|
|
self.monitor_memory() |
|
|
return self.memory_stats |
|
|
|
|
|
|
|
|
|
|
|
pipeline_scheduler = PipelineScheduler(num_stages=1, world_size=1) |
|
|
memory_manager = DistributedMemoryManager() |
|
|
|
|
|
|
|
|
def setup_advanced_distributed_training( |
|
|
rank: ProcessRank, |
|
|
world_size: WorldSize, |
|
|
enable_memory_monitoring: bool = True, |
|
|
enable_pipeline_scheduling: bool = True |
|
|
) -> Dict[str, Any]: |
|
|
"""Set up advanced distributed training with optimizations.""" |
|
|
global pipeline_scheduler, memory_manager |
|
|
|
|
|
|
|
|
success = setup_distributed(rank, world_size) |
|
|
if not success: |
|
|
return {"distributed": False} |
|
|
|
|
|
|
|
|
if enable_pipeline_scheduling: |
|
|
pipeline_scheduler = PipelineScheduler(num_stages=world_size, world_size=world_size) |
|
|
|
|
|
if enable_memory_monitoring: |
|
|
memory_manager = DistributedMemoryManager() |
|
|
memory_manager.monitor_memory() |
|
|
|
|
|
config = get_distributed_config() |
|
|
config.update({ |
|
|
"pipeline_scheduling": enable_pipeline_scheduling, |
|
|
"memory_monitoring": enable_memory_monitoring, |
|
|
"advanced_features": True |
|
|
}) |
|
|
|
|
|
logging.info(f"Advanced distributed training initialized on rank {rank}") |
|
|
return config |
|
|
|