monatolmats commited on
Commit
d5b5af5
·
verified ·
1 Parent(s): cd817b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -38
app.py CHANGED
@@ -1,93 +1,131 @@
 
1
  import gradio as gr
2
  import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
3
  from types import SimpleNamespace
4
  from utmosv2.utils import get_model
5
 
 
 
 
6
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
- MAX_LEN = 160_000 # 10 s @16 kHz
8
 
9
- # ---- mudel laetakse korra kogu Space’i eluea jooksul ----
 
 
 
 
 
 
 
 
10
  cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
11
  cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
12
  cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
13
- cfg.weight = "utmosv2_estonian.pth"
14
- model = get_model(cfg, DEVICE).eval()
15
- specs_cfg = cfg.dataset.specs
16
 
17
- def compute_spec(wav: torch.Tensor):
 
 
 
 
18
  views = []
19
  for s in specs_cfg:
20
  mel = torchaudio.transforms.MelSpectrogram(
21
- sample_rate=16000, n_fft=s.n_fft, hop_length=s.hop_length,
22
- win_length=s.win_length, n_mels=s.n_mels).to(DEVICE)
 
 
 
 
23
  db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
24
- if db.shape != (512,512):
25
- db = torch.nn.functional.interpolate(db[None,None], size=(512,512),
26
- mode="bilinear", align_corners=False)[0,0]
27
- views.extend([db.repeat(3,1,1)]*2)
28
- return torch.stack(views)
 
 
 
29
 
30
  def single_predict(audio_path, domain, quick):
31
- # identne loogika sarulab-speech Space’iga :contentReference[oaicite:0]{index=0}
32
  wav, sr = torchaudio.load(audio_path)
33
- if sr != 16000:
34
- wav = torchaudio.transforms.Resample(sr, 16000)(wav)[0]
35
  else:
36
  wav = wav[0]
 
37
  wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
38
  else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
39
  spec = compute_spec(wav)
 
 
40
  preds = []
41
- # kui sul on fine-tuningus kasutatud nt 8 domääni, muuda siit
42
- NUM_DOMAINS = 3
43
  for dom in range(NUM_DOMAINS):
44
- d_oh = torch.nn.functional.one_hot(torch.tensor(dom, device=DEVICE),
45
- num_classes=model.num_dataset).float()[None]
46
- with torch.no_grad():
47
- p = model(wav[None], spec[None], d_oh).item()
48
- preds.append(p)
49
- if quick: break
 
50
  return float(np.mean(preds))
51
 
52
  def batch_predict(csv_file, wav_zip, num_domains):
 
53
  tdir = tempfile.mkdtemp()
54
  with zipfile.ZipFile(wav_zip.name) as zf:
55
  zf.extractall(tdir)
56
- df = pd.read_csv(csv_file.name)
 
57
  outs = []
 
 
58
  for rel in df["audio"]:
59
  path = pathlib.Path(tdir) / rel
60
- outs.append(single_predict(str(path), "dummy", quick=True)) # domeeni-väärtus ei loe siin
 
61
  df["pred_mos"] = outs
62
  out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
63
  df.to_csv(out_file, index=False)
64
  shutil.rmtree(tdir)
65
  return df, out_file
66
 
 
 
 
67
  with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
68
  gr.Markdown(
69
  """
70
- # UTMOS-v2
71
- Laadi üksik `.wav` või kogu partii ning saa ennustatud MOS-id.
72
- Mudel laetakse GPU-le ühe korra, seega järgmised päringud on kiiremad.
73
  """
74
  )
 
 
75
  with gr.Tab("Üksik klipp"):
76
  audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
77
- domain = gr.Dropdown(["default"], value="default",
78
- label="Domään (valikuline, kui ise muutsid koodi)")
79
- quick = gr.Checkbox(value=True, label="Kiire (1 iteratsioon/fold)")
80
  out_mos = gr.Number(label="Ennustatud MOS")
81
- gr.Button("Hinda").click(single_predict, [audio, domain, quick], out_mos)
 
 
82
 
 
83
  with gr.Tab("Partii (CSV + ZIP)"):
84
  csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
85
- zip_in = gr.File(file_types=[".zip"], label="ZIP kõikide WAV-idega")
86
- n_dom = gr.Number(value=8, precision=0, label="Treeningu domäänide arv")
 
87
  df_out = gr.Dataframe(label="Tulemused")
88
- file_dl = gr.File(label="Lae CSV ennustustega")
89
  gr.Button("Start").click(batch_predict,
90
- [csv_in, zip_in, n_dom],
91
- [df_out, file_dl])
92
 
 
93
  demo.queue(max_size=10).launch()
 
1
+ import os
2
  import gradio as gr
3
  import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
4
  from types import SimpleNamespace
5
  from utmosv2.utils import get_model
6
 
7
+ # ----------------------------------------------------------
8
+ # Seadme valik – GPU kui olemas, muidu CPU
9
+ # ----------------------------------------------------------
10
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ torch.set_grad_enabled(False) # inference - pole backprop’i vaja
12
 
