import abc import torch import torch.nn as nn torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) """ MDLM Github Repo: """ def get_noise(config, dtype=torch.float32): if config.noise.type == 'geometric': return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max) elif config.noise.type == 'loglinear': return LogLinearNoise() elif config.noise.type == 'cosine': return CosineNoise() elif config.noise.type == 'cosinesqr': return CosineSqrNoise() elif config.noise.type == 'linear': return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype) else: raise ValueError(f'{config.noise.type} is not a valid noise') def binary_discretization(z): z_hard = torch.sign(z) z_soft = z / torch.norm(z, dim=-1, keepdim=True) return z_soft + (z_hard - z_soft).detach() class Noise(abc.ABC, nn.Module): """ Baseline forward method to get the total + rate of noise at a timestep """ def forward(self, t): # Assume time goes from 0 to 1 return self.total_noise(t), self.rate_noise(t) class CosineNoise(Noise): def __init__(self, eps=1e-3): super().__init__() self.eps = eps def rate_noise(self, t): cos = (1 - self.eps) * torch.cos(t * torch.pi / 2) sin = (1 - self.eps) * torch.sin(t * torch.pi / 2) scale = torch.pi / 2 return scale * sin / (cos + self.eps) def total_noise(self, t): cos = torch.cos(t * torch.pi / 2) return - torch.log(self.eps + (1 - self.eps) * cos) class CosineSqrNoise(Noise): def __init__(self, eps=1e-3): super().__init__() self.eps = eps def rate_noise(self, t): cos = (1 - self.eps) * ( torch.cos(t * torch.pi / 2) ** 2) sin = (1 - self.eps) * torch.sin(t * torch.pi) scale = torch.pi / 2 return scale * sin / (cos + self.eps) def total_noise(self, t): cos = torch.cos(t * torch.pi / 2) ** 2 return - torch.log(self.eps + (1 - self.eps) * cos) class Linear(Noise): def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32): super().__init__() self.sigma_min = torch.tensor(sigma_min, dtype=dtype) self.sigma_max = torch.tensor(sigma_max, dtype=dtype) def rate_noise(self): return self.sigma_max - self.sigma_min def total_noise(self, t): return self.sigma_min + t * (self.sigma_max - self.sigma_min) def importance_sampling_transformation(self, t): f_T = torch.log1p(- torch.exp(- self.sigma_max)) f_0 = torch.log1p(- torch.exp(- self.sigma_min)) sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) return (sigma_t - self.sigma_min) / ( self.sigma_max - self.sigma_min) class GeometricNoise(Noise): def __init__(self, sigma_min=1e-3, sigma_max=1): super().__init__() self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]) def rate_noise(self, t): return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * ( self.sigmas[1].log() - self.sigmas[0].log()) def total_noise(self, t): return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t class LogLinearNoise(Noise): """Log Linear noise schedule. Built such that 1 - 1/e^(n(t)) interpolates between 0 and ~1 when t varies from 0 to 1. Total noise is -log(1 - (1 - eps) * t), so the sigma will be (1 - eps) * t. """ def __init__(self, eps=1e-3): super().__init__() self.eps = eps self.sigma_max = self.total_noise(torch.tensor(1.0)) self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0)) def rate_noise(self, t): return (1 - self.eps) / (1 - (1 - self.eps) * t) def total_noise(self, t): return -torch.log1p(-(1 - self.eps) * t) def importance_sampling_transformation(self, t): f_T = torch.log1p(- torch.exp(- self.sigma_max)) f_0 = torch.log1p(- torch.exp(- self.sigma_min)) sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) t = - torch.expm1(- sigma_t) / (1 - self.eps) return t class LogPolyNoise(Noise): """ Log Polynomial noise schedule for slower masking of peptide bond tokens """ def __init__(self, eps=1e-3): super().__init__() self.eps = eps self.sigma_max = self.total_noise(torch.tensor(1.0)) self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0)) def rate_noise(self, t): # derivative of -log(1-t^w) return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3)) def total_noise(self, t): # -log(1-t^w) return -torch.log1p(-(1 - self.eps) * (t**3))