| # adapted from https://github.com/jik876/hifi-gan | |
| from transformers.modeling_utils import PreTrainedModel | |
| from quantizer_config import QuantizerConfig | |
| from modules.jukebox import Encoder, Decoder | |
| from modules.vq import Bottleneck | |
| class Quantizer(PreTrainedModel): | |
| config_class = QuantizerConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.encoder = Encoder(**config.f0_encoder_params) | |
| self.vq = Bottleneck(**config.f0_vq_params) | |
| self.decoder = Decoder(**config.f0_decoder_params) | |
| def forward(self, **kwargs): | |
| f0_h = self.encoder(kwargs['features']) | |
| zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h) | |
| f0 = self.decoder(f0_h_q) | |
| return { | |
| 'f0': f0, | |
| 'commit_losses': f0_commit_losses, | |
| 'metrics': f0_metrics, | |
| 'codes': zs, | |
| 'hidden_states': f0_h_q | |
| } | |