File size: 5,872 Bytes
76c0443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

from transformers import PretrainedConfig
from typing import List, Optional

class QuantizerConfig(PretrainedConfig):
    model_type = "prosody_quantizer"
    
    def __init__(
        self,
        # VQ parameters
        l_bins: int = 320,
        emb_width: int = 64,
        mu: float = 0.99,
        levels: int = 1,
        
        # Encoder parameters
        encoder_input_emb_width: int = 3,
        encoder_output_emb_width: int = 64,
        encoder_levels: int = 1,
        encoder_downs_t: List[int] = [4],
        encoder_strides_t: List[int] = [2],
        encoder_width: int = 32,
        encoder_depth: int = 4,
        encoder_m_conv: float = 1.0,
        encoder_dilation_growth_rate: int = 3,
        
        # Decoder parameters
        decoder_input_emb_width: int = 3,
        decoder_output_emb_width: int = 64,
        decoder_levels: int = 1,
        decoder_downs_t: List[int] = [4],
        decoder_strides_t: List[int] = [2],
        decoder_width: int = 32,
        decoder_depth: int = 4,
        decoder_m_conv: float = 1.0,
        decoder_dilation_growth_rate: int = 3,
        
        # Training parameters
        lambda_commit: float = 0.02,
        f0_normalize: bool = True,
        intensity_normalize: bool = True,
        multispkr: str = "single",
        f0_feats: bool = False,
        f0_median: bool = False,
        
        # Optional training hyperparameters
        learning_rate: float = 0.0002,
        adam_b1: float = 0.8,
        adam_b2: float = 0.99,
        lr_decay: float = 0.999,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        # VQ parameters
        self.l_bins = l_bins
        self.emb_width = emb_width
        self.mu = mu
        self.levels = levels
        
        # Encoder parameters
        self.encoder_input_emb_width = encoder_input_emb_width
        self.encoder_output_emb_width = encoder_output_emb_width
        self.encoder_levels = encoder_levels
        self.encoder_downs_t = encoder_downs_t
        self.encoder_strides_t = encoder_strides_t
        self.encoder_width = encoder_width
        self.encoder_depth = encoder_depth
        self.encoder_m_conv = encoder_m_conv
        self.encoder_dilation_growth_rate = encoder_dilation_growth_rate
        
        # Decoder parameters
        self.decoder_input_emb_width = decoder_input_emb_width
        self.decoder_output_emb_width = decoder_output_emb_width
        self.decoder_levels = decoder_levels
        self.decoder_downs_t = decoder_downs_t
        self.decoder_strides_t = decoder_strides_t
        self.decoder_width = decoder_width
        self.decoder_depth = decoder_depth
        self.decoder_m_conv = decoder_m_conv
        self.decoder_dilation_growth_rate = decoder_dilation_growth_rate
        
        # Training parameters
        self.lambda_commit = lambda_commit
        self.f0_normalize = f0_normalize
        self.intensity_normalize = intensity_normalize
        self.multispkr = multispkr
        self.f0_feats = f0_feats
        self.f0_median = f0_median
        
        # Training hyperparameters
        self.learning_rate = learning_rate
        self.adam_b1 = adam_b1
        self.adam_b2 = adam_b2
        self.lr_decay = lr_decay
    
    @property
    def f0_vq_params(self):
        return {
            "l_bins": self.l_bins,
            "emb_width": self.emb_width,
            "mu": self.mu,
            "levels": self.levels
        }
    
    @property
    def f0_encoder_params(self):
        return {
            "input_emb_width": self.encoder_input_emb_width,
            "output_emb_width": self.encoder_output_emb_width,
            "levels": self.encoder_levels,
            "downs_t": self.encoder_downs_t,
            "strides_t": self.encoder_strides_t,
            "width": self.encoder_width,
            "depth": self.encoder_depth,
            "m_conv": self.encoder_m_conv,
            "dilation_growth_rate": self.encoder_dilation_growth_rate
        }
    
    @property
    def f0_decoder_params(self):
        return {
            "input_emb_width": self.decoder_input_emb_width,
            "output_emb_width": self.decoder_output_emb_width,
            "levels": self.decoder_levels,
            "downs_t": self.decoder_downs_t,
            "strides_t": self.decoder_strides_t,
            "width": self.decoder_width,
            "depth": self.decoder_depth,
            "m_conv": self.decoder_m_conv,
            "dilation_growth_rate": self.decoder_dilation_growth_rate
        }
    
    @classmethod
    def from_yaml(cls, yaml_path: str):
        """Load config from yaml file"""
        import yaml
        with open(yaml_path, 'r') as f:
            config = yaml.safe_load(f)
        
        # Convert yaml config to kwargs
        kwargs = {
            # VQ params
            **{k: v for k, v in config['f0_vq_params'].items()},
            
            # Encoder params
            **{f"encoder_{k}": v for k, v in config['f0_encoder_params'].items()},
            
            # Decoder params  
            **{f"decoder_{k}": v for k, v in config['f0_decoder_params'].items()},
            
            # Training params
            "lambda_commit": config.get('lambda_commit', 0.02),
            "f0_normalize": config.get('f0_normalize', True),
            "intensity_normalize": config.get('intensity_normalize', True),
            "multispkr": config.get('multispkr', "single"),
            "f0_feats": config.get('f0_feats', False),
            "f0_median": config.get('f0_median', False),
            
            # Training hyperparams
            "learning_rate": config.get('learning_rate', 0.0002),
            "adam_b1": config.get('adam_b1', 0.8), 
            "adam_b2": config.get('adam_b2', 0.99),
            "lr_decay": config.get('lr_decay', 0.999),
        }
        
        return cls(**kwargs)