13
+ # Tasuta CPU-Space: 2 vCPU piirame lõimede arvu
14
+ if DEVICE.type == "cpu":
15
+ torch.set_num_threads(min(2, os.cpu_count() or 2))
16
+
17
+ MAX_LEN = 160_000 # 10 s @16 kHz (kui vaja kiiremaks, vähenda nt 80 000)
18
+
19
+ # ----------------------------------------------------------
20
+ # Laeme mudeli korra Space'i käivitumisel
21
+ # ----------------------------------------------------------
22
  cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
23
  cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
24
  cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
25
+ cfg.weight = "utmosv2_estonian.pth" # ← sinu checkpointi fail
26
+ model = get_model(cfg, DEVICE).to(DEVICE).eval()
27
+ specs_cfg = cfg.dataset.specs # mel-vaadete konfiguratsioon
28
 
29
+ # ----------------------------------------------------------
30
+ # Abifunktsioonid
31
+ # ----------------------------------------------------------
32
+ def compute_spec(wav: torch.Tensor) -> torch.Tensor:
33
+ """Loome mitmevaatelise mel-spectrogrammi [V,3,512,512]."""
34
  views = []
35
  for s in specs_cfg:
36
  mel = torchaudio.transforms.MelSpectrogram(
37
+ sample_rate=16_000,
38
+ n_fft=s.n_fft,
39
+ hop_length=s.hop_length,
40
+ win_length=s.win_length,
41
+ n_mels=s.n_mels
42
+ ).to(DEVICE)
43
  db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
44
+ if db.shape != (512, 512):
45
+ db = torch.nn.functional.interpolate(db[None, None],
46
+ size=(512, 512),
47
+ mode="bilinear",
48
+ align_corners=False)[0, 0]
49
+ # kaks “külge” sama spektriga – sama logika nagu originaalmudelis
50
+ views.extend([db.repeat(3, 1, 1)] * 2)
51
+ return torch.stack(views) # [V, 3, 512, 512]
52
 
53
  def single_predict(audio_path, domain, quick):
54
+ """Ennusta ühe WAV-i MOS; kui quick=True, kasutab vaid esimest domääni."""
55
  wav, sr = torchaudio.load(audio_path)
56
+ if sr != 16_000:
57
+ wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
58
  else:
59
  wav = wav[0]
60
+
61
  wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
62
  else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
63
  spec = compute_spec(wav)
64
+
65
+ NUM_DOMAINS = 3 # mitu domääni treeningus kasutati
66
  preds = []
 
 
67
  for dom in range(NUM_DOMAINS):
68
+ d_oh = torch.nn.functional.one_hot(
69
+ torch.tensor(dom, device=DEVICE),
70
+ num_classes=model.num_dataset
71
+ ).float()[None]
72
+ preds.append(model(wav[None], spec[None], d_oh).item())
73
+ if quick: # quick ⇒ ainult esimene domään
74
+ break
75
  return float(np.mean(preds))
76
 
77
  def batch_predict(csv_file, wav_zip, num_domains):
78
+ """Laeb ZIP-i, arvutab kõikidele CSV-is loetletud failidele MOS-i."""
79
  tdir = tempfile.mkdtemp()
80
  with zipfile.ZipFile(wav_zip.name) as zf:
81
  zf.extractall(tdir)
82
+
83
+ df = pd.read_csv(csv_file.name)
84
  outs = []
85
+ quick = True if int(num_domains) == 1 else False
86
+
87
  for rel in df["audio"]:
88
  path = pathlib.Path(tdir) / rel
89
+ outs.append(single_predict(str(path), "dummy", quick))
90
+
91
  df["pred_mos"] = outs
92
  out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
93
  df.to_csv(out_file, index=False)
94
  shutil.rmtree(tdir)
95
  return df, out_file
96
 
97
+ # ----------------------------------------------------------
98
+ # Gradio kasutajaliides
99
+ # ----------------------------------------------------------
100
  with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
101
  gr.Markdown(
102
  """
103
+ # UTMOS-v2 (eesti kõne)
104
+ Ennusta objektiivseid MOS-skoore üksikutele või suurele failihulgale.
105
+ Mudel laetakse mällu korra Space'i käivitumisel (CPU-s töötab ~ paar s / klipp).
106
  """
107
  )
108
+
109
+ # --- üksik WAV ---
110
  with gr.Tab("Üksik klipp"):
111
  audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
112
+ quick = gr.Checkbox(value=True, label="Kiire režiim (1 domään)")
 
 
113
  out_mos = gr.Number(label="Ennustatud MOS")
114
+ gr.Button("Hinda").click(fn=lambda a, q: single_predict(a, "default", q),
115
+ inputs=[audio, quick],
116
+ outputs=out_mos)
117
 
118
+ # --- partii: CSV + ZIP ---
119
  with gr.Tab("Partii (CSV + ZIP)"):
120
  csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
121
+ zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
122
+ n_dom = gr.Number(value=3, precision=0,
123
+ label="Domäänide arv (1 = quick, >1 = täiskeskmistamine)")
124
  df_out = gr.Dataframe(label="Tulemused")
125
+ file_dl = gr.File(label="Lae ennustused CSV-na")
126
  gr.Button("Start").click(batch_predict,
127
+ inputs=[csv_in, zip_in, n_dom],
128
+ outputs=[df_out, file_dl])
129
 
130
+ # Queue → võimaldab järjekorda; launch() ilma share-liputa
131
  demo.queue(max_size=10).launch()