File size: 1,040 Bytes
c199313 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# adapted from https://github.com/jik876/hifi-gan
from transformers.modeling_utils import PreTrainedModel
from quantizer_config import QuantizerConfig
from modules.jukebox import Encoder, Decoder
from modules.vq import Bottleneck
class Quantizer(PreTrainedModel):
config_class = QuantizerConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.encoder = Encoder(**config.f0_encoder_params)
self.vq = Bottleneck(**config.f0_vq_params)
self.decoder = Decoder(**config.f0_decoder_params)
def forward(self, **kwargs):
f0_h = self.encoder(kwargs['features'])
zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h)
f0 = self.decoder(f0_h_q)
return {
'f0': f0,
'commit_losses': f0_commit_losses,
'metrics': f0_metrics,
'codes': zs,
'hidden_states': f0_h_q
}
|