Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
import soundfile as sf | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from xcodec2.modeling_xcodec2 import XCodec2Model | |
import tempfile | |
import json | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
#################### | |
# 全局加载模型 | |
#################### | |
llasa_3b = "HKUSTAudio/Llasa-1B-multi-speakers-genshin-zh-en-ja-ko" | |
print("Loading tokenizer & model ...") | |
tokenizer = AutoTokenizer.from_pretrained(llasa_3b) | |
model = AutoModelForCausalLM.from_pretrained(llasa_3b) | |
model.eval().to(device) | |
print("Loading XCodec2Model ...") | |
codec_model_path = "HKUSTAudio/xcodec2" | |
Codec_model = XCodec2Model.from_pretrained(codec_model_path) | |
Codec_model.eval().to(device) | |
print("Models loaded.") | |
prompt_text_dict = json.load(open("Reference_Voice/text.json", "r", encoding="utf-8")) | |
#################### | |
# 推理用函数 | |
#################### | |
def extract_speech_ids(speech_tokens_str): | |
""" | |
将类似 <|s_23456|> 还原为 int 23456 | |
""" | |
speech_ids = [] | |
for token_str in speech_tokens_str: | |
if token_str.startswith("<|s_") and token_str.endswith("|>"): | |
num_str = token_str[4:-2] | |
num = int(num_str) | |
speech_ids.append(num) | |
else: | |
print(f"Unexpected token: {token_str}") | |
return speech_ids | |
def ids_to_speech_tokens(speech_ids): | |
speech_tokens_str = [] | |
for speech_id in speech_ids: | |
speech_tokens_str.append(f"<|s_{speech_id}|>") | |
return speech_tokens_str | |
def text2speech(target_text, game, speaker): | |
""" | |
将文本转为语音波形,并返回音频文件路径 | |
""" | |
prompt_wav, sr = sf.read(f"Reference_Voice/{game}/{speaker}/audio.mp3") | |
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0) | |
if prompt_wav.ndim == 3: | |
prompt_wav = prompt_wav.mean(dim=2) | |
prompt_text = prompt_text_dict[game][speaker] | |
input_text = prompt_text + " " + target_text | |
# read text file in the same directory with name text | |
with torch.no_grad(): | |
# Encode the prompt wav | |
vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav) | |
print("Prompt Vq Code Shape:", vq_code_prompt.shape ) | |
vq_code_prompt = vq_code_prompt[0,0,:] | |
# Convert int 12345 to token <|s_12345|> | |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) | |
# 在输入文本前后拼接提示token | |
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" | |
# Tokenize the text and the speech prefix | |
chat = [ | |
{"role": "user", "content": "Convert the text to speech:" + formatted_text}, | |
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)} | |
] | |
input_ids = tokenizer.apply_chat_template( | |
chat, | |
tokenize=True, | |
return_tensors='pt', | |
continue_final_message=True | |
) | |
input_ids = input_ids.to(device) | |
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') | |
# Generate the speech autoregressively | |
outputs = model.generate( | |
input_ids, | |
max_length=2048, # We trained our model with a max length of 2048 | |
eos_token_id= speech_end_id , | |
do_sample=True, | |
top_p=1, | |
temperature=0.8, | |
) | |
# Extract the speech tokens | |
generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1] | |
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
# Convert token <|s_23456|> to int 23456 | |
speech_tokens = extract_speech_ids(speech_tokens) | |
if torch.cuda.is_available(): | |
# Move speech tokens to GPU | |
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) | |
else: | |
# If CUDA is not available, keep on CPU | |
speech_tokens = torch.tensor(speech_tokens).unsqueeze(0).unsqueeze(0) | |
# Decode the speech tokens to speech waveform | |
gen_wav = Codec_model.decode_code(speech_tokens) | |
# 获取音频数据和采样率 | |
audio = gen_wav[0, 0, :].cpu().numpy() | |
sample_rate = 16000 | |
# 将音频保存到临时文件 | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile: | |
sf.write(tmpfile.name, audio, sample_rate) | |
audio_path = tmpfile.name | |
return audio_path | |
#################### | |
# Gradio 界面 | |
#################### | |
game_choices = [ | |
"HonkaiSR", | |
"Zenless", | |
"Genshin" | |
] | |
speaker_game_dict = { | |
"HonkaiSR": [ | |
"Kafka", "Firefly", "Silverwolf" | |
], | |
"Zenless": [ | |
"Yixuan", "Miyabi", "Jane" | |
], | |
"Genshin": [ | |
"Mavuika", "Navia", "Kokomi", "Furina", "Yoimiya" | |
] | |
} | |
#["puck", "kore"] | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
gr.Markdown("## Text to Speech Generation") | |
with gr.Row(): | |
game = gr.Dropdown(label="Game", choices=game_choices, value="HonkaiSR") | |
speaker = gr.Dropdown(label="Speaker", choices=speaker_game_dict[game.value], value="", allow_custom_value=True) | |
target_text = gr.Textbox(label="Target Text", placeholder="Enter the text you want to convert to speech.") | |
output_audio = gr.Audio(label="Generated Audio", type="filepath") | |
def update_speakers(game): | |
return speaker_game_dict[game] | |
game.change(update_speakers, inputs=game, outputs=speaker) | |
text2speech_button = gr.Button("Generate Speech") | |
text2speech_button.click(text2speech, inputs=[target_text, game, speaker], outputs=output_audio) | |
demo.launch() | |