Sophia Tang
		
	commited on
		
		
					Commit 
							
							·
						
						92f7053
	
1
								Parent(s):
							
							6612621
								
update
Browse files- config.yaml +168 -0
- diffusion.py +1 -101
- scoring/{hemolysis.py → functions/hemolysis.py} +0 -0
    	
        config.yaml
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            noise:
         | 
| 2 | 
            +
              type: loglinear
         | 
| 3 | 
            +
              sigma_min: 1e-4
         | 
| 4 | 
            +
              sigma_max: 20
         | 
| 5 | 
            +
              state_dependent: True
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            mode: ppl_eval  # train / ppl_eval / sample_eval
         | 
| 8 | 
            +
            diffusion: absorbing_state
         | 
| 9 | 
            +
            vocab: old_smiles # old_smiles / new_smiles / selfies / helm
         | 
| 10 | 
            +
            backbone: roformer  # peptideclm / helmgpt / dit / roformer / finetune_roformer
         | 
| 11 | 
            +
            parameterization: subs  # subs
         | 
| 12 | 
            +
            time_conditioning: False
         | 
| 13 | 
            +
            T: 0  # 0 (continuous time) / 1000 
         | 
| 14 | 
            +
            subs_masking: False
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            seed: 42
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            mcts: 
         | 
| 19 | 
            +
              num_children: 50
         | 
| 20 | 
            +
              num_objectives: 5
         | 
| 21 | 
            +
              topk: 100
         | 
| 22 | 
            +
              mask_token: 4
         | 
| 23 | 
            +
              num_iter: 128
         | 
| 24 | 
            +
              sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
         | 
| 25 | 
            +
              invalid_penalty: 0.5
         | 
| 26 | 
            +
              sample_prob: 1.0
         | 
| 27 | 
            +
              perm: True
         | 
| 28 | 
            +
              dual: False
         | 
| 29 | 
            +
              single: False
         | 
| 30 | 
            +
              time_dependent: True
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            lr_scheduler:
         | 
| 33 | 
            +
              _target_: transformers.get_constant_schedule_with_warmup
         | 
| 34 | 
            +
              num_warmup_steps: 2500
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            data:
         | 
| 37 | 
            +
              train: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-train.csv
         | 
| 38 | 
            +
              valid: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-val.csv
         | 
| 39 | 
            +
              batchinohup ng: wrapping # padding / wrapping
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            loader:
         | 
| 42 | 
            +
              global_batch_size: 64
         | 
| 43 | 
            +
              eval_global_batch_size: ${.global_batch_size}
         | 
| 44 | 
            +
              # Note: batch_size and eval_batch_size are **per machine**
         | 
| 45 | 
            +
              batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
         | 
| 46 | 
            +
              eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
         | 
| 47 | 
            +
              num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
         | 
| 48 | 
            +
              pin_memory: True
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            sampling:
         | 
| 51 | 
            +
              predictor: ddpm_cache  # analytic, ddpm, ddpm_cache
         | 
| 52 | 
            +
              num_sequences: 100
         | 
| 53 | 
            +
              sampling_eps: 1e-3
         | 
| 54 | 
            +
              steps: 128
         | 
| 55 | 
            +
              seq_length: 100
         | 
| 56 | 
            +
              noise_removal: True
         | 
| 57 | 
            +
              num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
         | 
| 58 | 
            +
              num_sample_log: 2
         | 
| 59 | 
            +
              stride_length: 1
         | 
| 60 | 
            +
              num_strides: 1
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            training:
         | 
| 63 | 
            +
              antithetic_sampling: True
         | 
| 64 | 
            +
              sampling_eps: 1e-3
         | 
| 65 | 
            +
              focus_mask: False
         | 
| 66 | 
            +
              #dynamic_batching: True
         | 
| 67 | 
            +
              accumulator: False
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            eval:
         | 
| 70 | 
            +
              checkpoint_path: /home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/epoch=10-step=156276.ckpt
         | 
| 71 | 
            +
              disable_ema: False
         | 
| 72 | 
            +
              compute_generative_perplexity: False
         | 
| 73 | 
            +
              perplexity_batch_size: 8
         | 
| 74 | 
            +
              compute_perplexity_on_sanity: False
         | 
| 75 | 
            +
              gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
         | 
| 76 | 
            +
              generate_samples: True
         | 
