File size: 1,040 Bytes
c199313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# adapted from https://github.com/jik876/hifi-gan

from transformers.modeling_utils import PreTrainedModel

from quantizer_config import QuantizerConfig
from modules.jukebox import Encoder, Decoder
from modules.vq import Bottleneck



class Quantizer(PreTrainedModel):
    config_class = QuantizerConfig

    def __init__(self, config):
        super().__init__(config)

        self.config = config
        self.encoder = Encoder(**config.f0_encoder_params)
        self.vq = Bottleneck(**config.f0_vq_params)
        self.decoder = Decoder(**config.f0_decoder_params)

    def forward(self, **kwargs):
        f0_h = self.encoder(kwargs['features'])
        
        zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h)
        
        f0 = self.decoder(f0_h_q)

        return {
            'f0': f0,                          
            'commit_losses': f0_commit_losses, 
            'metrics': f0_metrics,             
            'codes': zs,                       
            'hidden_states': f0_h_q            
        }