Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| torch.manual_seed(42) | |
| torch.backends.cudnn.deterministic = True | |
| from torch import nn | |
| import numpy as np | |
| from typing import * | |
| from datetime import datetime | |
| import os | |
| import pickle | |
| import inspect | |
| from tqdm.auto import trange | |
| import spaces | |
| from model.base_model import BaseModel | |
| class BarlowTwins(BaseModel): | |
| def __init__( | |
| self, | |
| n_bits: int = 1024, | |
| aa_emb_size: int = 1024, | |
| enc_n_neurons: int = 512, | |
| enc_n_layers: int = 2, | |
| proj_n_neurons: int = 2048, | |
| proj_n_layers: int = 2, | |
| embedding_dim: int = 512, | |
| act_function: str = "relu", | |
| loss_weight: float = 0.005, | |
| batch_size: int = 512, | |
| optimizer: str = "adamw", | |
| momentum: float = 0.9, | |
| learning_rate: float = 0.0001, | |
| betas: tuple = (0.9, 0.999), | |
| weight_decay: float = 1e-3, | |
| step_size: int = 10, | |
| gamma: float = 0.1, | |
| verbose: bool = True, | |
| ): | |
| super().__init__() | |
| self.enc_aa = None | |
| self.enc_mol = None | |
| self.proj = None | |
| self.scheduler = None | |
| self.optimizer = None | |
| # store input in dict | |
| self.param_dict = { | |
| "act_function": self.activation_dict[ | |
| act_function | |
| ], # which activation function to use among dict options | |
| "loss_weight": loss_weight, # off-diagonal cross correlation loss weight | |
| "batch_size": batch_size, # samples per gradient step | |
| "learning_rate": learning_rate, # update step magnitude when training | |
| "betas": betas, # momentum hyperparameter for adam-like optimizers | |
| "step_size": step_size, # decay period for the learning rate | |
| "gamma": gamma, # decay coefficient for the learning rate | |
| "optimizer": self.optimizer_dict[ | |
| optimizer | |
| ], # which optimizer to use among dict options | |
| "momentum": momentum, # momentum hyperparameter for SGD | |
| "enc_n_neurons": enc_n_neurons, # neurons to use for the mlp encoder | |
| "enc_n_layers": enc_n_layers, # number of hidden layers in the mlp encoder | |
| "proj_n_neurons": proj_n_neurons, # neurons to use for the mlp projector | |
| "proj_n_layers": proj_n_layers, # number of hidden layers in the mlp projector | |
| "embedding_dim": embedding_dim, # latent space dim for downstream tasks | |
| "weight_decay": weight_decay, # l2 regularization for linear layers | |
| "verbose": verbose, # whether to print feedback | |
| "radius": "Not defined yet", # fingerprint radius | |
| "n_bits": n_bits, # fingerprint bit size | |
| "aa_emb_size": aa_emb_size, # aa embedding size | |
| } | |
| # create history dictionary | |
| self.history = { | |
| "train_loss": [], | |
| "on_diag_loss": [], | |
| "off_diag_loss": [], | |
| "validation_loss": [], | |
| } | |
| # run NN architecture construction method | |
| self.construct_model() | |
| # run scheduler construction method | |
| self.construct_scheduler() | |
| # print if necessary | |
| if self.param_dict["verbose"] is True: | |
| self.print_config() | |
| def __validate_inputs(locals_dict) -> None: | |
| # get signature types from __init__ | |
| init_signature = inspect.signature(BarlowTwins.__init__) | |
| # loop over all chosen arguments | |
| for param_name, param_value in locals_dict.items(): | |
| # skip self | |
| if param_name != "self": | |
| # check that parameter exists | |
| if param_name in init_signature.parameters: | |
| # check that param is correct type | |
| expected_type = init_signature.parameters[param_name].annotation | |
| assert isinstance( | |
| param_value, expected_type | |
| ), f"[BT]: Type mismatch for parameter '{param_name}'" | |
| else: | |
| raise ValueError(f"[BT]: Unexpected parameter '{param_name}'") | |
| def construct_mlp(self, input_units, layer_units, n_layers, output_units) -> nn.Sequential: | |
| # make empty list to fill | |
| mlp_list = [] | |
| # make lists defining layer sizes (input + n_neurons*n_layers + embedding_dim) | |
| units = [input_units] + [layer_units] * n_layers | |
| # add layer stack (linear -> batchnorm -> dropout -> activation) | |
| for i in range(len(units) - 1): | |
| mlp_list.append(nn.Linear(units[i], units[i + 1])) | |
| mlp_list.append(nn.BatchNorm1d(units[i + 1])) | |
| mlp_list.append(self.param_dict["act_function"]()) | |
| # add final linear layer | |
| mlp_list.append(nn.Linear(units[-1], output_units)) | |
| return nn.Sequential(*mlp_list) | |
| def construct_model(self) -> None: | |
| # create fingerprint transformer | |
| self.enc_mol = self.construct_mlp( | |
| self.param_dict["n_bits"], | |
| self.param_dict["enc_n_neurons"], | |
| self.param_dict["enc_n_layers"], | |
| self.param_dict["embedding_dim"], | |
| ) | |
| # create aa transformer | |
| self.enc_aa = self.construct_mlp( | |
| self.param_dict["aa_emb_size"], | |
| self.param_dict["enc_n_neurons"], | |
| self.param_dict["enc_n_layers"], | |
| self.param_dict["embedding_dim"], | |
| ) | |
| # create mlp projector | |
| self.proj = self.construct_mlp( | |
| self.param_dict["embedding_dim"], | |
| self.param_dict["proj_n_neurons"], | |
| self.param_dict["proj_n_layers"], | |
| self.param_dict["proj_n_neurons"], | |
| ) | |
| # print if necessary | |
| if self.param_dict["verbose"] is True: | |
| print("[BT]: Model constructed successfully") | |
| def construct_scheduler(self): | |
| # make optimizer | |
| self.optimizer = self.param_dict["optimizer"]( | |
| list(self.enc_mol.parameters()) | |
| + list(self.enc_aa.parameters()) | |
| + list(self.proj.parameters()), | |
| lr=self.param_dict["learning_rate"], | |
| betas=self.param_dict["betas"], | |
| # momentum=self.param_dict["momentum"], | |
| weight_decay=self.param_dict["weight_decay"], | |
| ) | |
| # wrap optimizer in scheduler | |
| """ | |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| self.optimizer, | |
| T_max=self.param_dict["step_size"], # T_0 | |
| # eta_min=1e-7, | |
| verbose=True | |
| ) | |
| self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| self.optimizer, | |
| patience=self.param_dict["step_size"], | |
| verbose=True | |
| ) | |
| """ | |
| self.scheduler = torch.optim.lr_scheduler.StepLR( | |
| self.optimizer, | |
| step_size=self.param_dict["step_size"], | |
| gamma=self.param_dict["gamma"], | |
| ) | |
| # print if necessary | |
| if self.param_dict["verbose"] is True: | |
| print("[BT]: Optimizer constructed successfully") | |
| def switch_mode(self, is_training: bool): | |
| if is_training: | |
| self.enc_mol.train() | |
| self.enc_aa.train() | |
| self.proj.train() | |
| else: | |
| self.enc_mol.eval() | |
| self.enc_aa.eval() | |
| self.proj.eval() | |
| def normalize_projection(tensor: torch.tensor) -> torch.tensor: | |
| means = torch.mean(tensor, axis=0) | |
| std = torch.std(tensor, axis=0) | |
| centered = torch.add(tensor, -means) | |
| scaled = torch.div(centered, std) | |
| return scaled | |
| def compute_loss( | |
| self, | |
| mol_embedding: torch.tensor, | |
| aa_embedding: torch.tensor, | |
| ) -> torch.tensor: | |
| # empirical cross-correlation matrix | |
| mol_embedding = self.normalize_projection(mol_embedding).T | |
| aa_embedding = self.normalize_projection(aa_embedding) | |
| c = mol_embedding @ aa_embedding | |
| # normalize by number of samples | |
| c.div_(self.param_dict["batch_size"]) | |
| # compute elements on diagonal | |
| on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() | |
| # compute elements off diagonal | |
| n, m = c.shape | |
| off_diag = c.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() | |
| off_diag = off_diag.pow_(2).sum() * self.param_dict["loss_weight"] | |
| return on_diag, off_diag | |
| def forward( | |
| self, mol_data: torch.tensor, aa_data: torch.tensor, is_training: bool = True | |
| ) -> torch.tensor: | |
| # switch according to input | |
| self.switch_mode(is_training) | |
| # get embeddings | |
| mol_embeddings = self.enc_mol(mol_data) | |
| aa_embeddings = self.enc_aa(aa_data) | |
| # get projections | |
| mol_proj = self.proj(mol_embeddings) | |
| aa_proj = self.proj(aa_embeddings) | |
| # compute loss | |
| on_diag, off_diag = self.compute_loss(mol_proj, aa_proj) | |
| return on_diag, off_diag | |
| def train( | |
| self, | |
| train_data: torch.utils.data.DataLoader, | |
| val_data: torch.utils.data.DataLoader = None, | |
| num_epochs: int = 20, | |
| patience: int = None, | |
| ): | |
| if self.param_dict["verbose"] is True: | |
| print("[BT]: Training started") | |
| if patience is None: | |
| patience = 2 * self.param_dict["step_size"] | |
| pbar = trange(num_epochs, desc="[BT]: Epochs", leave=False, colour="blue") | |
| for epoch in pbar: | |
| # initialize loss containers | |
| train_loss = 0.0 | |
| on_diag_loss = 0.0 | |
| off_diag_loss = 0.0 | |
| val_loss = 0.0 | |
| # loop over training set | |
| for _, (mol_data, aa_data) in enumerate(train_data): | |
| # reset grad | |
| self.optimizer.zero_grad() | |
| # compute train loss for batch | |
| on_diag, off_diag = self.forward(mol_data, aa_data, is_training=True) | |
| t_loss = on_diag + off_diag | |
| # backpropagation and optimization | |
| t_loss.backward() | |
| """ | |
| nn.utils.clip_grad_norm_( | |
| list(self.enc_mol.parameters()) + | |
| list(self.enc_aa.parameters()) + | |
| list(self.proj.parameters()), | |
| 1 | |
| ) | |
| """ | |
| self.optimizer.step() | |
| # add i-th loss to training container | |
| train_loss += t_loss.item() | |
| on_diag_loss += on_diag.item() | |
| off_diag_loss += off_diag.item() | |
| # add mean epoch loss for train data to history dictionary | |
| self.history["train_loss"].append(train_loss / len(train_data)) | |
| self.history["on_diag_loss"].append(on_diag_loss / len(train_data)) | |
| self.history["off_diag_loss"].append(off_diag_loss / len(train_data)) | |
| # define msg to be printed | |
| msg = ( | |
| f"[BT]: Epoch [{epoch + 1}/{num_epochs}], " | |
| f"Train loss: {train_loss / len(train_data):.3f}, " | |
| f"On diagonal: {on_diag_loss / len(train_data):.3f}, " | |
| f"Off diagonal: {off_diag_loss / len(train_data):.3f} " | |
| ) | |
| # loop over validation set (if present) | |
| if val_data is not None: | |
| for _, (mol_data, aa_data) in enumerate(val_data): | |
| # compute val loss for batch | |
| on_diag_v_loss, off_diag_v_loss = self.forward( | |
| mol_data, aa_data, is_training=False | |
| ) | |
| # add i-th loss to val container | |
| v_loss = on_diag_v_loss + off_diag_v_loss | |
| val_loss += v_loss.item() | |
| # add mean epoc loss for val data to history dictionary | |
| self.history["validation_loss"].append(val_loss / len(val_data)) | |
| # add val loss to msg | |
| msg += f", Val loss: {val_loss / len(val_data):.3f}" | |
| # early stopping | |
| if self.early_stopping(patience=patience): | |
| break | |
| pbar.set_postfix( | |
| { | |
| "train loss": train_loss / len(train_data), | |
| "val loss": val_loss / len(val_data), | |
| } | |
| ) | |
| else: | |
| pbar.set_postfix({"train loss": train_loss / len(train_data)}) | |
| # update scheduler | |
| self.scheduler.step() # val_loss / len(val_data) | |
| if self.param_dict["verbose"] is True: | |
| print(msg) | |
| if self.param_dict["verbose"] is True: | |
| print("[BT]: Training finished") | |
| def encode( | |
| self, vector: np.ndarray, mode: str = "embedding", normalize: bool = True, encoder: str = "mol" | |
| ) -> np.ndarray: | |
| """ | |
| Encodes a given vector using the Barlow Twins model. | |
| Args: | |
| - vector (np.ndarray): the input vector to encode | |
| - mode (str): the mode to use for encoding, either "embedding" or "projection" | |
| - normalize (bool): whether to L2 normalize the output vector | |
| Returns: | |
| - np.ndarray: the encoded vector | |
| """ | |
| # set mol encoder to eval mode | |
| self.switch_mode(is_training=False) | |
| # convert from numpy to tensor | |
| if type(vector) is not torch.Tensor: | |
| vector = torch.from_numpy(vector) | |
| # if oly one molecule pair is passed, add a batch dimension | |
| if len(vector.shape) == 1: | |
| vector = vector.unsqueeze(0) | |
| # get representation | |
| if encoder == "mol": | |
| embedding = self.enc_mol(vector) | |
| if mode == "projection": | |
| embedding = self.proj(embedding) | |
| elif encoder == "aa": | |
| embedding = self.enc_aa(vector) | |
| if mode == "projection": | |
| embedding = self.proj(embedding) | |
| else: | |
| raise ValueError("[BT]: Encoder not recognized") | |
| # L2 normalize (optional) | |
| if normalize: | |
| embedding = torch.nn.functional.normalize(embedding) | |
| # convert back to numpy | |
| return embedding.cpu().detach().numpy() | |
| def zero_shot( | |
| self, mol_vector: np.ndarray, aa_vector: np.ndarray, l2_norm: bool = True, device: str = "cpu" | |
| ) -> np.ndarray: | |
| # disable training | |
| self.switch_mode(is_training=False) | |
| # cast aa vectors (pos and neg) to correct size, force single precision | |
| # to both | |
| mol_vector = np.array(mol_vector, dtype=np.float32) | |
| aa_vector = np.array(aa_vector, dtype=np.float32) | |
| # convert to tensors | |
| mol_vector = torch.from_numpy(mol_vector).to(device) | |
| aa_vector = torch.from_numpy(aa_vector).to(device) | |
| # get embeddings | |
| mol_embedding = self.encode(mol_vector, normalize=l2_norm, encoder="mol") | |
| aa_embedding = self.encode(aa_vector, normalize=l2_norm, encoder="aa") | |
| # concat mol and aa embeddings | |
| concat = np.concatenate((mol_embedding, aa_embedding), axis=1) | |
| return concat | |
| def zero_shot_explain( | |
| self, mol_vector, aa_vector, l2_norm: bool = True, device: str = "cpu" | |
| ): | |
| self.switch_mode(is_training=False) | |
| mol_embedding = self.encode(mol_vector, normalize=l2_norm, encoder="mol") | |
| aa_embedding = self.encode(aa_vector, normalize=l2_norm, encoder="aa") | |
| return torch.cat((mol_embedding, aa_embedding), dim=1) | |
| def consume_preprocessor(self, preprocessor) -> None: | |
| # save attributes related to fingerprint generation from | |
| # preprocessor object | |
| self.param_dict["radius"] = preprocessor.radius | |
| self.param_dict["n_bits"] = preprocessor.n_bits | |
| def save_model(self, path: str) -> None: | |
| # get current date and time for the filename | |
| now = datetime.now() | |
| formatted_date = now.strftime("%d%m%Y") | |
| formatted_time = now.strftime("%H%M") | |
| folder_name = f"{formatted_date}_{formatted_time}" | |
| # make full path string and folder | |
| folder_path = path + "/" + folder_name | |
| os.makedirs(folder_path) | |
| # make paths for weights, config and history | |
| weight_path = folder_path + "/weights.pt" | |
| param_path = folder_path + "/params.pkl" | |
| history_path = folder_path + "/history.json" | |
| # save each Sequential state dict in one object to the path | |
| torch.save( | |
| { | |
| "enc_mol": self.enc_mol.state_dict(), | |
| "enc_aa": self.enc_aa.state_dict(), | |
| "proj": self.proj.state_dict(), | |
| }, | |
| weight_path, | |
| ) | |
| # dump params in pkl | |
| with open(param_path, "wb") as file: | |
| pickle.dump(self.param_dict, file) | |
| # dump history in json | |
| with open(history_path, "wb") as file: | |
| pickle.dump(self.history, file) | |
| # print if verbose is True | |
| if self.param_dict["verbose"] is True: | |
| print(f"[BT]: Model saved at {folder_path}") | |
| def load_model(self, path: str) -> None: | |
| # make weights, config and history paths | |
| weights_path = path + "/weights.pt" | |
| param_path = path + "/params.pkl" | |
| history_path = path + "/history.json" | |
| # load weights, history and params | |
| checkpoint = torch.load(weights_path, map_location=self.device) | |
| with open(param_path, "rb") as file: | |
| param_dict = pickle.load(file) | |
| with open(history_path, "rb") as file: | |
| history = pickle.load(file) | |
| # construct model again, overriding old verbose key with new instance | |
| verbose = self.param_dict["verbose"] | |
| self.param_dict = param_dict | |
| self.param_dict["verbose"] = verbose | |
| self.history = history | |
| self.construct_model() | |
| # set weights in Sequential models | |
| self.enc_mol.load_state_dict(checkpoint["enc_mol"]) | |
| self.enc_aa.load_state_dict(checkpoint["enc_aa"]) | |
| self.proj.load_state_dict(checkpoint["proj"]) | |
| # recreate scheduler and optimizer in order to add new weights | |
| # to graph | |
| self.construct_scheduler() | |
| # print if verbose is True | |
| if self.param_dict["verbose"] is True: | |
| print(f"[BT]: Model loaded from {path}") | |
| print("[BT]: Loaded parameters:") | |
| print(self.param_dict) | |
| def move_to_device(self, device) -> None: | |
| # move each Sequential model to device | |
| self.enc_mol.to(device) | |
| self.enc_aa.to(device) | |
| self.proj.to(device) | |
| self.device = device | |