|
import torch |
|
import torch.nn as nn |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from huggingface_hub import ModelCard |
|
|
|
from tasnet import ConvTasNetStereo |
|
|
|
|
|
class DynamicSourceSeparator(torch.nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, pre_trained_models): |
|
super(DynamicSourceSeparator, self).__init__() |
|
self.models = nn.ModuleDict(pre_trained_models) |
|
|
|
def forward(self, mixture, indicator): |
|
separated_sources = {} |
|
for instrument, active in indicator.items(): |
|
if active: |
|
model = self.models[instrument] |
|
est_source = model(mixture) |
|
separated_sources[instrument] = est_source[:, 0, :, :] |
|
else: |
|
separated_sources[instrument] = torch.zeros_like(mixture) |
|
return separated_sources |
|
|