Dynamic_Source_Separator_Causal / dynamic_source_separator.py
groadabike's picture
Upload 2 files
7c6b998 verified
raw
history blame contribute delete
824 Bytes
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from huggingface_hub import ModelCard
from tasnet import ConvTasNetStereo
class DynamicSourceSeparator(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, pre_trained_models):
super(DynamicSourceSeparator, self).__init__()
self.models = nn.ModuleDict(pre_trained_models)
def forward(self, mixture, indicator):
separated_sources = {}
for instrument, active in indicator.items():
if active:
model = self.models[instrument]
est_source = model(mixture)
separated_sources[instrument] = est_source[:, 0, :, :]
else:
separated_sources[instrument] = torch.zeros_like(mixture)
return separated_sources