ollieollie commited on
Commit
5cd7466
·
verified ·
1 Parent(s): bff7fc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -23
app.py CHANGED
@@ -6,8 +6,11 @@ import gradio as gr
6
  import spaces
7
  from chatterbox.tts_turbo import ChatterboxTurboTTS
8
 
9
- # Check for GPU, but ZeroGPU handles the actual assignment dynamically
10
- DEVICE = "cpu"
 
 
 
11
 
12
  EVENT_TAGS = [
13
  "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
@@ -73,18 +76,14 @@ def set_seed(seed: int):
73
  np.random.seed(seed)
74
 
75
 
76
- # We don't need to decorate load_model, it runs on CPU or during startup
77
  def load_model():
 
78
  print(f"Loading Chatterbox-Turbo on {DEVICE}...")
79
- model = ChatterboxTurboTTS.from_pretrained(DEVICE)
80
- return model
81
 
82
- # --- 2. THE CRITICAL DECORATOR ---
83
- # This tells ZeroGPU to assign a GPU to this specific function call.
84
- # The duration param is optional but helps with scheduling (e.g. 60s limit).
85
  @spaces.GPU
86
  def generate(
87
- model,
88
  text,
89
  audio_prompt_path,
90
  temperature,
@@ -95,16 +94,18 @@ def generate(
95
  repetition_penalty,
96
  norm_loudness
97
  ):
98
- # Reload model inside the GPU context if it was lost (ZeroGPU quirk)
99
- if model is None:
100
- model = ChatterboxTurboTTS.from_pretrained("cpu")
 
 
 
 
101
 
102
- model.to("cuda")
103
-
104
  if seed_num != 0:
105
  set_seed(int(seed_num))
106
 
107
- wav = model.generate(
108
  text,
109
  audio_prompt_path=audio_prompt_path,
110
  temperature=temperature,
@@ -114,18 +115,17 @@ def generate(
114
  repetition_penalty=repetition_penalty,
115
  norm_loudness=norm_loudness,
116
  )
117
- return (model.sr, wav.squeeze(0).numpy())
 
118
 
119
 
120
  with gr.Blocks(title="Chatterbox Turbo") as demo:
121
  gr.Markdown("# ⚡ Chatterbox Turbo")
122
 
123
- model_state = gr.State(None)
124
-
125
  with gr.Row():
126
  with gr.Column():
127
  text = gr.Textbox(
128
- value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
129
  label="Text to synthesize (max chars 300)",
130
  max_lines=5,
131
  elem_id="main_textbox"
@@ -162,13 +162,12 @@ with gr.Blocks(title="Chatterbox Turbo") as demo:
162
  min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
163
  norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (Match prompt volume)")
164
 
165
-
166
- demo.load(fn=load_model, inputs=[], outputs=model_state)
167
 
168
  run_btn.click(
169
  fn=generate,
170
  inputs=[
171
- model_state,
172
  text,
173
  ref_wav,
174
  temp,
@@ -182,4 +181,9 @@ with gr.Blocks(title="Chatterbox Turbo") as demo:
182
  outputs=audio_output,
183
  )
184
 
185
- demo.launch(mcp_server=True, css=CUSTOM_CSS)
 
 
 
 
 
 
6
  import spaces
7
  from chatterbox.tts_turbo import ChatterboxTurboTTS
8
 
9
+ # --- 1. FORCE CPU FOR GLOBAL LOADING ---
10
+ # ZeroGPU forbids CUDA during startup. We only move to CUDA inside the decorated function.
11
+ DEVICE = "cpu"
12
+
13
+ MODEL = None
14
 
15
  EVENT_TAGS = [
16
  "[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
 
76
  np.random.seed(seed)
77
 
78
 
 
79
  def load_model():
80
+ global MODEL
81
  print(f"Loading Chatterbox-Turbo on {DEVICE}...")
82
+ MODEL = ChatterboxTurboTTS.from_pretrained(DEVICE)
83
+ return MODEL
84
 
 
 
 
85
  @spaces.GPU
86
  def generate(
 
87
  text,
88
  audio_prompt_path,
89
  temperature,
 
94
  repetition_penalty,
95
  norm_loudness
96
  ):
97
+ global MODEL
98
+ # Reload if the worker lost the global state
99
+ if MODEL is None:
100
+ MODEL = ChatterboxTurboTTS.from_pretrained("cpu")
101
+
102
+ # --- MOVE TO GPU HERE ---
103
+ MODEL.to("cuda")
104
 
 
 
105
  if seed_num != 0:
106
  set_seed(int(seed_num))
107
 
108
+ wav = MODEL.generate(
109
  text,
110
  audio_prompt_path=audio_prompt_path,
111
  temperature=temperature,
 
115
  repetition_penalty=repetition_penalty,
116
  norm_loudness=norm_loudness,
117
  )
118
+
119
+ return (MODEL.sr, wav.squeeze(0).cpu().numpy())
120
 
121
 
122
  with gr.Blocks(title="Chatterbox Turbo") as demo:
123
  gr.Markdown("# ⚡ Chatterbox Turbo")
124
 
 
 
125
  with gr.Row():
126
  with gr.Column():
127
  text = gr.Textbox(
128
+ value="Congratulations Miss Connor! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?",
129
  label="Text to synthesize (max chars 300)",
130
  max_lines=5,
131
  elem_id="main_textbox"
 
162
  min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
163
  norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (Match prompt volume)")
164
 
165
+ # Load on startup (CPU)
166
+ demo.load(fn=load_model, inputs=[], outputs=[])
167
 
168
  run_btn.click(
169
  fn=generate,
170
  inputs=[
 
171
  text,
172
  ref_wav,
173
  temp,
 
181
  outputs=audio_output,
182
  )
183
 
184
+ if __name__ == "__main__":
185
+ demo.queue().launch(
186
+ mcp_server=True,
187
+ css=CUSTOM_CSS,
188
+ ssr_mode=False
189
+ )