Vi-SparkTTS-0.5B / configuration_spark_tts.py
ancv's picture
Upload 4 files
f6c1e7e verified
# coding=utf-8
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
# ... (License headers remain the same) ...
""" SparkTTS model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from typing import List, Optional # Added typing
logger = logging.get_logger(__name__)
# --- Define Individual Sub-Component Config Classes ---
class SparkTTSMelParamsConfig(PretrainedConfig):
"""Configuration for Mel Spectrogram parameters."""
model_type = "spark-tts-mel-params"
def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
super().__init__(**kwargs)
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.num_mels = num_mels
class SparkTTSEncoderConfig(PretrainedConfig):
"""Configuration for the BiCodec Feature Encoder."""
model_type = "spark-tts-encoder"
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
super().__init__(**kwargs)
self.input_channels = input_channels
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.out_channels = out_channels
self.sample_ratios = sample_ratios
class SparkTTSDecoderConfig(PretrainedConfig):
"""Configuration for the BiCodec Wave Generator (Decoder)."""
model_type = "spark-tts-decoder"
def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
kernel_sizes=[16, 11, 8, 4], **kwargs):
super().__init__(**kwargs)
self.input_channel = input_channel
self.channels = channels
self.rates = rates
self.kernel_sizes = kernel_sizes
class SparkTTSQuantizerConfig(PretrainedConfig):
"""Configuration for the BiCodec Factorized Vector Quantizer."""
model_type = "spark-tts-quantizer"
def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
threshold_ema_dead_code=0.2, **kwargs):
# Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
# Add it back if it's actually used by the FactorizedVectorQuantize class init
super().__init__(**kwargs)
self.input_dim = input_dim
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.commitment = commitment
self.codebook_loss_weight = codebook_loss_weight
self.decay = decay
self.threshold_ema_dead_code = threshold_ema_dead_code
class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
"""Configuration for the BiCodec Speaker Encoder."""
model_type = "spark-tts-speaker-encoder"
def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.out_dim = out_dim
self.latent_dim = latent_dim
self.token_num = token_num
self.fsq_levels = fsq_levels
self.fsq_num_quantizers = fsq_num_quantizers
class SparkTTSPrenetConfig(PretrainedConfig):
"""Configuration for the BiCodec Prenet."""
model_type = "spark-tts-prenet"
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
vocos_num_layers=12, out_channels=1024, condition_dim=1024,
sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
super().__init__(**kwargs)
self.input_channels = input_channels
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.out_channels = out_channels
self.condition_dim = condition_dim
self.sample_ratios = sample_ratios
self.use_tanh_at_final = use_tanh_at_final
class SparkTTSPostnetConfig(PretrainedConfig):
"""Configuration for the BiCodec Postnet."""
model_type = "spark-tts-postnet"
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
# Note: Removed condition_dim as it wasn't in the original config example for postnet
super().__init__(**kwargs)
self.input_channels = input_channels
self.vocos_dim = vocos_dim
self.vocos_intermediate_dim = vocos_intermediate_dim
self.vocos_num_layers = vocos_num_layers
self.out_channels = out_channels
self.use_tanh_at_final = use_tanh_at_final
# --- Define the Intermediate BiCodec Config Class ---
class SparkTTSBiCodecConfig(PretrainedConfig):
"""
Intermediate configuration class for the BiCodec component within SparkTTS.
It holds instances of the individual sub-component configurations.
"""
model_type = "spark-tts-bicodec"
# Map keys in the 'bicodec_config' dict to their respective classes
sub_configs = {
"mel_params": SparkTTSMelParamsConfig,
"encoder_config": SparkTTSEncoderConfig,
"decoder_config": SparkTTSDecoderConfig,
"quantizer_config": SparkTTSQuantizerConfig,
"speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
"prenet_config": SparkTTSPrenetConfig,
"postnet_config": SparkTTSPostnetConfig,
}
def __init__(
self,
mel_params=None,
encoder_config=None,
decoder_config=None,
quantizer_config=None,
speaker_encoder_config=None,
prenet_config=None,
postnet_config=None,
**kwargs,
):
super().__init__(**kwargs)
# Instantiate sub-configs from dictionaries or use defaults/provided instances
self.mel_params = self._init_sub_config(mel_params, "mel_params")
self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
def _init_sub_config(self, config_input, config_key):
"""Helper to initialize sub-configs."""
config_cls = self.sub_configs[config_key]
if isinstance(config_input, dict):
return config_cls(**config_input)
elif config_input is None:
return config_cls() # Initialize with defaults
elif isinstance(config_input, config_cls):
return config_input # Already an instance
else:
raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
# --- Define the Main SparkTTS Config Class ---
class SparkTTSConfig(PretrainedConfig):
r"""
Main configuration class for SparkTTSModel, including nested BiCodec configuration.
Args:
llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
# ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
kwargs (*optional*): Dictionary of keyword arguments.
"""
model_type = "spark-tts"
# Map the key in config.json to the intermediate BiCodec config class
sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
attribute_map = {"hidden_size": "d_model"} # Example
def __init__(
self,
llm_model_name_or_path="./LLM",
bicodec_model_name_or_path="./BiCodec",
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
sample_rate=16000,
highpass_cutoff_freq=40,
latent_hop_length=320,
ref_segment_duration=6.0,
volume_normalize=True,
bicodec_config=None, # Expects a dictionary or None
torch_dtype="auto",
**kwargs,
):
# --- Top-level parameters ---
self.llm_model_name_or_path = llm_model_name_or_path
self.bicodec_model_name_or_path = bicodec_model_name_or_path
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
self.sample_rate = sample_rate
self.highpass_cutoff_freq = highpass_cutoff_freq
self.latent_hop_length = latent_hop_length
self.ref_segment_duration = ref_segment_duration
self.volume_normalize = volume_normalize
self.torch_dtype = torch_dtype
# --- Nested BiCodec Configuration ---
# Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
if isinstance(bicodec_config, dict):
self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
elif bicodec_config is None:
logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
self.bicodec_config = self.sub_configs["bicodec_config"]()
elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
self.bicodec_config = bicodec_config # Use existing instance
else:
raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
# Set processor class and auto_map
kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
kwargs["auto_map"] = kwargs.get("auto_map", {
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
"AutoModel": "modeling_spark_tts.SparkTTSModel",
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
})
super().__init__(**kwargs)