| 77 | 
            +
              generation_model: /home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/
         | 
| 78 | 
            +
              
         | 
| 79 | 
            +
            optim:
         | 
| 80 | 
            +
              weight_decay: 0.075
         | 
| 81 | 
            +
              lr: 3e-4
         | 
| 82 | 
            +
              beta1: 0.9
         | 
| 83 | 
            +
              beta2: 0.999
         | 
| 84 | 
            +
              eps: 1e-8
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            pepclm:
         | 
| 87 | 
            +
              hidden_size: 768
         | 
| 88 | 
            +
              cond_dim: 256
         | 
| 89 | 
            +
              n_heads: 20
         | 
| 90 | 
            +
              n_blocks: 4
         | 
| 91 | 
            +
              dropout: 0.5
         | 
| 92 | 
            +
              length: 512
         | 
| 93 | 
            +
              #scale_by_sigma: True
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            model:
         | 
| 96 | 
            +
              type: ddit
         | 
| 97 | 
            +
              hidden_size: 768
         | 
| 98 | 
            +
              cond_dim: 128
         | 
| 99 | 
            +
              length: 512
         | 
| 100 | 
            +
              n_blocks: 12
         | 
| 101 | 
            +
              n_heads: 12
         | 
| 102 | 
            +
              scale_by_sigma: True
         | 
| 103 | 
            +
              dropout: 0.1
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            roformer:
         | 
| 106 | 
            +
              hidden_size: 768
         | 
| 107 | 
            +
              n_layers: 8
         | 
| 108 | 
            +
              n_heads: 8
         | 
| 109 | 
            +
              max_position_embeddings: 1035
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            helmgpt:
         | 
| 112 | 
            +
              hidden_size: 256
         | 
| 113 | 
            +
              embd_pdrop: 0.1
         | 
| 114 | 
            +
              resid_pdrop: 0.1
         | 
| 115 | 
            +
              attn_pdrop: 0.1
         | 
| 116 | 
            +
              ff_dropout: 0.
         | 
| 117 | 
            +
              block_size: 140
         | 
| 118 | 
            +
              n_layer: 8
         | 
| 119 | 
            +
              n_heads: 8
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            trainer:
         | 
| 123 | 
            +
              _target_: lightning.Trainer
         | 
| 124 | 
            +
              accelerator: cuda
         | 
| 125 | 
            +
              num_nodes: 1
         | 
| 126 | 
            +
              devices: ${device_count:}
         | 
| 127 | 
            +
              accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
         | 
| 128 | 
            +
              gradient_clip_val: 1.0
         | 
| 129 | 
            +
              precision: 64-true
         | 
| 130 | 
            +
              num_sanity_val_steps: 2
         | 
| 131 | 
            +
              max_epochs: 100
         | 
| 132 | 
            +
              max_steps: 1_000_000
         | 
| 133 | 
            +
              log_every_n_steps: 10
         | 
| 134 | 
            +
              limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
         | 
| 135 | 
            +
              limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
         | 
| 136 | 
            +
              #val_check_interval: 40 #954
         | 
| 137 | 
            +
              check_val_every_n_epoch: 1
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            wandb:
         | 
| 141 | 
            +
              project: peptune
         | 
| 142 | 
            +
              notes: null
         | 
| 143 | 
            +
              group: null
         | 
| 144 | 
            +
              job_type: null
         | 
| 145 | 
            +
              name: sophia-tang
         | 
| 146 | 
            +
              id: ${.name}_nov12_set2
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            hydra:
         | 
| 149 | 
            +
              run:
         | 
| 150 | 
            +
                dir: ./${now:%Y.%m.%d}/
         | 
| 151 | 
            +
              job:
         | 
| 152 | 
            +
                chdir: True
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            checkpointing:
         | 
| 155 | 
            +
              # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
         | 
| 156 | 
            +
              save_dir: ${cwd:}
         | 
| 157 | 
            +
              # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
         | 
| 158 | 
            +
              resume_from_ckpt: True
         | 
| 159 | 
            +
              resume_ckpt_path: /home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer/epoch=7-step=108225.ckpt
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            callbacks:
         | 
| 162 | 
            +
              model_checkpoint:
         | 
| 163 | 
            +
                _target_: pytorch_lightning.callbacks.ModelCheckpoint
         | 
