Mejiro J
compatible with cpu instance
21e13bb
import spaces
import gradio as gr
import torch
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
from xcodec2.modeling_xcodec2 import XCodec2Model
import tempfile
import json
device = "cuda" if torch.cuda.is_available() else "cpu"
####################
# 全局加载模型
####################
llasa_3b = "HKUSTAudio/Llasa-1B-multi-speakers-genshin-zh-en-ja-ko"
print("Loading tokenizer & model ...")
tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
model = AutoModelForCausalLM.from_pretrained(llasa_3b)
model.eval().to(device)
print("Loading XCodec2Model ...")
codec_model_path = "HKUSTAudio/xcodec2"
Codec_model = XCodec2Model.from_pretrained(codec_model_path)
Codec_model.eval().to(device)
print("Models loaded.")
prompt_text_dict = json.load(open("Reference_Voice/text.json", "r", encoding="utf-8"))
####################
# 推理用函数
####################
def extract_speech_ids(speech_tokens_str):
"""
将类似 <|s_23456|> 还原为 int 23456
"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith("<|s_") and token_str.endswith("|>"):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def ids_to_speech_tokens(speech_ids):
speech_tokens_str = []
for speech_id in speech_ids:
speech_tokens_str.append(f"<|s_{speech_id}|>")
return speech_tokens_str
@spaces.GPU
def text2speech(target_text, game, speaker):
"""
将文本转为语音波形,并返回音频文件路径
"""
prompt_wav, sr = sf.read(f"Reference_Voice/{game}/{speaker}/audio.mp3")
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0)
if prompt_wav.ndim == 3:
prompt_wav = prompt_wav.mean(dim=2)
prompt_text = prompt_text_dict[game][speaker]
input_text = prompt_text + " " + target_text
# read text file in the same directory with name text
with torch.no_grad():
# Encode the prompt wav
vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
print("Prompt Vq Code Shape:", vq_code_prompt.shape )
vq_code_prompt = vq_code_prompt[0,0,:]
# Convert int 12345 to token <|s_12345|>
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
# 在输入文本前后拼接提示token
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
# Tokenize the text and the speech prefix
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
]
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids = input_ids.to(device)
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
# Generate the speech autoregressively
outputs = model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id= speech_end_id ,
do_sample=True,
top_p=1,
temperature=0.8,
)
# Extract the speech tokens
generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Convert token <|s_23456|> to int 23456
speech_tokens = extract_speech_ids(speech_tokens)
if torch.cuda.is_available():
# Move speech tokens to GPU
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
else:
# If CUDA is not available, keep on CPU
speech_tokens = torch.tensor(speech_tokens).unsqueeze(0).unsqueeze(0)
# Decode the speech tokens to speech waveform
gen_wav = Codec_model.decode_code(speech_tokens)
# 获取音频数据和采样率
audio = gen_wav[0, 0, :].cpu().numpy()
sample_rate = 16000
# 将音频保存到临时文件
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, audio, sample_rate)
audio_path = tmpfile.name
return audio_path
####################
# Gradio 界面
####################
game_choices = [
"HonkaiSR",
"Zenless",
"Genshin"
]
speaker_game_dict = {
"HonkaiSR": [
"Kafka", "Firefly", "Silverwolf"
],
"Zenless": [
"Yixuan", "Miyabi", "Jane"
],
"Genshin": [
"Mavuika", "Navia", "Kokomi", "Furina", "Yoimiya"
]
}
#["puck", "kore"]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.Markdown("## Text to Speech Generation")
with gr.Row():
game = gr.Dropdown(label="Game", choices=game_choices, value="HonkaiSR")
speaker = gr.Dropdown(label="Speaker", choices=speaker_game_dict[game.value], value="", allow_custom_value=True)
target_text = gr.Textbox(label="Target Text", placeholder="Enter the text you want to convert to speech.")
output_audio = gr.Audio(label="Generated Audio", type="filepath")
def update_speakers(game):
return speaker_game_dict[game]
game.change(update_speakers, inputs=game, outputs=speaker)
text2speech_button = gr.Button("Generate Speech")
text2speech_button.click(text2speech, inputs=[target_text, game, speaker], outputs=output_audio)
demo.launch()