|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import shutil |
|
|
|
import torch |
|
from omegaconf import OmegaConf |
|
from pytorch_lightning import Trainer |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
from nemo.core import ModelPT |
|
from nemo.utils import logging |
|
from nemo.utils.exp_manager import ExpManagerConfig, exp_manager |
|
|
|
|
|
class OnesDataset(torch.utils.data.Dataset): |
|
def __init__(self, dataset_len): |
|
super().__init__() |
|
self.__dataset_len = dataset_len |
|
|
|
def __getitem__(self, *args): |
|
return torch.ones(2) |
|
|
|
def __len__(self): |
|
return self.__dataset_len |
|
|
|
|
|
class ExampleModel(ModelPT): |
|
def __init__(self, *args, **kwargs): |
|
cfg = OmegaConf.structured({}) |
|
super().__init__(cfg, trainer=kwargs.get('trainer', None)) |
|
|
|
self.l1 = torch.nn.modules.Linear(in_features=2, out_features=1) |
|
|
|
def train_dataloader(self): |
|
return None |
|
|
|
def val_dataloader(self): |
|
return None |
|
|
|
def predict_dataloader(self): |
|
dataset = OnesDataset(2) |
|
return torch.utils.data.DataLoader(dataset, batch_size=2) |
|
|
|
def forward(self, batch): |
|
return batch.mean() |
|
|
|
def validation_step(self, batch, batch_idx): |
|
return self(batch) |
|
|
|
def training_step(self, batch, batch_idx): |
|
return self(batch) |
|
|
|
def list_available_models(self): |
|
pass |
|
|
|
def setup_training_data(self): |
|
pass |
|
|
|
def setup_validation_data(self): |
|
pass |
|
|
|
def validation_epoch_end(self, loss): |
|
self.log("val_loss", torch.stack(loss).mean()) |
|
|
|
|
|
def instantiate_multinode_ddp_if_possible(): |
|
num_gpus = torch.cuda.device_count() |
|
trainer = Trainer(devices=num_gpus, accelerator='gpu', strategy='ddp', logger=None, enable_checkpointing=False) |
|
exp_manager_cfg = ExpManagerConfig(exp_dir='./ddp_check/', use_datetime_version=False, version="") |
|
exp_manager(trainer, cfg=OmegaConf.structured(exp_manager_cfg)) |
|
return trainer |
|
|
|
|
|
def setup_model(trainer: Trainer): |
|
model = ExampleModel(trainer=trainer) |
|
|
|
logging.info(f"M.Global Rank:{model.global_rank}") |
|
logging.info(f"M.Local Rank:{model.local_rank}") |
|
logging.info(f"M.World Size:{model.trainer.world_size}") |
|
|
|
trainer.predict(model) |
|
return model |
|
|
|
|
|
def get_rank_info(texts: list, rank_key: str) -> int: |
|
for line in texts: |
|
if rank_key in line: |
|
rank_value = line.split(":")[-1] |
|
rank_value = int(rank_value) |
|
return rank_value |
|
|
|
print("Could not find the correct rank key !") |
|
exit(1) |
|
|
|
|
|
@rank_zero_only |
|
def check_model_ranks(model: ExampleModel): |
|
basedir = os.path.join('./ddp_check/', 'default', 'version_0') |
|
file_template = "nemo_log_globalrank-{rank}_localrank-{rank}.txt" |
|
|
|
world_size = torch.cuda.device_count() |
|
for rank in range(world_size): |
|
filename = file_template.format(rank=rank) |
|
filepath = os.path.join(basedir, filename) |
|
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
texts = f.readlines() |
|
texts = [t.replace("\n", "") for t in texts] |
|
|
|
log_global_rank = get_rank_info(texts, rank_key='M.Global Rank') |
|
log_world_size = get_rank_info(texts, rank_key='M.World Size') |
|
|
|
if log_global_rank != rank: |
|
print("Logged global rank is not equal to trainer.global_rank !") |
|
exit(1) |
|
|
|
if log_world_size != world_size: |
|
print("Logged world size if not equal to trainer.world_size !") |
|
exit(1) |
|
|
|
|
|
@rank_zero_only |
|
def cleanup(): |
|
if os.path.exists('./ddp_check'): |
|
shutil.rmtree('./ddp_check', ignore_errors=True) |
|
|
|
|
|
def run_checks(): |
|
cleanup() |
|
|
|
trainer = instantiate_multinode_ddp_if_possible() |
|
model = setup_model(trainer) |
|
check_model_ranks(model) |
|
|
|
print("DDP checks passed !") |
|
|
|
cleanup() |
|
|
|
|
|
if __name__ == '__main__': |
|
run_checks() |
|
|