| 164 | 
            +
                every_n_epochs: 1
         | 
| 165 | 
            +
                monitor: "val/nll"
         | 
| 166 | 
            +
                save_top_k: 10
         | 
| 167 | 
            +
                mode: "min"
         | 
| 168 | 
            +
                dirpath: '/home/st512/peptune/scripts/peptide-mdlm-mcts/checkpoints/11M-old-tokenizer'
         | 
    	
        diffusion.py
    CHANGED
    
    | @@ -116,8 +116,6 @@ class Diffusion(L.LightningModule): | |
| 116 | 
             
                    self.test_metrics = metrics.clone(prefix='test/')
         | 
| 117 |  | 
| 118 |  | 
| 119 | 
            -
                """LOSS"""
         | 
| 120 | 
            -
                
         | 
| 121 | 
             
                """LOSS FOR INVALID PEPTIDES"""
         | 
| 122 |  | 
| 123 | 
             
                @torch.no_grad()
         | 
| @@ -248,18 +246,6 @@ class Diffusion(L.LightningModule): | |
| 248 | 
             
                    t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
         | 
| 249 |  | 
| 250 | 
             
                    return t
         | 
| 251 | 
            -
                    
         | 
| 252 | 
            -
                """def mask_samples(self, x0, mask_prob):
         | 
| 253 | 
            -
                    
         | 
| 254 | 
            -
                    # generate array of values in range [0, 1] uniformly at random
         | 
| 255 | 
            -
                    # will be used to determine which tokens are masked
         | 
| 256 | 
            -
                    mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
         | 
| 257 | 
            -
                    
         | 
| 258 | 
            -
                    # select tokens to mask if the random value in mask_indices is less than mask_prob
         | 
| 259 | 
            -
                    # this will mask approximately the fraction of tokens indicated by mask_prob
         | 
| 260 | 
            -
                    zt = torch.where(mask_indices < mask_prob, self.mask_token_id, x0)
         | 
| 261 | 
            -
                    
         | 
| 262 | 
            -
                    return zt"""
         | 
| 263 |  | 
| 264 | 
             
                def q_xt(self, x, mask_prob):
         | 
| 265 | 
             
                    """Computes the noisy sample xt.
         | 
| @@ -349,48 +335,6 @@ class Diffusion(L.LightningModule): | |
| 349 | 
             
                    # scale by T and return
         | 
| 350 | 
             
                    return self.T * L_vb
         | 
| 351 |  | 
| 352 | 
            -
                """def _forward_pass_diffusion(self, x0, attn_mask, mask=None):
         | 
| 353 | 
            -
                    
         | 
| 354 | 
            -
                    print(x0)
         | 
| 355 | 
            -
                    # randomly sample time steps to start the denoising process for each x0 in batch
         | 
| 356 | 
            -
                    t = self.sample_t(x0.shape[0], x0.device)
         | 
| 357 | 
            -
                    
         | 
| 358 | 
            -
                    # if we are training the intermediate transition blocks
         | 
| 359 | 
            -
                    if self.T > 0: 
         | 
| 360 | 
            -
                        # scale by total timesteps T and cast to integer
         | 
| 361 | 
            -
                        t = (t * self.T).to(torch.int)
         | 
| 362 | 
            -
                        # scale down by T to get a multiple of 1/T
         | 
| 363 | 
            -
                        t = t / self.T
         | 
| 364 | 
            -
                        # add 1/T to ensure no 0 values
         | 
| 365 | 
            -
                        t += (1 / self.T)
         | 
| 366 | 
            -
                    
         | 
| 367 | 
            -
                    # get noise and rate of noise at timestep t
         | 
| 368 | 
            -
                    sigma, dsigma = self.noise(t)
         | 
| 369 | 
            -
                    time_conditioning = sigma[:, None]
         | 
| 370 | 
            -
                    # get masking probabilities for all tokens for each batch
         | 
| 371 | 
            -
                    mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
         | 
| 372 | 
            -
                    
         | 
| 373 | 
            -
                    # get masked samples at different timesteps
         | 
| 374 | 
            -
                    if mask is None: zt = self.q_xt(x0, mask_prob)
         | 
| 375 | 
            -
                    else: zt = x0.where(mask==1, torch.full_like(x0, self.mask_token_id))
         | 
| 376 | 
            -
                    
         | 
