import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import soundfile as sf from xcodec2.modeling_xcodec2 import XCodec2Model import torchaudio import gradio as gr import re llasa_model_id = 'OmniAICreator/Galgame-Llasa-1B-v3' tokenizer = AutoTokenizer.from_pretrained(llasa_model_id) model = AutoModelForCausalLM.from_pretrained( llasa_model_id, trust_remote_code=True, ) model.eval().cuda() xcodec2_model_id = "HKUSTAudio/xcodec2" codec_model = XCodec2Model.from_pretrained(xcodec2_model_id) codec_model.eval().cuda() whisper_turbo_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16, device='cuda', ) REPLACE_MAP: dict[str, str] = { r"\t": "", r"\[n\]": "", r" ": "", r" ": "", r"[;▼♀♂《》≪≫①②③④⑤⑥]": "", r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "", r"[\uff5e\u301C]": "ー", r"?": "?", r"!": "!", r"[●◯〇]": "○", r"♥": "♡", } FULLWIDTH_ALPHA_TO_HALFWIDTH = str.maketrans( { chr(full): chr(half) for full, half in zip( list(range(0xFF21, 0xFF3B)) + list(range(0xFF41, 0xFF5B)), list(range(0x41, 0x5B)) + list(range(0x61, 0x7B)), ) } ) HALFWIDTH_KATAKANA_TO_FULLWIDTH = str.maketrans( { chr(half): chr(full) for half, full in zip(range(0xFF61, 0xFF9F), range(0x30A1, 0x30FB)) } ) FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans( { chr(full): chr(half) for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A)) } ) INVALID_PATTERN = re.compile( r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" r"\u0041-\u005A\u0061-\u007A" r"\u0030-\u0039" r"。、!?…♪♡○]" ) def normalize(text: str) -> str: for pattern, replacement in REPLACE_MAP.items(): text = re.sub(pattern, replacement, text) text = text.translate(FULLWIDTH_ALPHA_TO_HALFWIDTH) text = text.translate(FULLWIDTH_DIGITS_TO_HALFWIDTH) text = text.translate(HALFWIDTH_KATAKANA_TO_FULLWIDTH) text = re.sub(r"…{3,}", "……", text) def replace_special_chars(match): seq = match.group(0) return seq[0] if len(set(seq)) == 1 else seq[0] + seq[-1] return text def ids_to_speech_tokens(speech_ids): speech_tokens_str = [] for speech_id in speech_ids: speech_tokens_str.append(f"<|s_{speech_id}|>") return speech_tokens_str def extract_speech_ids(speech_tokens_str): speech_ids = [] for token_str in speech_tokens_str: if token_str.startswith('<|s_') and token_str.endswith('|>'): num_str = token_str[4:-2] num = int(num_str) speech_ids.append(num) else: print(f"Unexpected token: {token_str}") return speech_ids @spaces.GPU(duration=60) def infer(sample_audio_path, target_text, temperature, top_p, repetition_penalty, progress=gr.Progress()): if not target_text or not target_text.strip(): gr.Warning("Please input text to generate audio.") return None, None if len(target_text) > 300: gr.Warning("Text is too long. Please keep it under 300 characters.") target_text = target_text[:300] target_text = normalize(target_text) with torch.no_grad(): if sample_audio_path: progress(0, 'Loading and trimming audio...') waveform, sample_rate = torchaudio.load(sample_audio_path) if len(waveform[0])/sample_rate > 15: gr.Warning("Trimming audio to first 15secs.") waveform = waveform[:, :sample_rate*15] # Check if the audio is stereo (i.e., has more than one channel) if waveform.size(0) > 1: # Convert stereo to mono by averaging the channels waveform_mono = torch.mean(waveform, dim=0, keepdim=True) else: # If already mono, just use the original waveform waveform_mono = waveform prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono) prompt_wav_len = prompt_wav.shape[1] prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip() progress(0.5, 'Transcribed! Encoding audio...') # Encode the prompt wav vq_code_prompt = codec_model.encode_code(input_waveform=prompt_wav)[0, 0, :] # Convert int 12345 to token <|s_12345|> speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) input_text = prompt_text + ' ' + target_text assistant_content = "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix) else: progress(0, "Preparing...") input_text = target_text assistant_content = "<|SPEECH_GENERATION_START|>" speech_ids_prefix = [] prompt_wav_len = 0 progress(0.75, "Generating audio...") formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" # Tokenize the text and the speech prefix chat = [ {"role": "user", "content": "Convert the text to speech:" + formatted_text}, {"role": "assistant", "content": assistant_content} ] input_ids = tokenizer.apply_chat_template( chat, tokenize=True, return_tensors='pt', continue_final_message=True ).to('cuda') speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>') # Generate the speech autoregressively outputs = model.generate( input_ids, max_length=2048, # We trained our model with a max length of 2048 eos_token_id=speech_end_id, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, ) # Extract the speech tokens if sample_audio_path: generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1] else: generated_ids = outputs[0][input_ids.shape[1]:-1] speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) # Convert token <|s_23456|> to int 23456 speech_tokens = extract_speech_ids(speech_tokens) if not speech_tokens: gr.Error("Audio generation failed.") return None speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) # Decode the speech tokens to speech waveform gen_wav = codec_model.decode_code(speech_tokens) # if only need the generated part if sample_audio_path and prompt_wav_len > 0: gen_wav = gen_wav[:, :, prompt_wav_len:] progress(1, 'Synthesized!') return (16000, gen_wav[0, 0, :].cpu().numpy()) with gr.Blocks() as app_tts: gr.Markdown("# Galgame Llasa 1B v3") ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") gen_text_input = gr.Textbox(label="Text to Generate", lines=10) with gr.Row(): temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Temperature") top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p") repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=1.5, value=1.1, step=0.05, label="Repetition Penalty") generate_btn = gr.Button("Synthesize", variant="primary") audio_output = gr.Audio(label="Synthesized Audio") generate_btn.click( infer, inputs=[ ref_audio_input, gen_text_input, temperature_slider, top_p_slider, repetition_penalty_slider, ], outputs=[audio_output], ) with gr.Blocks() as app_credits: gr.Markdown(""" # Credits * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training) * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) * [SunderAli17](https://huggingface.co/SunderAli17) for the [gradio demo code](https://huggingface.co/spaces/SunderAli17/llasa-3b-tts) """) with gr.Blocks() as app: gr.Markdown( """ # Galgame Llasa 1B v3 This is a local web UI for Galgame Llasa 1B v3 TTS model. You can check out the model [here](https://huggingface.co/OmniAICreator/Galgame-Llasa-1B-v3). The model is fine-tuned by Japanese audio data. If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. """ ) gr.TabbedInterface([app_tts], ["TTS"]) app.launch()