# adapted from https://github.com/jik876/hifi-gan from transformers.modeling_utils import PreTrainedModel from quantizer_config import QuantizerConfig from modules.jukebox import Encoder, Decoder from modules.vq import Bottleneck class Quantizer(PreTrainedModel): config_class = QuantizerConfig def __init__(self, config): super().__init__(config) self.config = config self.encoder = Encoder(**config.f0_encoder_params) self.vq = Bottleneck(**config.f0_vq_params) self.decoder = Decoder(**config.f0_decoder_params) def forward(self, **kwargs): f0_h = self.encoder(kwargs['features']) zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h) f0 = self.decoder(f0_h_q) return { 'f0': f0, 'commit_losses': f0_commit_losses, 'metrics': f0_metrics, 'codes': zs, 'hidden_states': f0_h_q }