jhansss commited on
Commit
b03522b
·
1 Parent(s): 1df5822

Refactor run_pipeline and update_metrics methods to use a global pipeline instance and improve parameter handling

Browse files
Files changed (1) hide show
  1. interface.py +89 -29
interface.py CHANGED
@@ -1,6 +1,5 @@
1
  import time
2
  import uuid
3
- from functools import partial
4
 
5
  import gradio as gr
6
  import spaces
@@ -9,35 +8,85 @@ import yaml
9
  from characters import CHARACTERS
10
  from pipeline import SingingDialoguePipeline
11
 
 
 
 
 
 
 
 
 
 
12
 
13
  @spaces.GPU(duration=120)
14
- def run_pipeline(audio_path, interface):
 
 
15
  if not audio_path:
16
- return gr.update(value=None), gr.update(value=None)
 
 
 
17
  tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
18
- results = interface.pipeline.run(
19
  audio_path,
20
- interface.svs_model_map[interface.current_svs_model]["lang"],
21
- interface.character_info[interface.current_character].prompt,
22
- interface.current_voice,
23
  output_audio_path=tmp_file,
24
  )
 
25
  formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
26
- return gr.update(value=formatted_logs), gr.update(
27
- value=results["output_audio_path"]
 
 
28
  )
29
 
30
 
31
  @spaces.GPU(duration=120)
32
- def update_metrics(audio_path, interface):
33
- if not audio_path or not interface.results:
 
 
34
  return gr.update(value="")
35
- results = interface.pipeline.evaluate(audio_path, **interface.results)
36
- results.update(interface.results.get("metrics", {}))
 
 
 
37
  formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
38
  return gr.update(value=formatted_metrics)
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class GradioInterface:
42
  def __init__(self, options_config: str, default_config: str):
43
  self.options = self.load_config(options_config)
@@ -53,7 +102,6 @@ class GradioInterface:
53
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
54
  self.character_info[self.current_character].default_voice
55
  ]
56
- self.pipeline = SingingDialoguePipeline(self.default_config)
57
  self.results = None
58
 
59
  def load_config(self, path: str):
@@ -177,18 +225,18 @@ class GradioInterface:
177
  fn=self.update_voice, inputs=voice_radio, outputs=voice_radio
178
  )
179
  mic_input.change(
180
- fn=partial(run_pipeline, interface=self),
181
  inputs=mic_input,
182
  outputs=[interaction_log, audio_output],
183
  )
184
  metrics_button.click(
185
- fn=partial(update_metrics, interface=self),
186
  inputs=audio_output,
187
  outputs=[metrics_output],
188
  )
189
 
190
  return demo
191
- except Exception as e:
192
  import traceback
193
 
194
  print(traceback.format_exc())
@@ -205,12 +253,12 @@ class GradioInterface:
205
  )
206
 
207
  def update_asr_model(self, asr_model):
208
- self.pipeline.set_asr_model(asr_model)
209
- return gr.update(value=asr_model)
210
 
211
  def update_llm_model(self, llm_model):
212
- self.pipeline.set_llm_model(llm_model)
213
- return gr.update(value=llm_model)
214
 
215
  def update_svs_model(self, svs_model):
216
  self.current_svs_model = svs_model
@@ -218,12 +266,9 @@ class GradioInterface:
218
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
219
  character_voice
220
  ]
