Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
import soundfile as sf | |
from xcodec2.modeling_xcodec2 import XCodec2Model | |
import torchaudio | |
import gradio as gr | |
import re | |
llasa_model_id = 'OmniAICreator/Galgame-Llasa-1B-v3' | |
tokenizer = AutoTokenizer.from_pretrained(llasa_model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
llasa_model_id, | |
trust_remote_code=True, | |
) | |
model.eval().cuda() | |
xcodec2_model_id = "HKUSTAudio/xcodec2" | |
codec_model = XCodec2Model.from_pretrained(xcodec2_model_id) | |
codec_model.eval().cuda() | |
whisper_turbo_pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
torch_dtype=torch.float16, | |
device='cuda', | |
) | |
REPLACE_MAP: dict[str, str] = { | |
r"\t": "", | |
r"\[n\]": "", | |
r" ": "", | |
r" ": "", | |
r"[;▼♀♂《》≪≫①②③④⑤⑥]": "", | |
r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "", | |
r"[\uff5e\u301C]": "ー", | |
r"?": "?", | |
r"!": "!", | |
r"[●◯〇]": "○", | |
r"♥": "♡", | |
} | |
FULLWIDTH_ALPHA_TO_HALFWIDTH = str.maketrans( | |
{ | |
chr(full): chr(half) | |
for full, half in zip( | |
list(range(0xFF21, 0xFF3B)) + list(range(0xFF41, 0xFF5B)), | |
list(range(0x41, 0x5B)) + list(range(0x61, 0x7B)), | |
) | |
} | |
) | |
HALFWIDTH_KATAKANA_TO_FULLWIDTH = str.maketrans( | |
{ | |
chr(half): chr(full) | |
for half, full in zip(range(0xFF61, 0xFF9F), range(0x30A1, 0x30FB)) | |
} | |
) | |
FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans( | |
{ | |
chr(full): chr(half) | |
for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A)) | |
} | |
) | |
INVALID_PATTERN = re.compile( | |
r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" | |
r"\u0041-\u005A\u0061-\u007A" | |
r"\u0030-\u0039" | |
r"。、!?…♪♡○]" | |
) | |
def normalize(text: str) -> str: | |
for pattern, replacement in REPLACE_MAP.items(): | |
text = re.sub(pattern, replacement, text) | |
text = text.translate(FULLWIDTH_ALPHA_TO_HALFWIDTH) | |
text = text.translate(FULLWIDTH_DIGITS_TO_HALFWIDTH) | |
text = text.translate(HALFWIDTH_KATAKANA_TO_FULLWIDTH) | |
text = re.sub(r"…{3,}", "……", text) | |
def replace_special_chars(match): | |
seq = match.group(0) | |
return seq[0] if len(set(seq)) == 1 else seq[0] + seq[-1] | |
return text | |
def ids_to_speech_tokens(speech_ids): | |
speech_tokens_str = [] | |
for speech_id in speech_ids: | |
speech_tokens_str.append(f"<|s_{speech_id}|>") | |
return speech_tokens_str | |
def extract_speech_ids(speech_tokens_str): | |
speech_ids = [] | |
for token_str in speech_tokens_str: | |
if token_str.startswith('<|s_') and token_str.endswith('|>'): | |
num_str = token_str[4:-2] | |
num = int(num_str) | |
speech_ids.append(num) | |
else: | |
print(f"Unexpected token: {token_str}") | |
return speech_ids | |
def infer(sample_audio_path, target_text, temperature, top_p, repetition_penalty, progress=gr.Progress()): | |
if not target_text or not target_text.strip(): | |
gr.Warning("Please input text to generate audio.") | |
return None, None | |
if len(target_text) > 300: | |
gr.Warning("Text is too long. Please keep it under 300 characters.") | |
target_text = target_text[:300] | |
target_text = normalize(target_text) | |
with torch.no_grad(): | |
if sample_audio_path: | |
progress(0, 'Loading and trimming audio...') | |
waveform, sample_rate = torchaudio.load(sample_audio_path) | |
if len(waveform[0])/sample_rate > 15: | |
gr.Warning("Trimming audio to first 15secs.") | |
waveform = waveform[:, :sample_rate*15] | |
# Check if the audio is stereo (i.e., has more than one channel) | |
if waveform.size(0) > 1: | |
# Convert stereo to mono by averaging the channels | |
waveform_mono = torch.mean(waveform, dim=0, keepdim=True) | |
else: | |
# If already mono, just use the original waveform | |
waveform_mono = waveform | |
prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono) | |
prompt_wav_len = prompt_wav.shape[1] | |
prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip() | |
progress(0.5, 'Transcribed! Encoding audio...') | |
# Encode the prompt wav | |
vq_code_prompt = codec_model.encode_code(input_waveform=prompt_wav)[0, 0, :] | |
# Convert int 12345 to token <|s_12345|> | |
speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) | |
input_text = prompt_text + ' ' + target_text | |
assistant_content = "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix) | |
else: | |
progress(0, "Preparing...") | |
input_text = target_text | |
assistant_content = "<|SPEECH_GENERATION_START|>" | |
speech_ids_prefix = [] | |
prompt_wav_len = 0 | |
progress(0.75, "Generating audio...") | |
formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" | |
# Tokenize the text and the speech prefix | |
chat = [ | |
{"role": "user", "content": "Convert the text to speech:" + formatted_text}, | |
{"role": "assistant", "content": assistant_content} | |
] | |
input_ids = tokenizer.apply_chat_template( | |
chat, | |
tokenize=True, | |
return_tensors='pt', | |
continue_final_message=True | |
).to('cuda') | |
speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') | |
# Generate the speech autoregressively | |
outputs = model.generate( | |
input_ids, | |
max_length=2048, # We trained our model with a max length of 2048 | |
eos_token_id=speech_end_id, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
# Extract the speech tokens | |
if sample_audio_path: | |
generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1] | |
else: | |
generated_ids = outputs[0][input_ids.shape[1]:-1] | |
speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
# Convert token <|s_23456|> to int 23456 | |
speech_tokens = extract_speech_ids(speech_tokens) | |
if not speech_tokens: | |
gr.Error("Audio generation failed.") | |
return None | |
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) | |
# Decode the speech tokens to speech waveform | |
gen_wav = codec_model.decode_code(speech_tokens) | |
# if only need the generated part | |
if sample_audio_path and prompt_wav_len > 0: | |
gen_wav = gen_wav[:, :, prompt_wav_len:] | |
progress(1, 'Synthesized!') | |
return (16000, gen_wav[0, 0, :].cpu().numpy()) | |
with gr.Blocks() as app_tts: | |
gr.Markdown("# Galgame Llasa 1B v3") | |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
with gr.Row(): | |
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p") | |
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=1.5, value=1.1, step=0.05, label="Repetition Penalty") | |
generate_btn = gr.Button("Synthesize", variant="primary") | |
audio_output = gr.Audio(label="Synthesized Audio") | |
generate_btn.click( | |
infer, | |
inputs=[ | |
ref_audio_input, | |
gen_text_input, | |
temperature_slider, | |
top_p_slider, | |
repetition_penalty_slider, | |
], | |
outputs=[audio_output], | |
) | |
with gr.Blocks() as app_credits: | |
gr.Markdown(""" | |
# Credits | |
* [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training) | |
* [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) | |
* [SunderAli17](https://huggingface.co/SunderAli17) for the [gradio demo code](https://huggingface.co/spaces/SunderAli17/llasa-3b-tts) | |
""") | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# Galgame Llasa 1B v3 | |
This is a local web UI for Galgame Llasa 1B v3 TTS model. You can check out the model [here](https://huggingface.co/OmniAICreator/Galgame-Llasa-1B-v3). | |
The model is fine-tuned by Japanese audio data. | |
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. | |
""" | |
) | |
gr.TabbedInterface([app_tts], ["TTS"]) | |
app.launch() |