Spaces:
Sleeping
Sleeping
| import itertools | |
| import json | |
| import random | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from tools import create_key | |
| from model.timbre_encoder_pretrain import get_timbre_encoder | |
| class ProjectionLayer(nn.Module): | |
| """Single-layer Linear projection with dropout, layer norm, and Gelu activation""" | |
| def __init__(self, input_dim, output_dim, dropout): | |
| super(ProjectionLayer, self).__init__() | |
| self.projection = nn.Linear(input_dim, output_dim) | |
| self.gelu = nn.GELU() | |
| self.fc = nn.Linear(output_dim, output_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(output_dim) | |
| def forward(self, x): | |
| projected = self.projection(x) | |
| x = self.gelu(projected) | |
| x = self.fc(x) | |
| x = self.dropout(x) | |
| x = x + projected | |
| x = self.layer_norm(x) | |
| return x | |
| class ProjectionHead(nn.Module): | |
| """Stack of 'ProjectionLayer'""" | |
| def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2): | |
| super(ProjectionHead, self).__init__() | |
| self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim, | |
| projection_dim, | |
| dropout) for i in range(num_layers)]) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = layer(x) | |
| return x | |
| class multi_modal_model(nn.Module): | |
| """The multi-modal model for contrastive learning""" | |
| def __init__( | |
| self, | |
| timbre_encoder, | |
| text_encoder, | |
| spectrogram_feature_dim, | |
| text_feature_dim, | |
| multi_modal_emb_dim, | |
| temperature, | |
| dropout, | |
| num_projection_layers=1, | |
| freeze_spectrogram_encoder=True, | |
| freeze_text_encoder=True, | |
| ): | |
| super().__init__() | |
| self.timbre_encoder = timbre_encoder | |
| self.text_encoder = text_encoder | |
| self.multi_modal_emb_dim = multi_modal_emb_dim | |
| self.text_projection = ProjectionHead(embedding_dim=text_feature_dim, | |
| projection_dim=self.multi_modal_emb_dim, dropout=dropout, | |
| num_layers=num_projection_layers) | |
| self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim, | |
| projection_dim=self.multi_modal_emb_dim, dropout=dropout, | |
| num_layers=num_projection_layers) | |
| self.temperature = temperature | |
| # Make spectrogram_encoder parameters non-trainable | |
| for param in self.timbre_encoder.parameters(): | |
| param.requires_grad = not freeze_spectrogram_encoder | |
| # Make text_encoder parameters non-trainable | |
| for param in self.text_encoder.parameters(): | |
| param.requires_grad = not freeze_text_encoder | |
| def forward(self, spectrogram_batch, tokenized_text_batch): | |
| # Getting Image and Text Embeddings (with same dimension) | |
| spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) | |
| text_features = self.text_encoder.get_text_features(**tokenized_text_batch) | |
| # Concat and apply projection | |
| spectrogram_embeddings = self.spectrogram_projection(spectrogram_features) | |
| text_embeddings = self.text_projection(text_features) | |
| # Calculating the Loss | |
| logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature | |
| images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T | |
| texts_similarity = text_embeddings @ text_embeddings.T | |
| targets = F.softmax( | |
| (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
| ) | |
| texts_loss = cross_entropy(logits, targets, reduction='none') | |
| images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
| contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) | |
| contrastive_loss = contrastive_loss.mean() | |
| return contrastive_loss | |
| def get_text_features(self, input_ids, attention_mask): | |
| text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask) | |
| return self.text_projection(text_features) | |
| def get_timbre_features(self, spectrogram_batch): | |
| spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) | |
| return self.spectrogram_projection(spectrogram_features) | |
| def cross_entropy(preds, targets, reduction='none'): | |
| log_softmax = nn.LogSoftmax(dim=-1) | |
| loss = (-targets * log_softmax(preds)).sum(1) | |
| if reduction == "none": | |
| return loss | |
| elif reduction == "mean": | |
| return loss.mean() | |
| def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
| mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config) | |
| print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}") | |
| mmm.to(device) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{model_name}_MMM.pth") | |
| checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device) | |
| mmm.load_state_dict(checkpoint['model_state_dict']) | |
| mmm.eval() | |
| return mmm | |
| def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device): | |
| (data, attributes) = next(iter(train_loader)) | |
| keys = [create_key(attribute) for attribute in attributes] | |
| while(len(set(keys)) != len(keys)): | |
| (data, attributes) = next(iter(train_loader)) | |
| keys = [create_key(attribute) for attribute in attributes] | |
| data = data.to(device) | |
| texts = [labels_mapping[create_key(attribute)] for attribute in attributes] | |
| selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] | |
| tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) | |
| loss = model(data, tokenized_text) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| return loss.item() | |
| def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device): | |
| (data, attributes) = next(iter(valid_loader)) | |
| keys = [create_key(attribute) for attribute in attributes] | |
| while(len(set(keys)) != len(keys)): | |
| (data, attributes) = next(iter(valid_loader)) | |
| keys = [create_key(attribute) for attribute in attributes] | |
| data = data.to(device) | |
| texts = [labels_mapping[create_key(attribute)] for attribute in attributes] | |
| selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] | |
| tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) | |
| loss = model(data, tokenized_text) | |
| return loss.item() | |
| def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder, | |
| timbre_encoder_Config, MMM_config, MMM_training_config, | |
| mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True, | |
| timbre_encoder_name=None, init_loss=None, save_steps=2000): | |
| def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter, | |
| current_loss): | |
| model_hyperparameter = MMM_config | |
| model_hyperparameter.update(MMM_training_config) | |
| model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE | |
| model_hyperparameter["model_size"] = model_size | |
| model_hyperparameter["current_iter"] = current_iter | |
| model_hyperparameter["current_loss"] = current_loss | |
| with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file: | |
| json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
| timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, | |
| device=device) | |
| mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device) | |
| print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}") | |
| print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}") | |
| print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}") | |
| print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}") | |
| total_parameters = sum(p.numel() for p in mmm.parameters()) | |
| trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad) | |
| print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}") | |
| params = [ | |
| {"params": itertools.chain( | |
| mmm.spectrogram_projection.parameters(), | |
| mmm.text_projection.parameters(), | |
| ), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]}, | |
| ] | |
| if not MMM_config["freeze_text_encoder"]: | |
| params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"], | |
| "weight_decay": MMM_training_config["text_encoder_weight_decay"]}) | |
| if not MMM_config["freeze_spectrogram_encoder"]: | |
| params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"], | |
| "weight_decay": MMM_training_config["timbre_encoder_weight_decay"]}) | |
| optimizer = torch.optim.AdamW(params, weight_decay=0.) | |
| if load_pretrain: | |
| print(f"Loading weights from models/{mmm_name}_MMM.pt") | |
| checkpoint = torch.load(f'models/{mmm_name}_MMM.pth') | |
| mmm.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| else: | |
| print("Model initialized.") | |
| if max_iter == 0: | |
| print("Return model directly.") | |
| return mmm, optimizer | |
| if init_loss is None: | |
| previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device) | |
| else: | |
| previous_lowest_loss = init_loss | |
| print(f"Initial total loss: {previous_lowest_loss}") | |
| train_loss_list = [] | |
| for i in range(max_iter): | |
| mmm.train() | |
| train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device) | |
| train_loss_list.append(train_loss) | |
| step = int( | |
| optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) | |
| if (i + 1) % 100 == 0: | |
| print('%d step' % (step)) | |
| if (i + 1) % save_steps == 0: | |
| current_loss = np.mean(train_loss_list[-save_steps:]) | |
| print(f"train_total_loss: {current_loss}") | |
| if current_loss < previous_lowest_loss: | |
| previous_lowest_loss = current_loss | |
| torch.save({ | |
| 'model_state_dict': mmm.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| }, f'models/{mmm_name}_MMM.pth') | |
| save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step, | |
| current_loss) | |
| return mmm, optimizer |