| 377 | 
            -
                    model_output = self.forward(zt, attn_mask, time_conditioning) 
         | 
| 378 | 
            -
                    
         | 
| 379 | 
            -
                    utils.print_nans(model_output, 'model_output')
         | 
| 380 | 
            -
                    
         | 
| 381 | 
            -
                    if self.T > 0:
         | 
| 382 | 
            -
                        # compute diffusion loss
         | 
| 383 | 
            -
                        diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
         | 
| 384 | 
            -
                        return diffusion_loss
         | 
| 385 | 
            -
                    
         | 
| 386 | 
            -
                    # compute loss for the final that converts from z0 to x0
         | 
| 387 | 
            -
                    # -log(p_theta)
         | 
| 388 | 
            -
                    # get (batch_size, L) array of log-probabilities 
         | 
| 389 | 
            -
                    log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1) # (B, L)
         | 
| 390 | 
            -
                    
         | 
| 391 | 
            -
                    
         | 
| 392 | 
            -
                    return -log_p_theta * (dsigma / torch.expm1(sigma))[:, None]"""
         | 
| 393 | 
            -
                
         | 
| 394 | 
             
                def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
         | 
| 395 | 
             
                    """
         | 
| 396 | 
             
                        Training reverse diffusion model x_theta to reconstruct samples x0
         | 
| @@ -634,21 +578,6 @@ class Diffusion(L.LightningModule): | |
| 634 |  | 
| 635 | 
             
                # first step in expansion
         | 
| 636 | 
             
                def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
         | 
| 637 | 
            -
                    """
         | 
| 638 | 
            -
                    Generates batch_size different samples from the same starting point for the 
         | 
| 639 | 
            -
                    first expansion step of MCTS
         | 
| 640 | 
            -
             | 
| 641 | 
            -
                    Args:
         | 
| 642 | 
            -
                        x (_type_): _description_
         | 
| 643 | 
            -
                        t (_type_): _description_
         | 
| 644 | 
            -
                        dt (_type_): _description_
         | 
| 645 | 
            -
                        batch_size (_type_): _description_
         | 
| 646 | 
            -
                        p_x0 (_type_, optional): _description_. Defaults to None.
         | 
| 647 | 
            -
                        attn_mask (_type_, optional): _description_. Defaults to None.
         | 
| 648 | 
            -
             | 
| 649 | 
            -
                    Returns:
         | 
| 650 | 
            -
                        _type_: _description_
         | 
| 651 | 
            -
                    """
         | 
| 652 |  | 
| 653 | 
             
                    assert self.config.noise.type == 'loglinear'
         | 
| 654 | 
             
                    sigma_t, _ = self.noise(t)
         | 
| @@ -880,9 +809,7 @@ class Diffusion(L.LightningModule): | |
| 880 | 
             
                                        0)[..., None]
         | 
| 881 | 
             
                    return edge   
         | 
| 882 |  | 
| 883 | 
            -
             | 
| 884 | 
            -
                """TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
         | 
| 885 | 
            -
                
         | 
| 886 | 
             
                def on_train_epoch_start(self):
         | 
| 887 | 
             
                    torch.cuda.empty_cache()
         | 
| 888 | 
             
                    self.backbone.train()
         | 
| @@ -1049,19 +976,6 @@ def sample_categorical(categorical_probs): | |
| 1049 | 
             
                return (categorical_probs / gumbel_norm).argmax(dim=-1)
         | 
| 1050 |  | 
| 1051 | 
             
            def sample_batched_categorical(categorical_probs, batch_size):
         | 
| 1052 | 
            -
                """
         | 
| 1053 | 
            -
                Generates `m` distinct sequences sampled from categorical probabilities 
         | 
| 1054 | 
            -
                using the Gumbel distribution to ensure randomness while following probabilities
         | 
| 1055 | 
            -
                
         | 
| 1056 | 
            -
                Args:
         | 
| 1057 | 
            -
                    categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
         | 
| 1058 | 
            -
                                                      representing categorical probabilities
         | 
| 1059 | 
            -
                    m (int): number of distinct sequences to sample
         | 
| 1060 | 
            -
                
         | 
| 1061 | 
            -
                Returns:
         | 
| 1062 | 
            -
                    torch.Tensor: tensor of shape (m, sequence_length), where each row is a 
         | 
| 1063 | 
            -
                                  distinct sequence of sampled category indices.
         | 