221
- self.pipeline.set_svs_model(
222
- self.svs_model_map[self.current_svs_model]["model_path"]
223
- )
224
- print(
225
- f"SVS model updated to {self.current_svs_model}. Will set gradio svs_radio to {svs_model} and voice_radio to {character_voice}"
226
- )
227
  return (
228
  gr.update(value=svs_model),
229
  gr.update(
@@ -236,9 +281,24 @@ class GradioInterface:
236
 
237
  def update_melody_source(self, melody_source):
238
  self.current_melody_source = melody_source
239
- self.pipeline.set_melody_controller(melody_source)
240
- return gr.update(value=self.current_melody_source)
241
 
242
  def update_voice(self, voice):
243
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
244
  return gr.update(value=voice)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
  import uuid
 
3
 
4
  import gradio as gr
5
  import spaces
 
8
  from characters import CHARACTERS
9
  from pipeline import SingingDialoguePipeline
10
 
11
+ pipe = None
12
+
13
+
14
+ def _ensure_pipeline(config):
15
+ """Ensure pipeline is initialized in GPU worker context."""
16
+ global pipe
17
+ if pipe is None:
18
+ pipe = SingingDialoguePipeline(config)
19
+
20
 
21
  @spaces.GPU(duration=120)
22
+ def run_pipeline(audio_path, config, svs_model_info, character_prompt, current_voice):
23
+ global pipe
24
+
25
  if not audio_path:
26
+ return gr.update(value=None), gr.update(value=None), None
27
+
28
+ _ensure_pipeline(config)
29
+
30
  tmp_file = f"audio_{int(time.time())}_{uuid.uuid4().hex[:8]}.wav"
31
+ results = pipe.run(
32
  audio_path,
33
+ svs_model_info["lang"],
34
+ character_prompt,
35
+ current_voice,
36
  output_audio_path=tmp_file,
37
  )
38
+
39
  formatted_logs = f"ASR: {results['asr_text']}\nLLM: {results['llm_text']}"
40
+ return (
41
+ gr.update(value=formatted_logs),
42
+ gr.update(value=results["output_audio_path"]),
43
+ results,
44
  )
45
 
46
 
47
  @spaces.GPU(duration=120)
48
+ def update_metrics(audio_path, config, results_data):
49
+ global pipe
50
+
51
+ if not audio_path or not results_data:
52
  return gr.update(value="")
53
+
54
+ _ensure_pipeline(config)
55
+
56
+ results = pipe.evaluate(audio_path, **results_data)
57
+ results.update(results_data.get("metrics", {}))
58
  formatted_metrics = "\n".join([f"{k}: {v}" for k, v in results.items()])
59
  return gr.update(value=formatted_metrics)
60
 
61
 
62
+ @spaces.GPU(duration=120)
63
+ def update_asr_model_in_pipeline(config, asr_model):
64
+ _ensure_pipeline(config)
65
+ pipe.set_asr_model(asr_model)
66
+ return gr.update(value=asr_model)
67
+
68
+
69
+ @spaces.GPU(duration=120)
70
+ def update_llm_model_in_pipeline(config, llm_model):
71
+ _ensure_pipeline(config)
72
+ pipe.set_llm_model(llm_model)
73
+ return gr.update(value=llm_model)
74
+
75
+
76
+ @spaces.GPU(duration=120)
77
+ def update_svs_model_in_pipeline(config, svs_model_path):
78
+ _ensure_pipeline(config)
79
+ pipe.set_svs_model(svs_model_path)
80
+ return gr.update()
81
+
82
+
83
+ @spaces.GPU(duration=120)
84
+ def update_melody_source_in_pipeline(config, melody_source):
85
+ _ensure_pipeline(config)
86
+ pipe.set_melody_controller(melody_source)
87
+ return gr.update(value=melody_source)
88
+
89
+
90
  class GradioInterface:
91
  def __init__(self, options_config: str, default_config: str):
92
  self.options = self.load_config(options_config)
 
102
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
103
  self.character_info[self.current_character].default_voice
104
  ]
 
105
  self.results = None
106
 
107
  def load_config(self, path: str):
 
225
  fn=self.update_voice, inputs=voice_radio, outputs=voice_radio
226
  )
227
  mic_input.change(
228
+ fn=self._run_pipeline_wrapper,
229
  inputs=mic_input,
230
  outputs=[interaction_log, audio_output],
231
  )
232
  metrics_button.click(
233
+ fn=self._update_metrics_wrapper,
234
  inputs=audio_output,
235
  outputs=[metrics_output],
236
  )
237
 
238
  return demo
239
+ except Exception:
240
  import traceback
241
 
242
  print(traceback.format_exc())
 
253
  )
254
 
255
  def update_asr_model(self, asr_model):
256
+ self.default_config["asr_model"] = asr_model
257
+ return update_asr_model_in_pipeline(self.default_config, asr_model)
258
 
259
  def update_llm_model(self, llm_model):
260
+ self.default_config["llm_model"] = llm_model
261
+ return update_llm_model_in_pipeline(self.default_config, llm_model)
262
 
263
  def update_svs_model(self, svs_model):
264
  self.current_svs_model = svs_model
 
266
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][
267
  character_voice
268
  ]
269
+ svs_model_path = self.svs_model_map[self.current_svs_model]["model_path"]
270
+ self.default_config["svs_model"] = svs_model_path
271
+ update_svs_model_in_pipeline(self.default_config, svs_model_path)
 
 
 
272
  return (
273
  gr.update(value=svs_model),
274
  gr.update(
 
281
 
282
  def update_melody_source(self, melody_source):
283
  self.current_melody_source = melody_source
284
+ self.default_config["melody_source"] = melody_source
285
+ return update_melody_source_in_pipeline(self.default_config, melody_source)
286
 
287
  def update_voice(self, voice):
288
  self.current_voice = self.svs_model_map[self.current_svs_model]["voices"][voice]
289
  return gr.update(value=voice)
290
+
291
+ def _run_pipeline_wrapper(self, audio_path):
292
+ log_update, audio_update, pipeline_results = run_pipeline(
293
+ audio_path,
294
+ self.default_config,
295
+ self.svs_model_map[self.current_svs_model],
296
+ self.character_info[self.current_character].prompt,
297
+ self.current_voice,
298
+ )
299
+ if pipeline_results:
300
+ self.results = pipeline_results
301
+ return log_update, audio_update
302
+
303
+ def _update_metrics_wrapper(self, audio_path):
304
+ return update_metrics(audio_path, self.default_config, self.results or {})