|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, List |
|
import os |
|
import yaml |
|
import re |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
from hyperpyyaml import load_hyperpyyaml |
|
|
|
from transformers import Qwen2Config, PreTrainedModel |
|
from transformers import Qwen2ForCausalLM, AutoTokenizer |
|
from audio_detokenizer.cli.model import AudioDetokenizerModel |
|
from s3bpe_tokenizer import S3BpeTokenizer |
|
from configuration_bailing_talker import BailingTalkerConfig |
|
from transformers.utils import ModelOutput |
|
from sentence_manager.sentence_manager import SentenceNormalizer |
|
|
|
@dataclass |
|
class BailingTalkerOutputWithPast(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
logits: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class BailingTalkerForConditionalGeneration(PreTrainedModel): |
|
config_class = BailingTalkerConfig |
|
base_model_prefix = 'model' |
|
|
|
def __init__(self, config:BailingTalkerConfig): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.vocab_size = self.config.vocab_size |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config._name_or_path) |
|
self.model_config = Qwen2Config.from_pretrained(self.config._name_or_path) |
|
self.model = Qwen2ForCausalLM(self.model_config) |
|
self.model.resize_token_embeddings(self.vocab_size) |
|
self.thinker_to_talker_proj = nn.Linear(self.config.qa_model_hidden_size, self.model_config.hidden_size) |
|
self.vp_head = nn.Conv1d( |
|
self.config.vp_feature_size, |
|
self.model_config.hidden_size, |
|
kernel_size=self.config.vp_kernel_size, |
|
stride=self.config.vp_stride, |
|
padding=self.config.vp_kernel_size // 2, |
|
) |
|
self.s3bpe_tokenizer = S3BpeTokenizer(bpe_model=f"{self.config._name_or_path}/s3_bpe/tokenizer.json", mapping_file=f"{self.config._name_or_path}/s3_bpe/char_mapping.txt") |
|
|
|
self.loss_function = nn.CrossEntropyLoss() |
|
|
|
default_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sentence_manager/default_config.yaml") |
|
self.sentence_manager_config = yaml.safe_load(open(default_config_path)) |
|
if "split_token" not in self.sentence_manager_config: |
|
self.sentence_manager_config["split_token"] = [] |
|
assert isinstance(self.sentence_manager_config["split_token"], list) |
|
self.sentence_manager_config["split_token"].append(re.escape(self.tokenizer.eos_token)) |
|
self.normalizer = SentenceNormalizer(self.sentence_manager_config.get("text_norm", {})) |
|
|
|
def get_input_embeddings(self): |
|
return self.model.get_input_embeddings() |
|
|
|
def encode_audio_segments( |
|
self, |
|
inputs_embeds: torch.FloatTensor, |
|
vp_emb: torch.FloatTensor, |
|
vp_insert_loc: torch.LongTensor, |
|
thinker_reply_part: Optional[torch.FloatTensor] = None, |
|
thinker_reply_length: Optional[List] = None, |
|
thinker_prefix_insert_loc: Optional[torch.LongTensor] = None |
|
): |
|
vp_emb_encoded = self.vp_head(vp_emb.transpose(-1, -2)).transpose(-1, -2) |
|
|
|
for idx in range(vp_insert_loc.shape[0]): |
|
inputs_embeds[idx, vp_insert_loc[idx].item():vp_insert_loc[idx].item() + 1, :] = vp_emb_encoded[idx, :, :] |
|
|
|
if thinker_prefix_insert_loc is not None: |
|
thinker_reply_part = self.thinker_to_talker_proj(thinker_reply_part) |
|
for idx in range(thinker_prefix_insert_loc.shape[0]): |
|
real_length = thinker_reply_length[idx] |
|
inputs_embeds[idx, thinker_prefix_insert_loc[idx].item():thinker_prefix_insert_loc[idx].item() + real_length, :] = thinker_reply_part[idx, :real_length, :] |
|
|
|
return inputs_embeds |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[dict] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
text_input_ids: Optional[torch.LongTensor] = None, |
|
vp_emb: Optional[torch.FloatTensor] = None, |
|
vp_insert_loc: Optional[torch.LongTensor] = None, |
|
thinker_reply_part: Optional[torch.FloatTensor] = None, |
|
thinker_reply_length: Optional[torch.FloatTensor] = None, |
|
thinker_prefix_insert_loc: Optional[torch.LongTensor] = None, |
|
): |
|
|
|
if inputs_embeds is None: |
|
audio_input_embeds = self.model.get_input_embeddings()(input_ids) |
|
text_input_embeds = self.model.get_input_embeddings()(text_input_ids) |
|
inputs_embeds = audio_input_embeds + text_input_embeds |
|
if past_key_values is None: |
|
inputs_embeds = self.encode_audio_segments( |
|
inputs_embeds, vp_emb, vp_insert_loc, thinker_reply_part=thinker_reply_part, |
|
thinker_reply_length=thinker_reply_length, thinker_prefix_insert_loc=thinker_prefix_insert_loc |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill_((attention_mask == 0), 1) |
|
|
|
outputs = self.model( |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
logits = outputs.logits |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)) |
|
|
|
return BailingTalkerOutputWithPast( |
|
loss=loss, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
logits=logits, |
|
) |
|
|
|
def sample(self, logits, topk=20, filter_value=-float("Inf"), stopping_criteria=False, eos_id=151666): |
|
logits = logits.reshape(1, -1) |
|
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None] |
|
indices_to_remove[0][eos_id] = True if stopping_criteria is True else indices_to_remove[0][eos_id] |
|
logits[indices_to_remove] = filter_value |
|
token_id = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1).to(torch.long) |
|
return token_id |
|
|
|
|
|
def omni_audio_generation_func( |
|
self, |
|
tts_text, |
|
prompt, |
|
prefix_from_thinker, |
|
vp, |
|
position_ids, |
|
talker_audio_prefix, |
|
vp_insert_loc, |
|
thinker_length, |
|
vp_emb=None, |
|
thinker_reply_part=None, |
|
prompt_text=None, |
|
prompt_speech_token=None, |
|
): |
|
|
|
text_input_part = self.tokenizer.encode(tts_text) |
|
|
|
prompt_text_input_part = self.tokenizer.encode(prompt_text) |
|
prompt_speech_token = prompt_speech_token[0].tolist() |
|
prompt_speech_token_bpe = self.s3bpe_tokenizer.encode(prompt_speech_token)[0] |
|
prompt_speech_token_bpe = (torch.tensor(prompt_speech_token_bpe) + len(self.tokenizer) ).tolist() |
|
|
|
|
|
talker_text_prefix = ( |
|
prompt + |
|
prefix_from_thinker + |
|
vp + |
|
prompt_text_input_part[:1] |
|
) |
|
|
|
talker_text_input_part = ( |
|
prompt_text_input_part[1:] + |
|
text_input_part + |
|
self.tokenizer.encode("<text_eos>") + |
|
self.tokenizer.encode("<text_pad>") |
|
) |
|
|
|
|
|
talker_text_prefix = torch.tensor(talker_text_prefix).reshape(1, -1).to(self.device) |
|
|
|
|
|
|
|
audio_token = self.generate( |
|
talker_audio_prefix=talker_audio_prefix, |
|
talker_text_prefix=talker_text_prefix, |
|
talker_text_input_part=talker_text_input_part, |
|
position_ids=position_ids, |
|
vp_emb=vp_emb, |
|
vp_insert_loc=vp_insert_loc, |
|
thinker_reply_part=thinker_reply_part, |
|
thinker_reply_length=torch.tensor([thinker_length]).to(self.device), |
|
thinker_prefix_insert_loc=torch.tensor([len(prompt) + 1]).to(self.device) if thinker_reply_part is not None else None, |
|
prompt_wav_token=prompt_speech_token_bpe, |
|
) |
|
|
|
audio_token = [ele - len(self.tokenizer) for ele in audio_token] |
|
audio_token = self.s3bpe_tokenizer.decode(audio_token) |
|
audio_token = torch.tensor([audio_token], dtype=torch.int32) |
|
|
|
return audio_token |
|
|
|
def text_length(self, text): |
|
return len(re.findall("[\u4e00-\u4E27\u4E29-\u4E3E\u4E42-\u9fa4]", text)) |
|
|
|
def cut_text(self, text, max_length, tail_min_length=5): |
|
def text_append(text_list, text, max_length): |
|
if len(text_list) == 0: |
|
text_list.append(text) |
|
else: |
|
if len(text_list[-1]) + self.text_length(text) <= max_length: |
|
if text_list[-1].endswith("。") and self.text_length(text) < tail_min_length: |
|
text_list.append(text.lstrip(",")) |
|
else: |
|
text_list[-1] += text |
|
else: |
|
text_list.append(text.lstrip(",")) |
|
return text_list |
|
|
|
text = text.replace("\n", " ") |
|
text = self.normalizer.normalize(text) |
|
text = text.replace("。,", "。") |
|
if len(text) <= max_length: |
|
return [text] |
|
text_list = [] |
|
text = text.replace(".", "。").replace(",", ",") |
|
|
|
sps1 = [] |
|
for t in text.split("。"): |
|
t = t.strip() |
|
if len(t) > 0: |
|
if t[-1] not in "!?,。!?,.": |
|
t += "。" |
|
sps1.append(t) |
|
|
|
for text_piece1 in sps1: |
|
sps2 = [] |
|
for t in text_piece1.split(","): |
|
t = t.strip() |
|
if len(t) > 0: |
|
if t[-1] not in "!?,。!?,.": |
|
t += "," |
|
sps2.append(t) |
|
|
|
for text_piece2 in sps2: |
|
text_piece2 = text_piece2.replace("。,", "。") |
|
if self.text_length(text_piece2) > max_length: |
|
for i in range(0, len(text_piece2), max_length): |
|
text_list = text_append(text_list, text_piece2[i:i+max_length], max_length) |
|
else: |
|
text_list = text_append(text_list, text_piece2, max_length) |
|
return text_list |
|
|
|
def omni_audio_generation( |
|
self, |
|
tts_text, |
|
vp_emb=None, |
|
thinker_reply_part=None, |
|
max_length=50, |
|
prompt_text=None, |
|
prompt_speech_token=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
thinker_length = thinker_reply_part.size(1) if thinker_reply_part is not None else 0 |
|
prefix_from_thinker = ( |
|
self.tokenizer.encode("<thinker_prefix>") + |
|
self.tokenizer.encode("<audio_pad>") * thinker_length + |
|
self.tokenizer.encode("</thinker_prefix>") |
|
) |
|
|
|
prompt = self.tokenizer.encode("<prompt>") + self.tokenizer.encode("</prompt>") |
|
vp = ( |
|
self.tokenizer.encode("<vp>") + |
|
self.tokenizer.encode("<audio_pad>") + |
|
self.tokenizer.encode("</vp>") |
|
) |
|
talker_audio_prefix = ( |
|
prompt + |
|
prefix_from_thinker + |
|
vp + |
|
self.tokenizer.encode("<audio_bos>") |
|
) |
|
attention_mask = torch.ones(len(talker_audio_prefix)).reshape(1, -1).to(self.device) |
|
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill_((attention_mask == 0), 1)[:, -1].view(1, -1) |
|
talker_audio_prefix = torch.tensor(talker_audio_prefix).reshape(1, -1).to(self.device) |
|
vp_insert_loc = torch.tensor(len(prompt) + len(prefix_from_thinker) + 1, dtype=torch.long).reshape(1, -1) |
|
vp_emb = vp_emb.unsqueeze(0).to(torch.bfloat16).to(self.device) |
|
|
|
assert max_length > 0, f"max_length must be greater than 0, but here is {max_length}" |
|
text_list = self.cut_text(tts_text, max_length) |
|
|
|
audio_tokens = [] |
|
for text in text_list: |
|
audio_tokens_piece = self.omni_audio_generation_func( |
|
tts_text=text, |
|
prompt=prompt, |
|
prefix_from_thinker=prefix_from_thinker, |
|
vp=vp, |
|
position_ids=position_ids, |
|
talker_audio_prefix=talker_audio_prefix, |
|
vp_insert_loc=vp_insert_loc, |
|
thinker_length=thinker_length, |
|
vp_emb=vp_emb, |
|
thinker_reply_part=thinker_reply_part, |
|
prompt_text=prompt_text, |
|
prompt_speech_token=prompt_speech_token, |
|
) |
|
audio_tokens.append(audio_tokens_piece) |
|
return audio_tokens |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
talker_audio_prefix: torch.LongTensor, |
|
talker_text_prefix: torch.LongTensor, |
|
talker_text_input_part: List, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
vp_emb: Optional[torch.FloatTensor] = None, |
|
vp_insert_loc: Optional[torch.LongTensor] = None, |
|
thinker_reply_part: Optional[torch.FloatTensor] = None, |
|
thinker_reply_length: Optional[torch.FloatTensor] = None, |
|
thinker_prefix_insert_loc: Optional[torch.LongTensor] = None, |
|
prompt_wav_token: List = [], |
|
min_new_token = 10, |
|
): |
|
result = [] |
|
step = 0 |
|
eos_id = self.tokenizer.encode("<audio_eos>")[0] |
|
prompt_wav_token_len = len(prompt_wav_token) |
|
while step < 1000: |
|
if step == 0: |
|
talker_audio_input_ids = talker_audio_prefix |
|
talker_text_input_ids = talker_text_prefix |
|
attention_mask = torch.ones(talker_audio_input_ids.shape).to(talker_audio_prefix.device) |
|
|
|
else: |
|
talker_audio_input_ids = next_token |
|
talker_text_input_ids = torch.tensor(talker_text_input_part[0], dtype=torch.long).reshape(1, -1).to( |
|
talker_audio_prefix.device) |
|
attention_mask = torch.ones(next_token.shape[0], 1).to(talker_audio_prefix.device) |
|
position_ids += 1 |
|
thinker_prefix_insert_loc = None |
|
|
|
if len(talker_text_input_part) > 1: |
|
talker_text_input_part = talker_text_input_part[1:] |
|
|
|
outputs = self( |
|
input_ids=talker_audio_input_ids, |
|
text_input_ids=talker_text_input_ids, |
|
thinker_reply_part=thinker_reply_part, |
|
thinker_reply_length=thinker_reply_length, |
|
thinker_prefix_insert_loc=thinker_prefix_insert_loc, |
|
vp_emb=vp_emb, |
|
vp_insert_loc=vp_insert_loc, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
use_cache=True, |
|
past_key_values=outputs.past_key_values if step > 0 else None |
|
) |
|
|
|
logits = outputs.logits[:, -1, :] |
|
|
|
stopping_criteria = position_ids.item() < prompt_wav_token_len + min_new_token |
|
next_token = self.sample(logits, stopping_criteria=stopping_criteria ) |
|
if next_token.item() == eos_id: |
|
break |
|
|
|
if len(prompt_wav_token) > 0: |
|
next_token = torch.tensor([[prompt_wav_token[0]]]).to(logits.device) |
|
prompt_wav_token = prompt_wav_token[1:] |
|
else: |
|
result.append(next_token.item()) |
|
step += 1 |
|
|
|
return result |
|
|
|
|
|
class AudioDetokenizer: |
|
def __init__(self, config_path, flow_model_path, hifigan_model_path): |
|
with open(config_path, 'r') as f: |
|
configs = load_hyperpyyaml(f) |
|
|
|
self.model = AudioDetokenizerModel(configs['flow'], configs['hift']) |
|
self.model.load(flow_model_path, hifigan_model_path) |
|
self.sr = 22050 |
|
|
|
def token2wav(self, audio_tokens, save_path=None, **kwargs): |
|
assert isinstance(audio_tokens, list), f"audio_tokens should be list" |
|
speech_list = [] |
|
for audio_token in audio_tokens: |
|
model_input = {"tts_speech_token": audio_token} |
|
kwargs.update(**model_input) |
|
|
|
model_output = self.model.inference(**kwargs) |
|
|
|
silent_dur = 0.02 |
|
silent_tensor = torch.Tensor([0.0] * int(self.sr * silent_dur)) |
|
model_output['tts_speech'][0][:int(self.sr * silent_dur)] = silent_tensor |
|
|
|
speech_list.append(model_output['tts_speech']) |
|
if len(speech_list) == 1: |
|
speech = speech_list[0] |
|
else: |
|
speech = torch.cat(speech_list, dim=1) |
|
if save_path is not None: |
|
torchaudio.save(save_path, speech, sample_rate=self.sr) |
|
return speech |
|
|