Spaces:
Build error
Build error
| from typing import Optional, Any, Sequence, List | |
| from dataclasses import dataclass | |
| import os | |
| import math | |
| import yaml | |
| import shutil | |
| import torch | |
| import torch.distributed as dist | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| import tqdm | |
| import wandb | |
| import coolname | |
| import hydra | |
| import pydantic | |
| from omegaconf import DictConfig | |
| from adam_atan2 import AdamATan2 | |
| from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata | |
| from utils.functions import load_model_class, get_model_source_path | |
| from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed | |
| class LossConfig(pydantic.BaseModel): | |
| model_config = pydantic.ConfigDict(extra='allow') | |
| name: str | |
| class ArchConfig(pydantic.BaseModel): | |
| model_config = pydantic.ConfigDict(extra='allow') | |
| name: str | |
| loss: LossConfig | |
| class PretrainConfig(pydantic.BaseModel): | |
| # Config | |
| arch: ArchConfig | |
| # Data | |
| data_path: str | |
| # Hyperparams | |
| global_batch_size: int | |
| epochs: int | |
| lr: float | |
| lr_min_ratio: float | |
| lr_warmup_steps: int | |
| weight_decay: float | |
| beta1: float | |
| beta2: float | |
| # Puzzle embedding | |
| puzzle_emb_lr: float | |
| puzzle_emb_weight_decay: float | |
| # Names | |
| project_name: Optional[str] = None | |
| run_name: Optional[str] = None | |
| checkpoint_path: Optional[str] = None | |
| # Extras | |
| seed: int = 0 | |
| checkpoint_every_eval: bool = False | |
| eval_interval: Optional[int] = None | |
| eval_save_outputs: List[str] = [] | |
| class TrainState: | |
| model: nn.Module | |
| optimizers: Sequence[torch.optim.Optimizer] | |
| optimizer_lrs: Sequence[float] | |
| carry: Any | |
| step: int | |
| total_steps: int | |
| def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): | |
| dataset = PuzzleDataset(PuzzleDatasetConfig( | |
| seed=config.seed, | |
| dataset_path=config.data_path, | |
| rank=rank, | |
| num_replicas=world_size, | |
| **kwargs | |
| ), split=split) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=None, | |
| num_workers=1, | |
| prefetch_factor=8, | |
| pin_memory=True, | |
| persistent_workers=True | |
| ) | |
| return dataloader, dataset.metadata | |
| def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): | |
| model_cfg = dict( | |
| **config.arch.__pydantic_extra__, # type: ignore | |
| batch_size=config.global_batch_size // world_size, | |
| vocab_size=train_metadata.vocab_size, | |
| seq_len=train_metadata.seq_len, | |
| num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, | |
| causal=False # Non-autoregressive | |
| ) | |
| # Instantiate model with loss head | |
| model_cls = load_model_class(config.arch.name) | |
| loss_head_cls = load_model_class(config.arch.loss.name) | |
| with torch.device("cuda"): | |
| model: nn.Module = model_cls(model_cfg) | |
| model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore | |
| if "DISABLE_COMPILE" not in os.environ: | |
| model = torch.compile(model, dynamic=False) # type: ignore | |
| # Broadcast parameters from rank 0 | |
| if world_size > 1: | |
| with torch.no_grad(): | |
| for param in list(model.parameters()) + list(model.buffers()): | |
| dist.broadcast(param, src=0) | |
| # Optimizers and lr | |
| optimizers = [ | |
| CastedSparseEmbeddingSignSGD_Distributed( | |
| model.model.puzzle_emb.buffers(), # type: ignore | |
| lr=0, # Needs to be set by scheduler | |
| weight_decay=config.puzzle_emb_weight_decay, | |
| world_size=world_size | |
| ), | |
| AdamATan2( | |
| model.parameters(), | |
| lr=0, # Needs to be set by scheduler | |
| weight_decay=config.weight_decay, | |
| betas=(config.beta1, config.beta2) | |
| ) | |
| ] | |
| optimizer_lrs = [ | |
| config.puzzle_emb_lr, | |
| config.lr | |
| ] | |
| return model, optimizers, optimizer_lrs | |
| def cosine_schedule_with_warmup_lr_lambda( | |
| current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 | |
| ): | |
| if current_step < num_warmup_steps: | |
| return base_lr * float(current_step) / float(max(1, num_warmup_steps)) | |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) | |
| return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) | |
| def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): | |
| # Estimated total training steps | |
| total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) | |
| # Model | |
| model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) | |
| return TrainState( | |
| step=0, | |
| total_steps=total_steps, | |
| model=model, | |
| optimizers=optimizers, | |
| optimizer_lrs=optimizer_lrs, | |
| carry=None | |
| ) | |
| def save_train_state(config: PretrainConfig, train_state: TrainState): | |
| # FIXME: Only saved model. | |
| if config.checkpoint_path is None: | |
| return | |
| os.makedirs(config.checkpoint_path, exist_ok=True) | |
| torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) | |
| def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): | |
| return cosine_schedule_with_warmup_lr_lambda( | |
| current_step=train_state.step, | |
| base_lr=base_lr, | |
| num_warmup_steps=round(config.lr_warmup_steps), | |
| num_training_steps=train_state.total_steps, | |
| min_ratio=config.lr_min_ratio | |
| ) | |
| def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): | |
| train_state.step += 1 | |
| if train_state.step > train_state.total_steps: # At most train_total_steps | |
| return | |
| # To device | |
| batch = {k: v.cuda() for k, v in batch.items()} | |
| # Init carry if it is None | |
| if train_state.carry is None: | |
| with torch.device("cuda"): | |
| train_state.carry = train_state.model.initial_carry(batch) # type: ignore | |
| # Forward | |
| train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) | |
| ((1 / global_batch_size) * loss).backward() | |
| # Allreduce | |
| if world_size > 1: | |
| for param in train_state.model.parameters(): | |
| if param.grad is not None: | |
| dist.all_reduce(param.grad) | |
| # Apply optimizer | |
| lr_this_step = None | |
| for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): | |
| lr_this_step = compute_lr(base_lr, config, train_state) | |
| for param_group in optim.param_groups: | |
| param_group['lr'] = lr_this_step | |
| optim.step() | |
| optim.zero_grad() | |
| # Reduce metrics | |
| if len(metrics): | |
| assert not any(v.requires_grad for v in metrics.values()) | |
| metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. | |
| # Reduce and reconstruct | |
| metric_values = torch.stack([metrics[k] for k in metric_keys]) | |
| if world_size > 1: | |
| dist.reduce(metric_values, dst=0) | |
| if rank == 0: | |
| metric_values = metric_values.cpu().numpy() | |
| reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} | |
| # Postprocess | |
| count = max(reduced_metrics["count"], 1) # Avoid NaNs | |
| reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} | |
| reduced_metrics["train/lr"] = lr_this_step | |
| return reduced_metrics | |
| def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): | |
| with torch.inference_mode(): | |
| set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} | |
| all_preds = {} | |
| metric_keys = [] | |
| metric_values = None | |
| metric_global_batch_size = [0 for _ in range(len(set_ids))] | |
| carry = None | |
| for set_name, batch, global_batch_size in eval_loader: | |
| # To device | |
| batch = {k: v.cuda() for k, v in batch.items()} | |
| with torch.device("cuda"): | |
| carry = train_state.model.initial_carry(batch) # type: ignore | |
| # Forward | |
| while True: | |
| carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) | |
| if all_finish: | |
| break | |
| for collection in (batch, preds): | |
| for k, v in collection.items(): | |
| if k in config.eval_save_outputs: | |
| all_preds.setdefault(k, []) | |
| all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory | |
| del carry, preds, batch, all_finish | |
| # Aggregate | |
| set_id = set_ids[set_name] | |
| if metric_values is None: | |
| metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. | |
| metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") | |
| metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) | |
| metric_global_batch_size[set_id] += global_batch_size | |
| if len(all_preds) and config.checkpoint_path is not None: | |
| all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} | |
| os.makedirs(config.checkpoint_path, exist_ok=True) | |
| torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) | |
| # Logging | |
| # Reduce to rank 0 | |
| if metric_values is not None: | |
| if world_size > 1: | |
| dist.reduce(metric_values, dst=0) | |
| if rank == 0: | |
| reduced_metrics = metric_values.cpu().numpy() | |
| reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} | |
| for set_id, set_name in enumerate(set_ids)} | |
| # Postprocess | |
| for set_name, metrics in reduced_metrics.items(): | |
| count = metrics.pop("count") | |
| reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} | |
| return reduced_metrics | |
| def save_code_and_config(config: PretrainConfig): | |
| if config.checkpoint_path is None or wandb.run is None: | |
| return | |
| os.makedirs(config.checkpoint_path, exist_ok=True) | |
| # Copy code | |
| code_list = [ | |
| get_model_source_path(config.arch.name), | |
| get_model_source_path(config.arch.loss.name) | |
| ] | |
| for code_file in code_list: | |
| if code_file is not None: | |
| code_name = os.path.basename(code_file) | |
| shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) | |
| # Dump config as yaml | |
| config_file = os.path.join(config.checkpoint_path, "all_config.yaml") | |
| with open(config_file, "wt") as f: | |
| yaml.dump(config.model_dump(), f) | |
| # Log code | |
| wandb.run.log_code(config.checkpoint_path) | |
| def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: | |
| objects = [None] | |
| if rank == 0: | |
| config = PretrainConfig(**hydra_config) # type: ignore | |
| # Naming | |
| if config.project_name is None: | |
| config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" | |
| if config.run_name is None: | |
| config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" | |
| if config.checkpoint_path is None: | |
| config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) | |
| objects = [config] | |
| if world_size > 1: | |
| dist.broadcast_object_list(objects, src=0) | |
| return objects[0] # type: ignore | |
| def launch(hydra_config: DictConfig): | |
| RANK = 0 | |
| WORLD_SIZE = 1 | |
| # Initialize distributed training if in distributed environment (e.g. torchrun) | |
| if "LOCAL_RANK" in os.environ: | |
| # Initialize distributed, default device and dtype | |
| dist.init_process_group(backend="nccl") | |
| RANK = dist.get_rank() | |
| WORLD_SIZE = dist.get_world_size() | |
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
| # Load sync'ed config | |
| config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) | |
| # Seed RNGs to ensure consistency | |
| torch.random.manual_seed(config.seed + RANK) | |
| # Dataset | |
| train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs | |
| total_iters = config.epochs // train_epochs_per_iter | |
| assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." | |
| train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) | |
| eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) | |
| # Train state | |
| train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) | |
| # Progress bar and logger | |
| progress_bar = None | |
| if RANK == 0: | |
| progress_bar = tqdm.tqdm(total=train_state.total_steps) | |
| wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore | |
| wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) | |
| save_code_and_config(config) | |
| # Training Loop | |
| for _iter_id in range(total_iters): | |
| print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") | |
| ############ Train Iter | |
| train_state.model.train() | |
| for set_name, batch, global_batch_size in train_loader: | |
| metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) | |
| if RANK == 0 and metrics is not None: | |
| wandb.log(metrics, step=train_state.step) | |
| progress_bar.update(train_state.step - progress_bar.n) # type: ignore | |
| ############ Evaluation | |
| train_state.model.eval() | |
| metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) | |
| if RANK == 0 and metrics is not None: | |
| wandb.log(metrics, step=train_state.step) | |
| ############ Checkpointing | |
| if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): | |
| save_train_state(config, train_state) | |
| # finalize | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| launch() | |