Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,381 Bytes
7fbdb0f 50adbda 7fbdb0f 39ce78e 7fbdb0f 50adbda 7fbdb0f 50adbda 7fbdb0f 50adbda 7fbdb0f 50adbda 7fbdb0f dc1f428 7fbdb0f 50adbda 7fbdb0f dc1f428 7fbdb0f dc1f428 d5b16d8 dc1f428 d5b16d8 dc1f428 d5b16d8 dc1f428 d5b16d8 dc1f428 7fbdb0f 50adbda 7fbdb0f dc1f428 7fbdb0f dc1f428 561072b 7fbdb0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import spaces
import gradio as gr
import torch
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
from xcodec2.modeling_xcodec2 import XCodec2Model
import tempfile
device = "cuda" if torch.cuda.is_available() else "cpu"
####################
# Global model loading
####################
model_name = "fakeavatar/vtubers-4"
print("Loading tokenizer & model ...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval().to(device)
print("Loading XCodec2Model ...")
codec_model_path = "HKUSTAudio/xcodec2"
Codec_model = XCodec2Model.from_pretrained(codec_model_path)
Codec_model.eval().to(device)
print("Models loaded.")
####################
# Inference function
####################
def extract_speech_ids(speech_tokens_str):
"""
Restore an integer 23456 from tokens like <|s_23456|>
"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith("<|s_") and token_str.endswith("|>"):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
@spaces.GPU
def text2speech(input_text, num_samples):
"""
Convert text to speech waveform and return the audio file path
"""
results = []
with torch.no_grad():
for i in range(0, num_samples):
# Add start and end tokens around the input text
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
chat = [
{"role": "user", "content": "Convert the text to speech:" + formatted_text},
{"role": "assistant", "content": f"<|SPEECH_GENERATION_START|>"}
]
# tokenizer.apply_chat_template is used in the Llasa-style dialogue model
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
).to(device)
# End token
speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
# Text generation
outputs = model.generate(
input_ids,
max_length=2048, # We trained our model with a max length of 2048
eos_token_id=speech_end_id,
do_sample=True,
top_p=0.95, # Adjusts the diversity of generated content
temperature=0.9, # Controls randomness in output
repetition_penalty=1.2,
)
# Extract newly generated tokens (excluding the input part)
generated_ids = outputs[0][input_ids.shape[1]:-1]
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract <|s_23456|> as [23456 ...]
speech_tokens_int = extract_speech_ids(speech_tokens_str)
speech_tokens_int = torch.tensor(speech_tokens_int).to(device).unsqueeze(0).unsqueeze(0)
# Decode waveform using XCodec2Model
gen_wav = Codec_model.decode_code(speech_tokens_int) # [batch, channels, samples]
# Get audio data and sample rate
audio = gen_wav[0, 0, :].cpu().numpy()
sample_rate = 16000
# Save the audio to a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
sf.write(tmpfile.name, audio, sample_rate)
audio_path = tmpfile.name
results.append(audio_path)
while len(results) < 10:
results.append(results[-1])
return results
####################
# Gradio Interface
####################
# Slider to control the number of audio samples to generate
num_samples_slider = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Number of Audio Samples")
demo = gr.Interface(
fn=text2speech,
inputs=[gr.Textbox(label="Enter text", lines=5), num_samples_slider],
outputs=[gr.Audio(label=f"Generated Audio {i+1}", type="numpy") for i in range(10)],
title="VTuber TTS",
description="Input a piece of text in English, and click to generate speech."
)
if __name__ == "__main__":
demo.launch(
share=True ) |