#!/usr/bin/env import os #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' import uuid import wandb import fsspec import hydra import lightning as L from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint, GradientAccumulationScheduler import omegaconf import rich.syntax import rich.tree import torch import sys import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import dataset as dataloader import dataloading_for_dynamic_batching as dynamic_dataloader from diffusion import Diffusion import utils.utils as utils from new_tokenizer.ape_tokenizer import APETokenizer from lightning.pytorch.strategies import DDPStrategy from datasets import load_dataset from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer from helm_tokenizer.helm_tokenizer import HelmTokenizer omegaconf.OmegaConf.register_new_resolver('cwd', os.getcwd) omegaconf.OmegaConf.register_new_resolver('device_count', torch.cuda.device_count) omegaconf.OmegaConf.register_new_resolver('eval', eval) omegaconf.OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) omegaconf.OmegaConf.register_new_resolver("env_or", lambda k, d: os.getenv(k, d)) def _load_from_checkpoint(config, tokenizer): """Create Diffusion model; load weights if checkpoint_path is set.""" if "hf" in str(config.get("backbone", "")): return Diffusion(config, tokenizer=tokenizer).to("cuda") ckpt_path = config.eval.checkpoint_path model = Diffusion.load_from_checkpoint( ckpt_path, tokenizer=tokenizer, config=config, map_location="cuda" if torch.cuda.is_available() else "cpu", ) return model @L.pytorch.utilities.rank_zero_only def print_config( config: omegaconf.DictConfig, resolve: bool = True, save_cfg: bool = True) -> None: """ Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. resolve (bool): Whether to resolve reference fields of DictConfig. save_cfg (bool): Whether to save the configuration tree to a file. """ style = 'dim' tree = rich.tree.Tree('CONFIG', style=style, guide_style=style) fields = config.keys() for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, omegaconf.DictConfig): branch_content = omegaconf.OmegaConf.to_yaml( config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, 'yaml')) rich.print(tree) if save_cfg: with fsspec.open( '{}/config_tree.txt'.format( config.checkpointing.save_dir), 'w') as fp: rich.print(tree, file=fp) @L.pytorch.utilities.rank_zero_only def print_batch(train_ds, valid_ds, tokenizer, k=64): #for dl_type, dl in [ #('train', train_ds), ('valid', valid_ds)]: for dl_type, dl in [ ('train', train_ds)]: print(f'Printing {dl_type} dataloader batch.') batch = next(iter(dl)) print('Batch input_ids.shape', batch['input_ids'].shape) first = batch['input_ids'][0, :k] last = batch['input_ids'][0, -k:] print(f'First {k} tokens:', tokenizer.decode(first)) print('ids:', first) print(f'Last {k} tokens:', tokenizer.decode(last)) print('ids:', last) def generate_samples(config, logger, tokenizer): logger.info('Generating samples.') model = _load_from_checkpoint(config=config, tokenizer=tokenizer) # model.gen_ppl_metric.reset() #stride_length = config.sampling.stride_length #num_strides = config.sampling.num_strides for _ in range(config.sampling.num_sample_batches): samples = model.restore_model_and_sample(num_steps=config.sampling.steps) peptide_sequences = model.tokenizer.batch_decode(samples) model.compute_generative_perplexity(peptide_sequences) print('Peptide samples:', peptide_sequences) print('Generative perplexity:', model.compute_masked_perplexity()) return peptide_sequences def ppl_eval(config, logger, tokenizer, data_module): logger.info('Starting Zero Shot Eval.') model = _load_from_checkpoint(config=config, tokenizer=tokenizer) wandb_logger = None if config.get('wandb', None) is not None: wandb_logger = L.pytorch.loggers.WandbLogger( config=omegaconf.OmegaConf.to_object(config), ** config.wandb) callbacks = [] if 'callbacks' in config: for _, callback in config.callbacks.items(): callbacks.append(hydra.utils.instantiate(callback)) trainer = hydra.utils.instantiate( config.trainer, default_root_dir=os.getcwd(), callbacks=callbacks, strategy=DDPStrategy(find_unused_parameters = True), logger=wandb_logger) #_, valid_ds = dataloader.get_dataloaders(config, tokenizer, skiptrain=True, valid_seed=config.seed) trainer.test(model, data_module) def _train(config, logger, tokenizer, data_module): logger.info('Starting Training.') wandb_logger = None if config.get('wandb', None) is not None: unique_id = str(uuid.uuid4()) config.wandb.id = f"{config.wandb.id}_{unique_id}" wandb_logger = L.pytorch.loggers.WandbLogger( config=omegaconf.OmegaConf.to_object(config), ** config.wandb) if (config.checkpointing.resume_from_ckpt and config.checkpointing.resume_ckpt_path is not None and utils.fsspec_exists( config.checkpointing.resume_ckpt_path)): ckpt_path = config.checkpointing.resume_ckpt_path else: ckpt_path = None # Lightning callbacks callbacks = [] if 'callbacks' in config: for callback_name, callback_config in config.callbacks.items(): if callback_name == 'model_checkpoint': model_checkpoint_config = {k: v for k, v in callback_config.items() if k != '_target_'} callbacks.append(ModelCheckpoint(**model_checkpoint_config)) else: callbacks.append(hydra.utils.instantiate(callback_config)) if config.training.accumulator: accumulator = GradientAccumulationScheduler(scheduling = {1: 5, 2: 4, 3: 3, 4: 1}) callbacks.append(accumulator) trainer = hydra.utils.instantiate( config.trainer, default_root_dir=os.getcwd(), callbacks=callbacks, accelerator='cuda', strategy=DDPStrategy(find_unused_parameters = True), devices=[2,3,4,5,6,7], logger=wandb_logger) model = Diffusion(config, tokenizer=tokenizer) if config.backbone == 'finetune_roformer' and config.eval.checkpoint_path: checkpoint = torch.load(config.eval.checkpoint_path, map_location="cpu") state = checkpoint.get("state_dict", checkpoint) model.load_state_dict(state, strict=False) trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path) @hydra.main(version_base=None, config_path='configs', config_name='config') def main(config): """ Main entry point for training """ L.seed_everything(config.seed) # print_config(config, resolve=True, save_cfg=True) logger = utils.get_logger(__name__) # load PeptideCLM tokenizer tok_dir = config.paths.tokenizers if config.vocab == 'new_smiles': tokenizer = APETokenizer() tokenizer.load_vocabulary(f'{tok_dir}/peptide_smiles_600_vocab.json') elif config.vocab == 'old_smiles': tokenizer = SMILES_SPE_Tokenizer(f'{tok_dir}/new_vocab.txt', f'{tok_dir}/new_splits.txt') elif config.vocab == 'selfies': tokenizer = APETokenizer() tokenizer.load_vocabulary(f'{tok_dir}/peptide_selfies_600_vocab.json') elif config.vocab == 'helm': tokenizer = HelmTokenizer(f'{tok_dir}/monomer_vocab.txt') if config.backbone == 'finetune_roformer': train_dataset = load_dataset('csv', data_files=config.data.train) val_dataset = load_dataset('csv', data_files=config.data.valid) train_dataset = train_dataset['train']#.select(lst) val_dataset = val_dataset['train']#.select(lst) data_module = dataloader.CustomDataModule(train_dataset, val_dataset, None, tokenizer, batch_size=config.loader.global_batch_size) else: data_module = dynamic_dataloader.CustomDataModule(f'{config.paths.data}/smiles/11M_smiles_old_tokenizer_no_limit', tokenizer) if config.mode == 'sample_eval': generate_samples(config, logger, tokenizer) elif config.mode == 'ppl_eval': ppl_eval(config, logger, tokenizer, data_module) else: _train(config, logger, tokenizer, data_module) if __name__ == '__main__': main()