| 1064 | 
            -
                """
         | 
| 1065 | 
             
                _, sequence_length, vocab_size = categorical_probs.shape
         | 
| 1066 |  | 
| 1067 | 
             
                # add Gumbel noise and sample m sequences
         | 
| @@ -1074,20 +988,6 @@ def sample_batched_categorical(categorical_probs, batch_size): | |
| 1074 | 
             
                return sampled_sequences
         | 
| 1075 |  | 
| 1076 | 
             
            def sample_batched_top_k(categorical_probs, batch_size, k):
         | 
| 1077 | 
            -
                """
         | 
| 1078 | 
            -
                Generates `m` sequences sampled from the top-k probabilities of each token
         | 
| 1079 | 
            -
                using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
         | 
| 1080 | 
            -
             | 
| 1081 | 
            -
                Args:
         | 
| 1082 | 
            -
                    categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
         | 
| 1083 | 
            -
                                                      representing categorical probabilities.
         | 
| 1084 | 
            -
                    m (int): Number of sequences to sample.
         | 
| 1085 | 
            -
                    k (int): Number of top probabilities to consider for sampling.
         | 
| 1086 | 
            -
             | 
| 1087 | 
            -
                Returns:
         | 
| 1088 | 
            -
                    torch.Tensor: A tensor of shape (m, sequence_length), where each row is a 
         | 
| 1089 | 
            -
                                  sampled sequence of category indices.
         | 
| 1090 | 
            -
                """
         | 
| 1091 | 
             
                _, sequence_length, vocab_length = categorical_probs.shape
         | 
| 1092 |  | 
| 1093 | 
             
                # Add Gumbel noise to the log probabilities
         | 
|  | |
| 116 | 
             
                    self.test_metrics = metrics.clone(prefix='test/')
         | 
| 117 |  | 
| 118 |  | 
|  | |
|  | |
| 119 | 
             
                """LOSS FOR INVALID PEPTIDES"""
         | 
| 120 |  | 
| 121 | 
             
                @torch.no_grad()
         | 
|  | |
| 246 | 
             
                    t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
         | 
| 247 |  | 
| 248 | 
             
                    return t
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 249 |  | 
| 250 | 
             
                def q_xt(self, x, mask_prob):
         | 
| 251 | 
             
                    """Computes the noisy sample xt.
         | 
|  | |
| 335 | 
             
                    # scale by T and return
         | 
| 336 | 
             
                    return self.T * L_vb
         | 
| 337 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 338 | 
             
                def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
         | 
| 339 | 
             
                    """
         | 
| 340 | 
             
                        Training reverse diffusion model x_theta to reconstruct samples x0
         | 
|  | |
| 578 |  | 
| 579 | 
             
                # first step in expansion
         | 
| 580 | 
             
                def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 581 |  | 
| 582 | 
             
                    assert self.config.noise.type == 'loglinear'
         | 
| 583 | 
             
                    sigma_t, _ = self.noise(t)
         | 
|  | |
| 809 | 
             
                                        0)[..., None]
         | 
| 810 | 
             
                    return edge   
         | 
| 811 |  | 
| 812 | 
            +
                    
         | 
|  | |
|  | |
| 813 | 
             
                def on_train_epoch_start(self):
         | 
| 814 | 
             
                    torch.cuda.empty_cache()
         | 
| 815 | 
             
                    self.backbone.train()
         | 
|  | |
| 976 | 
             
                return (categorical_probs / gumbel_norm).argmax(dim=-1)
         | 
| 977 |  | 
| 978 | 
             
            def sample_batched_categorical(categorical_probs, batch_size):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 979 | 
             
                _, sequence_length, vocab_size = categorical_probs.shape
         | 
| 980 |  | 
| 981 | 
             
                # add Gumbel noise and sample m sequences
         | 
|  | |
| 988 | 
             
                return sampled_sequences
         | 
| 989 |  | 
| 990 | 
             
            def sample_batched_top_k(categorical_probs, batch_size, k):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 991 | 
             
                _, sequence_length, vocab_length = categorical_probs.shape
         | 
| 992 |  | 
| 993 | 
             
                # Add Gumbel noise to the log probabilities
         | 
    	
        scoring/{hemolysis.py → functions/hemolysis.py}
    RENAMED
    
    | 
            File without changes
         | 
