|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing. |
|
""" |
|
|
|
import os |
|
import re |
|
import torch |
|
import numpy as np |
|
import soundfile as sf |
|
import soxr |
|
|
|
from pathlib import Path |
|
from typing import Optional, Union, List, Dict, Tuple, Any |
|
|
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers.models.auto.tokenization_auto import AutoTokenizer |
|
from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor |
|
from transformers.utils import logging, PushToHubMixin |
|
from numpy.lib.stride_tricks import sliding_window_view |
|
import soxr |
|
import soundfile |
|
import random |
|
|
|
|
|
from .configuration_spark_tts import SparkTTSConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: |
|
""" |
|
Normalize the volume of an audio signal. |
|
|
|
Parameters: |
|
audio (numpy array): Input audio signal array. |
|
coeff (float): Target coefficient for normalization, default is 0.2. |
|
|
|
Returns: |
|
numpy array: The volume-normalized audio signal. |
|
""" |
|
|
|
temp = np.sort(np.abs(audio)) |
|
|
|
|
|
if temp[-1] < 0.1: |
|
scaling_factor = max( |
|
temp[-1], 1e-3 |
|
) |
|
audio = audio / scaling_factor * 0.1 |
|
|
|
|
|
temp = temp[temp > 0.01] |
|
L = temp.shape[0] |
|
|
|
|
|
if L <= 10: |
|
return audio |
|
|
|
|
|
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) |
|
|
|
|
|
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) |
|
|
|
|
|
max_value = np.max(np.abs(audio)) |
|
if max_value > 1: |
|
audio = audio / max_value |
|
|
|
return audio |
|
|
|
|
|
def load_audio( |
|
adfile: Path, |
|
sampling_rate: int = None, |
|
length: int = None, |
|
volume_normalize: bool = False, |
|
segment_duration: int = None, |
|
) -> np.ndarray: |
|
r"""Load audio file with target sampling rate and lsength |
|
|
|
Args: |
|
adfile (Path): path to audio file. |
|
sampling_rate (int, optional): target sampling rate. Defaults to None. |
|
length (int, optional): target audio length. Defaults to None. |
|
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False. |
|
segment_duration (int): random select a segment with duration of {segment_duration}s. |
|
Defualt to None which means the whole audio will be used. |
|
|
|
Returns: |
|
audio (np.ndarray): audio |
|
""" |
|
|
|
audio, sr = soundfile.read(adfile) |
|
if len(audio.shape) > 1: |
|
audio = audio[:, 0] |
|
|
|
if sampling_rate is not None and sr != sampling_rate: |
|
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") |
|
sr = sampling_rate |
|
|
|
if segment_duration is not None: |
|
seg_length = int(sr * segment_duration) |
|
audio = random_select_audio_segment(audio, seg_length) |
|
|
|
|
|
if volume_normalize: |
|
audio = audio_volume_normalize(audio) |
|
|
|
if length is not None: |
|
assert abs(audio.shape[0] - length) < 1000 |
|
if audio.shape[0] > length: |
|
audio = audio[:length] |
|
else: |
|
audio = np.pad(audio, (0, int(length - audio.shape[0]))) |
|
return audio |
|
|
|
|
|
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: |
|
"""get an audio segment given the length |
|
|
|
Args: |
|
audio (np.ndarray): |
|
length (int): audio length = sampling_rate * duration |
|
""" |
|
if audio.shape[0] < length: |
|
audio = np.pad(audio, (0, int(length - audio.shape[0]))) |
|
start_index = random.randint(0, audio.shape[0] - length) |
|
end_index = int(start_index + length) |
|
|
|
return audio[start_index:end_index] |
|
|
|
def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: |
|
"""Get reference audio clip for speaker embedding.""" |
|
|
|
if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']): |
|
raise AttributeError("Config object missing required attributes for get_ref_clip") |
|
ref_segment_length = ( |
|
int(config.sample_rate * config.ref_segment_duration) |
|
// config.latent_hop_length |
|
* config.latent_hop_length |
|
) |
|
wav_length = len(wav) |
|
if ref_segment_length > wav_length: |
|
wav = np.tile(wav, ref_segment_length // wav_length + 1) |
|
return wav[:ref_segment_length] |
|
|
|
|
|
|
|
|
|
TASK_TOKEN_MAP = { |
|
"vc": "<|task_vc|>", |
|
"tts": "<|task_tts|>", |
|
"asr": "<|task_asr|>", |
|
"s2s": "<|task_s2s|>", |
|
"t2s": "<|task_t2s|>", |
|
"understand": "<|task_understand|>", |
|
"caption": "<|task_cap|>", |
|
"controllable_tts": "<|task_controllable_tts|>", |
|
"prompt_tts": "<|task_prompt_tts|>", |
|
"speech_edit": "<|task_edit|>", |
|
} |
|
|
|
LEVELS_MAP = { |
|
"very_low": 0, |
|
"low": 1, |
|
"moderate": 2, |
|
"high": 3, |
|
"very_high": 4, |
|
} |
|
|
|
LEVELS_MAP_UI = { |
|
1: 'very_low', |
|
2: 'low', |
|
3: 'moderate', |
|
4: 'high', |
|
5: 'very_high' |
|
} |
|
|
|
GENDER_MAP = { |
|
"female": 0, |
|
"male": 1, |
|
} |
|
|
|
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4} |
|
|
|
EMO_MAP = { |
|
"UNKNOWN": 0, |
|
"NEUTRAL": 1, |
|
"ANGRY": 2, |
|
"HAPPY": 3, |
|
"SAD": 4, |
|
"FEARFUL": 5, |
|
"DISGUSTED": 6, |
|
"SURPRISED": 7, |
|
"SARCASTIC": 8, |
|
"EXCITED": 9, |
|
"SLEEPY": 10, |
|
"CONFUSED": 11, |
|
"EMPHASIS": 12, |
|
"LAUGHING": 13, |
|
"SINGING": 14, |
|
"WORRIED": 15, |
|
"WHISPER": 16, |
|
"ANXIOUS": 17, |
|
"NO-AGREEMENT": 18, |
|
"APOLOGETIC": 19, |
|
"CONCERNED": 20, |
|
"ENUNCIATED": 21, |
|
"ASSERTIVE": 22, |
|
"ENCOURAGING": 23, |
|
"CONTEMPT": 24, |
|
} |
|
|
|
|
|
class TokenParser: |
|
"""Turn label to special token""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
"""Parse the attributes of a person.""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
@staticmethod |
|
def age(age: str) -> str: |
|
"""Turn age token.""" |
|
age_id = AGE_MAP[age] |
|
return f"<|age_{age_id}|>" |
|
|
|
@staticmethod |
|
def gender(gender: str) -> str: |
|
"""Turn gender token.""" |
|
gender_id = GENDER_MAP[gender] |
|
return f"<|gender_{gender_id}|>" |
|
|
|
@staticmethod |
|
def mel_value(mel: int): |
|
"""Turn special token of mel scale pitch.""" |
|
mel = max(0, int(mel)) |
|
mel = min(1000, int(mel)) |
|
return f"<|pitch_value_{mel}|>" |
|
|
|
@staticmethod |
|
def mel_level(level: str): |
|
"""Turn special token of mel level.""" |
|
level_tag = LEVELS_MAP[level] |
|
return f"<|pitch_label_{level_tag}|>" |
|
|
|
@staticmethod |
|
def pitch_var_value(pitch_std: int): |
|
"""Turn special token of pitch_std value.""" |
|
assert isinstance(pitch_std, int) |
|
pitch_std = max(0, int(pitch_std)) |
|
pitch_std = min(10, int(pitch_std)) |
|
return f"<|pitch_var_value_{pitch_std}|>" |
|
|
|
@staticmethod |
|
def pitch_var_level(level: str): |
|
"""Turn special token of pitch std level.""" |
|
level_tag = LEVELS_MAP[level] |
|
return f"<|pitch_var_label_{level_tag}|>" |
|
|
|
@staticmethod |
|
def loudness_value(loudness: int): |
|
"""Turn special toak of loudness value [0, 30]""" |
|
assert loudness >= 0 |
|
loudness = max(0, int(loudness)) |
|
loudness = min(30, int(loudness)) |
|
return f"<|loudness_value_{loudness}|>" |
|
|
|
@staticmethod |
|
def loudness_level(level: str): |
|
"""Turn special token of loudness level.""" |
|
level_tag = LEVELS_MAP[level] |
|
return f"<|loudness_label_{level_tag}|>" |
|
|
|
@staticmethod |
|
def speed_value(speed: int): |
|
"""Turn special token of speed value.""" |
|
speed = max(0, int(speed)) |
|
speed = min(10, int(speed)) |
|
return f"<|speed_value_{speed}|>" |
|
|
|
@staticmethod |
|
def speed_level(level: str): |
|
"""Turn special token of speed level.""" |
|
level_tag = LEVELS_MAP[level] |
|
return f"<|speed_label_{level_tag}|>" |
|
|
|
@staticmethod |
|
def task(task: str) -> str: |
|
"""Turn special token of task.""" |
|
assert task in TASK_TOKEN_MAP.keys() |
|
|
|
return TASK_TOKEN_MAP[task] |
|
|
|
@staticmethod |
|
def emotion(emotion: str): |
|
emo_id = EMO_MAP[emotion] |
|
|
|
return f"<|emotion_{emo_id}|>" |
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): |
|
r""" |
|
Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic. |
|
|
|
Args: |
|
tokenizer ([`PreTrainedTokenizer`]): |
|
An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM. |
|
feature_extractor ([`Wav2Vec2FeatureExtractor`]): |
|
An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted |
|
within the model's `tokenize_audio`, the extractor's configuration (like sampling rate) |
|
is useful, and it aligns with the ProcessorMixin pattern. |
|
config ([`SparkTTSConfig`], *optional*): |
|
An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate. |
|
""" |
|
attributes = ["tokenizer", "feature_extractor"] |
|
tokenizer_class = "AutoTokenizer" |
|
feature_extractor_class = "Wav2Vec2FeatureExtractor" |
|
|
|
def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs): |
|
super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs) |
|
self.model = None |
|
self.config = config |
|
|
|
if config and hasattr(config, 'sample_rate'): |
|
self.sampling_rate = config.sample_rate |
|
elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'): |
|
self.sampling_rate = feature_extractor.sampling_rate |
|
else: |
|
self.sampling_rate = 16000 |
|
logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def link_model(self, model): |
|
"""Links the processor to a SparkTTSModel instance for audio processing calls.""" |
|
if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'): |
|
raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.") |
|
if not hasattr(model, 'config'): |
|
logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.") |
|
|
|
self.model = model |
|
logger.info("SparkTTSModel successfully linked to the processor.") |
|
|
|
if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'): |
|
if self.sampling_rate != model.config.sample_rate: |
|
logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.") |
|
self.sampling_rate = model.config.sample_rate |
|
|
|
if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate: |
|
logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.") |
|
self.feature_extractor.sampling_rate = model.config.sample_rate |
|
|
|
|
|
def __call__( |
|
self, |
|
text: str, |
|
prompt_speech_path: Optional[Union[str, Path]] = None, |
|
prompt_text: Optional[str] = None, |
|
gender: Optional[str] = None, |
|
pitch: Optional[str] = None, |
|
speed: Optional[str] = None, |
|
return_tensors: Optional[str] = "pt", |
|
**kwargs, |
|
) -> BatchEncoding: |
|
""" |
|
Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`]. |
|
|
|
Args: |
|
text (`str`): |
|
The main text to be synthesized. |
|
prompt_speech_path (`str` or `Path`, *optional*): |
|
Path to the prompt audio file for voice cloning. Required if `gender` is not set. |
|
prompt_text (`str`, *optional*): |
|
Transcript of the prompt audio. Used only in voice cloning mode. |
|
gender (`str`, *optional*): |
|
Target gender ("male" or "female") for controllable synthesis. If set, enables control mode. |
|
pitch (`str`, *optional*): |
|
Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set. |
|
speed (`str`, *optional*): |
|
Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set. |
|
return_tensors (`str`, *optional*, defaults to `"pt"`): |
|
If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently. |
|
**kwargs: |
|
Additional arguments passed to the underlying tokenizer's `__call__` method. |
|
|
|
Returns: |
|
[`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM. |
|
In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the |
|
global tokens extracted from the prompt audio. |
|
""" |
|
|
|
global_token_ids_prompt = None |
|
|
|
|
|
is_control_mode = gender is not None |
|
is_cloning_mode = prompt_speech_path is not None and not is_control_mode |
|
|
|
if is_control_mode: |
|
|
|
if not all([pitch, speed]): |
|
raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.") |
|
if prompt_speech_path is not None: |
|
logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).") |
|
|
|
if not all(k in GENDER_MAP for k in [gender]): |
|
raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}") |
|
if not all(k in LEVELS_MAP for k in [pitch, speed]): |
|
raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}") |
|
|
|
gender_id = GENDER_MAP[gender] |
|
pitch_level_id = LEVELS_MAP[pitch] |
|
speed_level_id = LEVELS_MAP[speed] |
|
|
|
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" |
|
speed_label_tokens = f"<|speed_label_{speed_level_id}|>" |
|
gender_tokens = f"<|gender_{gender_id}|>" |
|
|
|
attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens]) |
|
|
|
prompt_list = [ |
|
TASK_TOKEN_MAP["controllable_tts"], |
|
"<|start_content|>", |
|
text, |
|
"<|end_content|>", |
|
"<|start_style_label|>", |
|
attribute_tokens, |
|
"<|end_style_label|>", |
|
] |
|
prompt_string = "".join(prompt_list) |
|
|
|
elif is_cloning_mode: |
|
|
|
if self.model is None: |
|
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.") |
|
prompt_speech_path = Path(prompt_speech_path) |
|
if not prompt_speech_path.exists(): |
|
raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}") |
|
|
|
|
|
try: |
|
model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config |
|
if model_config is None: |
|
raise ValueError("Configuration not available in processor or linked model.") |
|
|
|
|
|
wav = load_audio( |
|
prompt_speech_path, |
|
sampling_rate=self.sampling_rate, |
|
volume_normalize=getattr(model_config, 'volume_normalize', True), |
|
) |
|
|
|
wav_ref_np = get_ref_clip(wav, model_config) |
|
wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float() |
|
wav_tensor = torch.from_numpy(wav).unsqueeze(0).float() |
|
|
|
|
|
|
|
global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref) |
|
|
|
|
|
global_token_ids_prompt = global_tokens_tensor |
|
|
|
|
|
global_token_list = global_tokens_tensor.squeeze().tolist() |
|
semantic_token_list = semantic_tokens_tensor.squeeze().tolist() |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
|
|
|
|
|
|
|
|
global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list]) |
|
semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list]) |
|
|
|
|
|
|
|
if prompt_text is not None and prompt_text.strip(): |
|
logger.info("Using prompt text in voice cloning prompt.") |
|
prompt_list = [ |
|
TASK_TOKEN_MAP["tts"], |
|
"<|start_content|>", |
|
prompt_text, |
|
text, |
|
"<|end_content|>", |
|
"<|start_global_token|>", |
|
global_tokens_str, |
|
"<|end_global_token|>", |
|
"<|start_semantic_token|>", |
|
semantic_tokens_str, |
|
|
|
] |
|
else: |
|
|
|
logger.info("No prompt text provided, using text-only voice cloning prompt.") |
|
prompt_list = [ |
|
TASK_TOKEN_MAP["tts"], |
|
"<|start_content|>", |
|
text, |
|
"<|end_content|>", |
|
"<|start_global_token|>", |
|
global_tokens_str, |
|
"<|end_global_token|>", |
|
] |
|
prompt_string = "".join(prompt_list) |
|
logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") |
|
|
|
else: |
|
raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.") |
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
prompt_string, |
|
return_tensors=return_tensors, |
|
padding=kwargs.get("padding", False), |
|
truncation=kwargs.get("truncation", True), |
|
max_length=kwargs.get("max_length", self.tokenizer.model_max_length), |
|
add_special_tokens=kwargs.get("add_special_tokens", True), |
|
return_attention_mask=kwargs.get("return_attention_mask", True), |
|
**{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]} |
|
) |
|
logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}") |
|
|
|
|
|
|
|
if is_cloning_mode and global_token_ids_prompt is not None: |
|
if return_tensors == "pt": |
|
inputs["global_token_ids_prompt"] = global_token_ids_prompt |
|
else: |
|
|
|
inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist() |
|
|
|
return inputs |
|
|
|
|
|
def decode( |
|
self, |
|
generated_ids: torch.Tensor, |
|
global_token_ids_prompt: Optional[torch.Tensor] = None, |
|
input_ids_len: Optional[int] = None, |
|
skip_special_tokens: bool = True, |
|
) -> Dict[str, Any]: |
|
""" |
|
Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform. |
|
|
|
Args: |
|
generated_ids (`torch.Tensor`): |
|
Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len]. |
|
global_token_ids_prompt (`torch.Tensor`, *optional*): |
|
The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning). |
|
Shape [B, N_global]. Required if the generation was for voice cloning. |
|
input_ids_len (`int`, *optional*): |
|
The length of the original input prompt `input_ids` fed to `model.generate()`. Required to |
|
correctly isolate the newly generated tokens. |
|
skip_special_tokens (`bool`, *optional*, defaults to `True`): |
|
Whether to skip special tokens during the text decoding step (used to extract audio tokens). |
|
|
|
Returns: |
|
Dict[str, Any]: A dictionary containing: |
|
- "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio]. |
|
- "sampling_rate": The sampling rate of the audio. |
|
""" |
|
if self.model is None: |
|
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.") |
|
if input_ids_len is None: |
|
raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.") |
|
|
|
|
|
|
|
|
|
if generated_ids.shape[1] < input_ids_len: |
|
logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.") |
|
output_only_ids = generated_ids[:, input_ids_len:] |
|
else: |
|
output_only_ids = generated_ids[:, input_ids_len:] |
|
|
|
|
|
|
|
|
|
|
|
decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens) |
|
|
|
|
|
|
|
batch_size = generated_ids.shape[0] |
|
all_semantic_ids = [] |
|
all_global_tokens = [] |
|
successful_indices = [] |
|
|
|
for i in range(batch_size): |
|
decoded_text = decoded_texts[i] |
|
current_semantic_ids = None |
|
current_global_tokens = None |
|
|
|
|
|
try: |
|
pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)] |
|
if not pred_semantic_indices: |
|
logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'") |
|
continue |
|
|
|
current_semantic_ids = torch.tensor(pred_semantic_indices).long() |
|
except Exception as e: |
|
logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}") |
|
continue |
|
|
|
|
|
if global_token_ids_prompt is not None: |
|
|
|
if global_token_ids_prompt.shape[0] != batch_size: |
|
raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.") |
|
current_global_tokens = global_token_ids_prompt[i] |
|
else: |
|
|
|
try: |
|
pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)] |
|
if not pred_global_indices: |
|
logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'") |
|
continue |
|
|
|
current_global_tokens = torch.tensor(pred_global_indices).long() |
|
|
|
except Exception as e: |
|
logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}") |
|
continue |
|
|
|
|
|
all_semantic_ids.append(current_semantic_ids) |
|
all_global_tokens.append(current_global_tokens) |
|
successful_indices.append(i) |
|
|
|
if not successful_indices: |
|
logger.error("Failed to extract audio tokens for any item in the batch.") |
|
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} |
|
|
|
|
|
|
|
|
|
|
|
|
|
if batch_size > 1 and len(successful_indices) < batch_size: |
|
logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.") |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
if len(successful_indices) != 1: |
|
raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.") |
|
|
|
final_semantic_ids = all_semantic_ids[0].unsqueeze(0) |
|
final_global_tokens = all_global_tokens[0].unsqueeze(0) |
|
|
|
except IndexError: |
|
logger.error("Internal error during token batch preparation.") |
|
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate} |
|
|
|
|
|
|
|
try: |
|
|
|
|
|
output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error during audio detokenization: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
raise RuntimeError("Audio detokenization failed.") from e |
|
|
|
return {"audio": output_wav, "sampling_rate": self.sampling_rate} |
|
|
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Union[str, os.PathLike], |
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
force_download: bool = False, |
|
local_files_only: bool = False, |
|
token: Optional[Union[str, bool]] = None, |
|
revision: str = "main", |
|
trust_remote_code: bool = False, |
|
**kwargs, |
|
): |
|
r""" |
|
Instantiate a SparkTTSProcessor from pretrained components. |
|
""" |
|
|
|
config = kwargs.pop("config", None) |
|
|
|
|
|
|
|
|
|
loaded_config = None |
|
if not isinstance(config, SparkTTSConfig): |
|
try: |
|
|
|
loaded_config = SparkTTSConfig.from_pretrained( |
|
pretrained_model_name_or_path, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
trust_remote_code=trust_remote_code, |
|
**kwargs, |
|
) |
|
except Exception as e: |
|
logger.warning( |
|
f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. " |
|
f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}" |
|
) |
|
loaded_config = None |
|
else: |
|
|
|
loaded_config = config |
|
|
|
|
|
|
|
llm_tokenizer_path_or_id = "./LLM" |
|
w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" |
|
|
|
if loaded_config: |
|
llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id) |
|
w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
component_loading_kwargs = { |
|
"cache_dir": cache_dir, |
|
"force_download": force_download, |
|
"local_files_only": local_files_only, |
|
"token": token, |
|
"revision": revision, |
|
**kwargs |
|
} |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder=llm_tokenizer_path_or_id.lstrip('./'), |
|
trust_remote_code=trust_remote_code, |
|
**component_loading_kwargs |
|
) |
|
except Exception as e: |
|
|
|
if llm_tokenizer_path_or_id != "./LLM": |
|
try: |
|
logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
llm_tokenizer_path_or_id, |
|
trust_remote_code=trust_remote_code, |
|
**component_loading_kwargs |
|
) |
|
except Exception as e2: |
|
raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e |
|
else: |
|
raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}") |
|
|
|
|
|
try: |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder=w2v_processor_path_or_id.lstrip('./'), |
|
**component_loading_kwargs |
|
) |
|
except Exception as e: |
|
|
|
if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53": |
|
try: |
|
logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}") |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
w2v_processor_path_or_id, |
|
**component_loading_kwargs |
|
) |
|
except Exception as e2: |
|
raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e |
|
else: |
|
raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}") |
|
|
|
|
|
|
|
|
|
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config) |
|
|
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: Union[str, os.PathLike], |
|
push_to_hub: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Save the processor's state (tokenizer and feature extractor files) to a directory. |
|
|
|
Args: |
|
save_directory (`str` or `os.PathLike`): |
|
Directory where the processor files will be saved. |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether or not to push your model to the Hugging Face Hub after saving it. |
|
**kwargs: |
|
Additional key word arguments passed along to the `push_to_hub` method. |
|
""" |
|
save_directory = Path(save_directory) |
|
save_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.tokenizer.save_pretrained(str(save_directory), **kwargs) |
|
|
|
|
|
self.feature_extractor.save_pretrained(str(save_directory), **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Processor components saved in {save_directory}") |
|
|
|
if push_to_hub: |
|
|
|
commit_message = kwargs.pop("commit_message", "Save processor") |
|
return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs) |
|
|
|
return str(save_directory) |