monatolmats commited on
Commit
90101f3
·
verified ·
1 Parent(s): 68a55a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -30
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # ---------------- monkey-patch so CUDA checkpoints load on CPU --------------
2
  import torch
3
  _original_torch_load = torch.load
4
  def _cpu_load(*args, **kwargs):
@@ -18,94 +18,127 @@ torch.set_grad_enabled(False)
18
  if DEVICE.type == "cpu":
19
  torch.set_num_threads(min(2, os.cpu_count() or 2))
20
 
21
- MAX_LEN = 160_000 # 10 s @16 kHz
22
- NUM_DOMAINS = 3 # ennustame alati 3 domeeni põhjal
23
  # ---------------------------------------------------------------------------
24
 
25
- # ---------- load model once -------------------------------------------------
26
  cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
27
  cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod)
28
  if not k.startswith("__")})
29
  cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
30
  cfg.weight = "utmosv2_estonian.pth"
31
- model = get_model(cfg, DEVICE).to(DEVICE).eval()
32
- specs_cfg = cfg.dataset.specs
33
  # ---------------------------------------------------------------------------
34
 
35
  def compute_spec(wav: torch.Tensor) -> torch.Tensor:
 
36
  views = []
37
  for s in specs_cfg:
38
  mel = torchaudio.transforms.MelSpectrogram(
39
- sample_rate=16_000, n_fft=s.n_fft,
40
- hop_length=s.hop_length, win_length=s.win_length,
41
- n_mels=s.n_mels).to(DEVICE)
42
  db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
43
  if db.shape != (512, 512):
44
  db = torch.nn.functional.interpolate(db[None, None],
45
  size=(512, 512),
46
  mode="bilinear",
47
  align_corners=False)[0, 0]
48
- views.extend([db.repeat(3, 1, 1)] * 2) # kaks vaadet
49
- return torch.stack(views) # [V,3,512,512]
50
 
51
  def single_predict(audio_path: str) -> float:
52
- """MOS ühe faili kohta – keskmista 3 domeeni."""
53
  wav, sr = torchaudio.load(audio_path)
54
  if sr != 16_000:
55
  wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
56
  else:
57
  wav = wav[0]
58
 
59
- wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
60
- else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))
61
- ).to(DEVICE)
62
  spec = compute_spec(wav)
63
 
64
  preds = []
65
  for dom in range(NUM_DOMAINS):
66
  dom_oh = torch.nn.functional.one_hot(
67
- torch.tensor(dom, device=DEVICE),
68
- num_classes=model.num_dataset
69
- ).float()[None]
70
  preds.append(model(wav[None], spec[None], dom_oh).item())
71
  return float(np.mean(preds))
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def batch_predict(csv_file, wav_zip):
74
- """Partii ennustused: loe CSV, lae ZIPist WAV-id, lisa 'pred_mos' veerg."""
75
  tdir = tempfile.mkdtemp()
76
  with zipfile.ZipFile(wav_zip.name) as zf:
77
  zf.extractall(tdir)
78
 
79
- df, outs = pd.read_csv(csv_file.name), []
80
- for rel in df["faili_uus_nimi"]:
81
- outs.append(single_predict(str(pathlib.Path(tdir) / rel)))
 
 
 
 
 
 
 
 
82
 
83
  df["pred_mos"] = outs
84
  out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
85
  df.to_csv(out_file, index=False)
86
  shutil.rmtree(tdir)
 
 
 
 
 
87
  return df, out_file
88
  # ---------------------------------------------------------------------------
89
 
 
90
  with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo:
91
  gr.Markdown(
92
  """
93
  # UTMOS-v2 (eesti kõne)
94
- Ennustab objektiivse MOS-i **kolme treeningudomääni** keskmisena.
95
  """
96
  )
97
 
98
- # ---- üksik klipp -------------------------------------------------------
99
  with gr.Tab("Üksik klipp"):
100
- audio = gr.Audio(type="filepath", label="WAV (16 kHz)")
101
  out_mos = gr.Number(label="Ennustatud MOS")
102
- gr.Button("Hinda").click(fn=single_predict,
103
- inputs=audio,
104
- outputs=out_mos)
105
 
106
- # ---- partii ------------------------------------------------------------
107
  with gr.Tab("Partii (CSV + ZIP)"):
108
- csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
109
  zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
110
  df_out = gr.Dataframe(label="Tulemused")
111
  file_dl = gr.File(label="Lae CSV ennustustega")
 
1
+ # ---------------- monkey-patch: CUDA-checkpoint CPU -----------------------
2
  import torch
3
  _original_torch_load = torch.load
4
  def _cpu_load(*args, **kwargs):
 
18
  if DEVICE.type == "cpu":
19
  torch.set_num_threads(min(2, os.cpu_count() or 2))
20
 
21
+ MAX_LEN = 160_000 # 10 s @16 kHz
22
+ NUM_DOMAINS = 3 # ennustame alati 3 domääni põhjal
23
  # ---------------------------------------------------------------------------
24
 
25
+ # ---------- laadime mudeli korra -------------------------------------------
26
  cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
27
  cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod)
28
  if not k.startswith("__")})
29
  cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
