File size: 824 Bytes
7c6b998 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from huggingface_hub import ModelCard
from tasnet import ConvTasNetStereo
class DynamicSourceSeparator(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, pre_trained_models):
super(DynamicSourceSeparator, self).__init__()
self.models = nn.ModuleDict(pre_trained_models)
def forward(self, mixture, indicator):
separated_sources = {}
for instrument, active in indicator.items():
if active:
model = self.models[instrument]
est_source = model(mixture)
separated_sources[instrument] = est_source[:, 0, :, :]
else:
separated_sources[instrument] = torch.zeros_like(mixture)
return separated_sources
|