Daporte's picture
Add files from https://github.com/facebookresearch/speech-resynthesis
c199313 verified
raw
history blame
1.04 kB
# adapted from https://github.com/jik876/hifi-gan
from transformers.modeling_utils import PreTrainedModel
from quantizer_config import QuantizerConfig
from modules.jukebox import Encoder, Decoder
from modules.vq import Bottleneck
class Quantizer(PreTrainedModel):
config_class = QuantizerConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.encoder = Encoder(**config.f0_encoder_params)
self.vq = Bottleneck(**config.f0_vq_params)
self.decoder = Decoder(**config.f0_decoder_params)
def forward(self, **kwargs):
f0_h = self.encoder(kwargs['features'])
zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h)
f0 = self.decoder(f0_h_q)
return {
'f0': f0,
'commit_losses': f0_commit_losses,
'metrics': f0_metrics,
'codes': zs,
'hidden_states': f0_h_q
}