30
  cfg.weight = "utmosv2_estonian.pth"
31
+ model = get_model(cfg, DEVICE).to(DEVICE).eval()
32
+ specs_cfg = cfg.dataset.specs
33
  # ---------------------------------------------------------------------------
34
 
35
  def compute_spec(wav: torch.Tensor) -> torch.Tensor:
36
+ """Tagastab [V,3,512,512] multi-view mel-spectrogrammi."""
37
  views = []
38
  for s in specs_cfg:
39
  mel = torchaudio.transforms.MelSpectrogram(
40
+ sample_rate=16_000, n_fft=s.n_fft,
41
+ hop_length=s.hop_length, win_length=s.win_length,
42
+ n_mels=s.n_mels).to(DEVICE)
43
  db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
44
  if db.shape != (512, 512):
45
  db = torch.nn.functional.interpolate(db[None, None],
46
  size=(512, 512),
47
  mode="bilinear",
48
  align_corners=False)[0, 0]
49
+ views.extend([db.repeat(3, 1, 1)] * 2)
50
+ return torch.stack(views)
51
 
52
  def single_predict(audio_path: str) -> float:
53
+ """MOS ühe faili kohta – keskmistab 3 domääni."""
54
  wav, sr = torchaudio.load(audio_path)
55
  if sr != 16_000:
56
  wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
57
  else:
58
  wav = wav[0]
59
 
60
+ wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
61
+ else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))
62
+ ).to(DEVICE)
63
  spec = compute_spec(wav)
64
 
65
  preds = []
66
  for dom in range(NUM_DOMAINS):
67
  dom_oh = torch.nn.functional.one_hot(
68
+ torch.tensor(dom, device=DEVICE),
69
+ num_classes=model.num_dataset
70
+ ).float()[None]
71
  preds.append(model(wav[None], spec[None], dom_oh).item())
72
  return float(np.mean(preds))
73
 
74
+ # ---------- abifunktsioon faili otsimiseks ---------------------------------
75
+ def locate_audio(root: pathlib.Path, rel_path: str) -> pathlib.Path:
76
+ """
77
+ Leia helifail ajutisest kataloogist.
78
+ 1) root / rel_path
79
+ 2) kui ei leidu, otsi failinime järgi kogu puust.
80
+ """
81
+ rel_path = rel_path.strip().lstrip("/\\")
82
+ direct = root / rel_path
83
+ if direct.is_file():
84
+ return direct
85
+
86
+ matches = list(root.rglob(pathlib.Path(rel_path).name))
87
+ if matches:
88
+ return matches[0]
89
+
90
+ raise FileNotFoundError(f"'{rel_path}' ei leitud ZIP-ist – "
91
+ f"kontrolli CSV ja ZIP teede vastavust.")
92
+
93
+ # ---------- partii-töötlus --------------------------------------------------
94
  def batch_predict(csv_file, wav_zip):
95
+ """Loeb CSV & ZIP, lisab 'pred_mos', annab CSV-i ja tabeli tagasi."""
96
  tdir = tempfile.mkdtemp()
97
  with zipfile.ZipFile(wav_zip.name) as zf:
98
  zf.extractall(tdir)
99
 
100
+ df = pd.read_csv(csv_file.name)
101
+ col = "faili_uus_nimi" if "faili_uus_nimi" in df.columns else "audio"
102
+ outs, errors = [], []
103
+
104
+ for rel in df[col]:
105
+ try:
106
+ full = locate_audio(pathlib.Path(tdir), str(rel))
107
+ outs.append(single_predict(str(full)))
108
+ except Exception as e:
109
+ outs.append(np.nan)
110
+ errors.append(f"{rel}: {e}")
111
 
112
  df["pred_mos"] = outs
113
  out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
114
  df.to_csv(out_file, index=False)
115
  shutil.rmtree(tdir)
116
+
117
+ if errors:
118
+ gr.Warning("⚠️ Mõned failid jäid leidmata või tekkis viga:\n" +
119
+ "\n".join(errors[:15]))
120
+
121
  return df, out_file
122
  # ---------------------------------------------------------------------------
123
 
124
+ # --------------------------- Gradio UI -------------------------------------
125
  with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo:
126
  gr.Markdown(
127
  """
128
  # UTMOS-v2 (eesti kõne)
129
+ Ennustab objektiivse MOS-i **kolme treeningu-domääni** keskmisena.
130
  """
131
  )
132
 
133
+ # Üksik klipp -----------------------------------------------------------
134
  with gr.Tab("Üksik klipp"):
135
+ audio = gr.Audio(type="filepath", label="WAV (16 kHz või muu)")
136
  out_mos = gr.Number(label="Ennustatud MOS")
137
+ gr.Button("Hinda").click(single_predict, inputs=audio, outputs=out_mos)
 
 
138
 
139
+ # Partii ---------------------------------------------------------------
140
  with gr.Tab("Partii (CSV + ZIP)"):
141
+ csv_in = gr.File(file_types=[".csv"], label="CSV (audio|faili_uus_nimi)")
142
  zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
143
  df_out = gr.Dataframe(label="Tulemused")
144
  file_dl = gr.File(label="Lae CSV ennustustega")