import random import numpy as np import torch from chatterbox.src.chatterbox.tts import ChatterboxTTS import gradio as gr import spaces import re DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"๐Ÿš€ Running on device: {DEVICE}") # --- Global Model Initialization --- MODEL = None def get_or_load_model(): """Loads the ChatterboxTTS model if it hasn't been loaded already, and ensures it's on the correct device.""" global MODEL if MODEL is None: print("Model not loaded, initializing...") try: MODEL = ChatterboxTTS.from_pretrained(DEVICE) if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: MODEL.to(DEVICE) print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") except Exception as e: print(f"Error loading model: {e}") raise return MODEL # Attempt to load the model at startup. try: get_or_load_model() except Exception as e: print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") def set_seed(seed: int): """Sets the random seed for reproducibility across torch, numpy, and random.""" torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def split_text_into_chunks(text: str, max_chars: int = 250) -> list[str]: """ ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๋˜, ๊ฐ ์ฒญํฌ๊ฐ€ max_chars๋ฅผ ๋„˜์ง€ ์•Š๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. """ # ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌ (๊ธฐ๋ณธ์ ์ธ ๋ฌธ์žฅ ๋ถ„๋ฆฌ) sentences = re.split(r'(?<=[.!?])\s+', text.strip()) chunks = [] current_chunk = "" for sentence in sentences: # ํ˜„์žฌ ์ฒญํฌ์— ๋ฌธ์žฅ์„ ์ถ”๊ฐ€ํ•ด๋„ max_chars๋ฅผ ๋„˜์ง€ ์•Š์œผ๋ฉด ์ถ”๊ฐ€ if len(current_chunk) + len(sentence) + 1 <= max_chars: if current_chunk: current_chunk += " " + sentence else: current_chunk = sentence else: # ํ˜„์žฌ ์ฒญํฌ๋ฅผ ์ €์žฅํ•˜๊ณ  ์ƒˆ ์ฒญํฌ ์‹œ์ž‘ if current_chunk: chunks.append(current_chunk) # ๋ฌธ์žฅ ์ž์ฒด๊ฐ€ max_chars๋ณด๋‹ค ๊ธด ๊ฒฝ์šฐ ๊ฐ•์ œ๋กœ ๋ถ„ํ•  if len(sentence) > max_chars: words = sentence.split() temp_chunk = "" for word in words: if len(temp_chunk) + len(word) + 1 <= max_chars: if temp_chunk: temp_chunk += " " + word else: temp_chunk = word else: if temp_chunk: chunks.append(temp_chunk) temp_chunk = word if temp_chunk: current_chunk = temp_chunk else: current_chunk = sentence # ๋งˆ์ง€๋ง‰ ์ฒญํฌ ์ถ”๊ฐ€ if current_chunk: chunks.append(current_chunk) return chunks @spaces.GPU def generate_tts_audio( text_input: str, audio_prompt_path_input: str, exaggeration_input: float, temperature_input: float, seed_num_input: int, cfgw_input: float, chunk_size_input: int, progress=gr.Progress() ) -> tuple[int, np.ndarray]: """ ๊ธด ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋‚˜๋ˆ„์–ด TTS ์˜ค๋””์˜ค๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์—ฐ๊ฒฐํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋“  ์ฒ˜๋ฆฌ๋ฅผ ๋‹จ์ผ GPU ์ปจํ…์ŠคํŠธ ๋‚ด์—์„œ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. """ current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded.") if seed_num_input != 0: set_seed(int(seed_num_input)) # ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋ถ„ํ•  chunks = split_text_into_chunks(text_input, max_chars=chunk_size_input) total_chunks = len(chunks) print(f"ํ…์ŠคํŠธ๋ฅผ {total_chunks}๊ฐœ์˜ ์ฒญํฌ๋กœ ๋ถ„ํ• ํ–ˆ์Šต๋‹ˆ๋‹ค.") # ๊ฐ ์ฒญํฌ์— ๋Œ€ํ•ด ์˜ค๋””์˜ค ์ƒ์„ฑ audio_segments = [] for i, chunk in enumerate(chunks): progress((i + 1) / total_chunks, f"์ฒญํฌ {i + 1}/{total_chunks} ์ƒ์„ฑ ์ค‘...") print(f"์ฒญํฌ {i + 1}/{total_chunks} ์ƒ์„ฑ ์ค‘: '{chunk[:50]}...'") try: # ์ง์ ‘ generate ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ (๋ณ„๋„ ํ•จ์ˆ˜ ์—†์ด) wav = current_model.generate( chunk, audio_prompt_path=audio_prompt_path_input, exaggeration=exaggeration_input, temperature=temperature_input, cfg_weight=cfgw_input, ) wav_chunk = wav.squeeze(0).numpy() audio_segments.append(wav_chunk) except Exception as e: print(f"์ฒญํฌ {i + 1} ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}") # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ๊ณ„์† ์ง„ํ–‰ continue # ๋ชจ๋“  ์˜ค๋””์˜ค ์„ธ๊ทธ๋จผํŠธ ์—ฐ๊ฒฐ if audio_segments: # ๊ฐ ์ฒญํฌ ์‚ฌ์ด์— ์งง์€ ๋ฌด์Œ ์ถ”๊ฐ€ (์„ ํƒ์‚ฌํ•ญ) silence_duration = int(0.2 * current_model.sr) # 0.2์ดˆ ๋ฌด์Œ silence = np.zeros(silence_duration) final_audio = [] for i, segment in enumerate(audio_segments): final_audio.append(segment) if i < len(audio_segments) - 1: # ๋งˆ์ง€๋ง‰ ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์•„๋‹ˆ๋ฉด ๋ฌด์Œ ์ถ”๊ฐ€ final_audio.append(silence) concatenated_audio = np.concatenate(final_audio) print(f"์˜ค๋””์˜ค ์ƒ์„ฑ ์™„๋ฃŒ. ์ด ๊ธธ์ด: {len(concatenated_audio) / current_model.sr:.2f}์ดˆ") return (current_model.sr, concatenated_audio) else: raise RuntimeError("์˜ค๋””์˜ค ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.") # ๋‹จ์ผ ์ฒญํฌ ์ƒ์„ฑ์„ ์œ„ํ•œ ๊ฐ„๋‹จํ•œ wrapper ํ•จ์ˆ˜ (GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ํฌํ•จ) @spaces.GPU def generate_single_audio( text_input: str, audio_prompt_path_input: str, exaggeration_input: float, temperature_input: float, seed_num_input: int, cfgw_input: float ) -> tuple[int, np.ndarray]: """ ๋‹จ์ผ ํ…์ŠคํŠธ์— ๋Œ€ํ•œ TTS ์˜ค๋””์˜ค ์ƒ์„ฑ (300์ž ์ดํ•˜) """ current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded.") if seed_num_input != 0: set_seed(int(seed_num_input)) print(f"Generating audio for text: '{text_input[:50]}...'") wav = current_model.generate( text_input[:300], # ์•ˆ์ „์„ ์œ„ํ•ด 300์ž๋กœ ์ œํ•œ audio_prompt_path=audio_prompt_path_input, exaggeration=exaggeration_input, temperature=temperature_input, cfg_weight=cfgw_input, ) print("Audio generation complete.") return (current_model.sr, wav.squeeze(0).numpy()) with gr.Blocks() as demo: gr.Markdown( """ # Chatterbox TTS Demo - ๋ฌด์ œํ•œ ๊ธธ์ด ๋ฒ„์ „ ๊ธด ํ…์ŠคํŠธ๋„ ์ฒญํฌ๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌํ•˜์—ฌ ์ œํ•œ ์—†์ด ์Œ์„ฑ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. """ ) with gr.Row(): with gr.Column(): text = gr.Textbox( value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", label="ํ…์ŠคํŠธ ์ž…๋ ฅ (๊ธธ์ด ์ œํ•œ ์—†์Œ)", lines=10, max_lines=30 ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio File (Optional)", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" ) with gr.Row(): exaggeration = gr.Slider( 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5)", value=.5 ) cfg_weight = gr.Slider( 0.2, 1, step=.05, label="CFG/Pace", value=0.5 ) with gr.Row(): chunk_size = gr.Slider( 100, 300, step=50, label="์ฒญํฌ ํฌ๊ธฐ (๋ฌธ์ž ์ˆ˜)", value=250, info="ํ…์ŠคํŠธ๋ฅผ ๋‚˜๋ˆŒ ์ฒญํฌ์˜ ์ตœ๋Œ€ ํฌ๊ธฐ์ž…๋‹ˆ๋‹ค. ์ž‘์„์ˆ˜๋ก ๋” ์ž์—ฐ์Šค๋Ÿฝ์ง€๋งŒ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ๊ธธ์–ด์ง‘๋‹ˆ๋‹ค." ) mode = gr.Radio( choices=["๋‹จ์ผ ์ƒ์„ฑ (300์ž ์ดํ•˜)", "์ฒญํฌ ๋ถ„ํ•  (๋ฌด์ œํ•œ)"], value="์ฒญํฌ ๋ถ„ํ•  (๋ฌด์ œํ•œ)", label="์ƒ์„ฑ ๋ชจ๋“œ" ) with gr.Accordion("๊ณ ๊ธ‰ ์˜ต์…˜", open=False): seed_num = gr.Number(value=0, label="Random seed (0 for random)") temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) run_btn = gr.Button("์Œ์„ฑ ์ƒ์„ฑ", variant="primary") with gr.Column(): audio_output = gr.Audio(label="์ƒ์„ฑ๋œ ์Œ์„ฑ") # ํ…์ŠคํŠธ ๊ธธ์ด ํ‘œ์‹œ char_count = gr.Textbox( label="ํ…์ŠคํŠธ ์ •๋ณด", value="0 ๋ฌธ์ž, ์•ฝ 0๊ฐœ ์ฒญํฌ", interactive=False ) # ํ…์ŠคํŠธ ์ž…๋ ฅ ์‹œ ๋ฌธ์ž ์ˆ˜์™€ ์˜ˆ์ƒ ์ฒญํฌ ์ˆ˜ ์—…๋ฐ์ดํŠธ def update_char_count(text, chunk_size, mode): char_len = len(text) if mode == "๋‹จ์ผ ์ƒ์„ฑ (300์ž ์ดํ•˜)": if char_len > 300: return f"{char_len} ๋ฌธ์ž (โš ๏ธ 300์ž ์ดˆ๊ณผ - ์ž˜๋ฆด ์ˆ˜ ์žˆ์Œ)" else: return f"{char_len} ๋ฌธ์ž" else: chunks = split_text_into_chunks(text, max_chars=chunk_size) chunk_count = len(chunks) return f"{char_len} ๋ฌธ์ž, ์•ฝ {chunk_count}๊ฐœ ์ฒญํฌ๋กœ ๋ถ„ํ• ๋จ" text.change( fn=update_char_count, inputs=[text, chunk_size, mode], outputs=[char_count] ) chunk_size.change( fn=update_char_count, inputs=[text, chunk_size, mode], outputs=[char_count] ) mode.change( fn=update_char_count, inputs=[text, chunk_size, mode], outputs=[char_count] ) # ๋ชจ๋“œ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ํ•จ์ˆ˜ ํ˜ธ์ถœ def process_audio(text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size, mode): if mode == "๋‹จ์ผ ์ƒ์„ฑ (300์ž ์ดํ•˜)": return generate_single_audio(text, ref_wav, exaggeration, temp, seed_num, cfg_weight) else: return generate_tts_audio(text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size) run_btn.click( fn=process_audio, inputs=[ text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size, mode ], outputs=[audio_output], ) gr.Markdown( """ ### ์‚ฌ์šฉ ํŒ: - **๋‹จ์ผ ์ƒ์„ฑ ๋ชจ๋“œ**: 300์ž ์ดํ•˜์˜ ์งง์€ ํ…์ŠคํŠธ์— ์ ํ•ฉํ•˜๋ฉฐ ๋น ๋ฅด๊ฒŒ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค - **์ฒญํฌ ๋ถ„ํ•  ๋ชจ๋“œ**: ๊ธด ํ…์ŠคํŠธ๋ฅผ ์ž๋™์œผ๋กœ ์—ฌ๋Ÿฌ ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค - ์ฒญํฌ ํฌ๊ธฐ๋ฅผ ์กฐ์ ˆํ•˜์—ฌ ํ’ˆ์งˆ๊ณผ ์†๋„์˜ ๊ท ํ˜•์„ ๋งž์ถœ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค - ๊ฐ ์ฒญํฌ ์‚ฌ์ด์—๋Š” ์ž์—ฐ์Šค๋Ÿฌ์šด ์ „ํ™˜์„ ์œ„ํ•ด ์งง์€ ๋ฌด์Œ์ด ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค - ๋งค์šฐ ๊ธด ํ…์ŠคํŠธ์˜ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค """ ) demo.launch()