Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import numpy as np | |
| from os.path import exists | |
| import torch | |
| from encoder.data_objects import DataLoader, Train_Dataset, Dev_Dataset | |
| from encoder.model import SpeakerEncoder | |
| from encoder.params_model import * | |
| from encoder.visualizations import Visualizations | |
| from utils.profiler import Profiler | |
| def sync(device: torch.device): | |
| # For correct profiling (cuda operations are async) | |
| if device.type == "cuda": | |
| torch.cuda.synchronize(device) | |
| def update_lr(optimizer, lr): | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = lr | |
| def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, | |
| backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, | |
| no_visdom: bool): | |
| # Create a dataset and a dataloader | |
| train_dataset = Train_Dataset(clean_data_root.joinpath("train")) | |
| dev_dataset = Dev_Dataset(clean_data_root.joinpath("dev")) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| speakers_per_batch, | |
| utterances_per_speaker, | |
| shuffle=True, | |
| num_workers=8, | |
| pin_memory=True | |
| ) | |
| dev_batch = len(dev_dataset) | |
| dev_loader = DataLoader( | |
| dev_dataset, | |
| dev_batch, | |
| utterances_per_speaker, | |
| shuffle=False, | |
| num_workers=2, | |
| pin_memory=True | |
| ) | |
| # Setup the device on which to run the forward pass and the loss. These can be different, | |
| # because the forward pass is faster on the GPU whereas the loss is often (depending on your | |
| # hyperparameters) faster on the CPU. | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # FIXME: currently, the gradient is None if loss_device is cuda | |
| # loss_device = torch.device("cpu") | |
| loss_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ####modified#### | |
| # Create the model and the optimizer | |
| model = SpeakerEncoder(device, loss_device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) | |
| current_lr = learning_rate_init | |
| init_step = 1 | |
| # Configure file path for the model | |
| model_dir = models_dir / run_id | |
| model_dir.mkdir(exist_ok=True, parents=True) | |
| state_fpath = model_dir / "encoder.pt" | |
| # Load any existing model | |
| if not force_restart: | |
| if state_fpath.exists(): | |
| print("Found existing model \"%s\", loading it and resuming training." % run_id) | |
| checkpoint = torch.load(state_fpath) | |
| init_step = checkpoint["step"] | |
| print(f"Resuming training from step {init_step}") | |
| model.load_state_dict(checkpoint["model_state"]) | |
| optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
| optimizer.param_groups[0]["lr"] = learning_rate_init | |
| else: | |
| print("No model \"%s\" found, starting training from scratch." % run_id) | |
| else: | |
| print("Starting the training from scratch.") | |
| # Initialize the visualization environment | |
| vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) | |
| vis.log_dataset(train_dataset) | |
| vis.log_params() | |
| device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") | |
| vis.log_implementation({"Device": device_name}) | |
| best_eer_file_path = "encoder_loss/best_eer.npy" | |
| if not exists("encoder_loss"): | |
| import os | |
| os.mkdir("encoder_loss") | |
| best_eer = np.load(best_eer_file_path)[0] if exists(best_eer_file_path) else 1 | |
| # Training loop | |
| profiler = Profiler(summarize_every=1000, disabled=False) | |
| for step, speaker_batch in enumerate(train_loader, init_step): | |
| model.train() | |
| profiler.tick("Blocking, waiting for batch (threaded)") | |
| # Data to GPU mem | |
| inputs = torch.from_numpy(speaker_batch.data).to(device) | |
| sync(device) | |
| profiler.tick("Data to %s" % device) | |
| # Forward pass | |
| embeds = model(inputs) | |
| sync(device) | |
| profiler.tick("Forward pass") | |
| embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) | |
| loss, eer = model.loss(embeds_loss) | |
| sync(loss_device) | |
| profiler.tick("Loss") | |
| # Backward pass | |
| model.zero_grad() # Sets gradients of all model parameters to zero | |
| loss.backward() # Calc gradients of all model parameters | |
| profiler.tick("Backward pass") | |
| model.do_gradient_ops() | |
| optimizer.step() # do gradient descent of all model parameters | |
| profiler.tick("Parameter update") | |
| # Update visualizations | |
| # learning_rate = optimizer.param_groups[0]["lr"] | |
| # Overwrite the latest version of the model | |
| if save_every != 0 and step % save_every == 0: | |
| current_lr *= 0.995 | |
| update_lr(optimizer, current_lr) | |
| dev_loss, dev_eer, dev_embeds = validate(dev_loader, model, dev_batch, device, loss_device) | |
| sync(device) | |
| sync(loss_device) | |
| profiler.tick("validate") | |
| vis.update(loss.item(), eer, step, dev_loss, dev_eer) | |
| if dev_eer < best_eer: | |
| best_eer = dev_eer | |
| np.save(best_eer_file_path, np.array([best_eer])) | |
| print("Saving the model (step %d)" % step) | |
| torch.save({ | |
| "step": step + 1, | |
| "model_state": model.state_dict(), | |
| "optimizer_state": optimizer.state_dict(), | |
| }, state_fpath) | |
| else: | |
| vis.update(loss.item(), eer, step) | |
| # Draw projections and save them to the backup folder | |
| if umap_every != 0 and step % umap_every == 0: | |
| print("Drawing and saving projections (step %d)" % step) | |
| projection_fpath = model_dir / f"umap_{step:06d}.png" | |
| dev_projection_fpath = model_dir / f"dev_umap_{step:06d}.png" | |
| embeds = embeds.detach().cpu().numpy() | |
| dev_embeds = dev_embeds.detach().cpu().numpy() | |
| vis.draw_projections(embeds, dev_embeds, utterances_per_speaker, step, projection_fpath, dev_projection_fpath) | |
| vis.save() | |
| # # Make a backup | |
| # if backup_every != 0 and step % backup_every == 0: | |
| # print("Making a backup (step %d)" % step) | |
| # backup_fpath = model_dir / f"encoder_{step:06d}.bak" | |
| # torch.save({ | |
| # "step": step + 1, | |
| # "model_state": model.state_dict(), | |
| # "optimizer_state": optimizer.state_dict(), | |
| # }, backup_fpath) | |
| profiler.tick("Extras (visualizations, saving)") | |
| def validate(dev_loader: DataLoader, model: SpeakerEncoder, dev_batch, device, loss_device): | |
| model.eval() | |
| losses = [] | |
| eers = [] | |
| with torch.no_grad(): | |
| for step, speaker_batch in enumerate(dev_loader, 1): | |
| frames = torch.from_numpy(speaker_batch.data).to(device) | |
| embeds = model.forward(frames) | |
| embeds_loss = embeds.view((dev_batch, utterances_per_speaker, -1)).to(loss_device) | |
| loss, eer = model.loss(embeds_loss) | |
| losses.append(loss.item()) | |
| eers.append(eer) | |
| return sum(losses) / len(losses), sum(eers) / len(eers), embeds.detach() | |