import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import ModelCard from tasnet import ConvTasNetStereo class DynamicSourceSeparator(torch.nn.Module, PyTorchModelHubMixin): def __init__(self, pre_trained_models): super(DynamicSourceSeparator, self).__init__() self.models = nn.ModuleDict(pre_trained_models) def forward(self, mixture, indicator): separated_sources = {} for instrument, active in indicator.items(): if active: model = self.models[instrument] est_source = model(mixture) separated_sources[instrument] = est_source[:, 0, :, :] else: separated_sources[instrument] = torch.zeros_like(mixture) return separated_sources