Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# ---------------- monkey-patch
|
2 |
import torch
|
3 |
_original_torch_load = torch.load
|
4 |
def _cpu_load(*args, **kwargs):
|
@@ -18,94 +18,127 @@ torch.set_grad_enabled(False)
|
|
18 |
if DEVICE.type == "cpu":
|
19 |
torch.set_num_threads(min(2, os.cpu_count() or 2))
|
20 |
|
21 |
-
MAX_LEN
|
22 |
-
NUM_DOMAINS
|
23 |
# ---------------------------------------------------------------------------
|
24 |
|
25 |
-
# ----------
|
26 |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
|
27 |
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod)
|
28 |
if not k.startswith("__")})
|
29 |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
|
30 |
cfg.weight = "utmosv2_estonian.pth"
|
31 |
-
model
|
32 |
-
specs_cfg
|
33 |
# ---------------------------------------------------------------------------
|
34 |
|
35 |
def compute_spec(wav: torch.Tensor) -> torch.Tensor:
|
|
|
36 |
views = []
|
37 |
for s in specs_cfg:
|
38 |
mel = torchaudio.transforms.MelSpectrogram(
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
|
43 |
if db.shape != (512, 512):
|
44 |
db = torch.nn.functional.interpolate(db[None, None],
|
45 |
size=(512, 512),
|
46 |
mode="bilinear",
|
47 |
align_corners=False)[0, 0]
|
48 |
-
views.extend([db.repeat(3, 1, 1)] * 2)
|
49 |
-
return torch.stack(views)
|
50 |
|
51 |
def single_predict(audio_path: str) -> float:
|
52 |
-
"""MOS ühe faili kohta –
|
53 |
wav, sr = torchaudio.load(audio_path)
|
54 |
if sr != 16_000:
|
55 |
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
|
56 |
else:
|
57 |
wav = wav[0]
|
58 |
|
59 |
-
wav
|
60 |
-
|
61 |
-
|
62 |
spec = compute_spec(wav)
|
63 |
|
64 |
preds = []
|
65 |
for dom in range(NUM_DOMAINS):
|
66 |
dom_oh = torch.nn.functional.one_hot(
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
preds.append(model(wav[None], spec[None], dom_oh).item())
|
71 |
return float(np.mean(preds))
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def batch_predict(csv_file, wav_zip):
|
74 |
-
"""
|
75 |
tdir = tempfile.mkdtemp()
|
76 |
with zipfile.ZipFile(wav_zip.name) as zf:
|
77 |
zf.extractall(tdir)
|
78 |
|
79 |
-
df
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
df["pred_mos"] = outs
|
84 |
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
|
85 |
df.to_csv(out_file, index=False)
|
86 |
shutil.rmtree(tdir)
|
|
|
|
|
|
|
|
|
|
|
87 |
return df, out_file
|
88 |
# ---------------------------------------------------------------------------
|
89 |
|
|
|
90 |
with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo:
|
91 |
gr.Markdown(
|
92 |
"""
|
93 |
# UTMOS-v2 (eesti kõne)
|
94 |
-
Ennustab objektiivse MOS-i **kolme treeningu
|
95 |
"""
|
96 |
)
|
97 |
|
98 |
-
#
|
99 |
with gr.Tab("Üksik klipp"):
|
100 |
-
audio = gr.Audio(type="filepath", label="WAV (16 kHz)")
|
101 |
out_mos = gr.Number(label="Ennustatud MOS")
|
102 |
-
gr.Button("Hinda").click(
|
103 |
-
inputs=audio,
|
104 |
-
outputs=out_mos)
|
105 |
|
106 |
-
#
|
107 |
with gr.Tab("Partii (CSV + ZIP)"):
|
108 |
-
csv_in = gr.File(file_types=[".csv"], label="CSV (audio
|
109 |
zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
|
110 |
df_out = gr.Dataframe(label="Tulemused")
|
111 |
file_dl = gr.File(label="Lae CSV ennustustega")
|
|
|
1 |
+
# ---------------- monkey-patch: CUDA-checkpoint → CPU -----------------------
|
2 |
import torch
|
3 |
_original_torch_load = torch.load
|
4 |
def _cpu_load(*args, **kwargs):
|
|
|
18 |
if DEVICE.type == "cpu":
|
19 |
torch.set_num_threads(min(2, os.cpu_count() or 2))
|
20 |
|
21 |
+
MAX_LEN = 160_000 # 10 s @16 kHz
|
22 |
+
NUM_DOMAINS = 3 # ennustame alati 3 domääni põhjal
|
23 |
# ---------------------------------------------------------------------------
|
24 |
|
25 |
+
# ---------- laadime mudeli korra -------------------------------------------
|
26 |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
|
27 |
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod)
|
28 |
if not k.startswith("__")})
|
29 |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
|
30 |
cfg.weight = "utmosv2_estonian.pth"
|
31 |
+
model = get_model(cfg, DEVICE).to(DEVICE).eval()
|
32 |
+
specs_cfg = cfg.dataset.specs
|
33 |
# ---------------------------------------------------------------------------
|
34 |
|
35 |
def compute_spec(wav: torch.Tensor) -> torch.Tensor:
|
36 |
+
"""Tagastab [V,3,512,512] multi-view mel-spectrogrammi."""
|
37 |
views = []
|
38 |
for s in specs_cfg:
|
39 |
mel = torchaudio.transforms.MelSpectrogram(
|
40 |
+
sample_rate=16_000, n_fft=s.n_fft,
|
41 |
+
hop_length=s.hop_length, win_length=s.win_length,
|
42 |
+
n_mels=s.n_mels).to(DEVICE)
|
43 |
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
|
44 |
if db.shape != (512, 512):
|
45 |
db = torch.nn.functional.interpolate(db[None, None],
|
46 |
size=(512, 512),
|
47 |
mode="bilinear",
|
48 |
align_corners=False)[0, 0]
|
49 |
+
views.extend([db.repeat(3, 1, 1)] * 2)
|
50 |
+
return torch.stack(views)
|
51 |
|
52 |
def single_predict(audio_path: str) -> float:
|
53 |
+
"""MOS ühe faili kohta – keskmistab 3 domääni."""
|
54 |
wav, sr = torchaudio.load(audio_path)
|
55 |
if sr != 16_000:
|
56 |
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
|
57 |
else:
|
58 |
wav = wav[0]
|
59 |
|
60 |
+
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
|
61 |
+
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))
|
62 |
+
).to(DEVICE)
|
63 |
spec = compute_spec(wav)
|
64 |
|
65 |
preds = []
|
66 |
for dom in range(NUM_DOMAINS):
|
67 |
dom_oh = torch.nn.functional.one_hot(
|
68 |
+
torch.tensor(dom, device=DEVICE),
|
69 |
+
num_classes=model.num_dataset
|
70 |
+
).float()[None]
|
71 |
preds.append(model(wav[None], spec[None], dom_oh).item())
|
72 |
return float(np.mean(preds))
|
73 |
|
74 |
+
# ---------- abifunktsioon faili otsimiseks ---------------------------------
|
75 |
+
def locate_audio(root: pathlib.Path, rel_path: str) -> pathlib.Path:
|
76 |
+
"""
|
77 |
+
Leia helifail ajutisest kataloogist.
|
78 |
+
1) root / rel_path
|
79 |
+
2) kui ei leidu, otsi failinime järgi kogu puust.
|
80 |
+
"""
|
81 |
+
rel_path = rel_path.strip().lstrip("/\\")
|
82 |
+
direct = root / rel_path
|
83 |
+
if direct.is_file():
|
84 |
+
return direct
|
85 |
+
|
86 |
+
matches = list(root.rglob(pathlib.Path(rel_path).name))
|
87 |
+
if matches:
|
88 |
+
return matches[0]
|
89 |
+
|
90 |
+
raise FileNotFoundError(f"'{rel_path}' ei leitud ZIP-ist – "
|
91 |
+
f"kontrolli CSV ja ZIP teede vastavust.")
|
92 |
+
|
93 |
+
# ---------- partii-töötlus --------------------------------------------------
|
94 |
def batch_predict(csv_file, wav_zip):
|
95 |
+
"""Loeb CSV & ZIP, lisab 'pred_mos', annab CSV-i ja tabeli tagasi."""
|
96 |
tdir = tempfile.mkdtemp()
|
97 |
with zipfile.ZipFile(wav_zip.name) as zf:
|
98 |
zf.extractall(tdir)
|
99 |
|
100 |
+
df = pd.read_csv(csv_file.name)
|
101 |
+
col = "faili_uus_nimi" if "faili_uus_nimi" in df.columns else "audio"
|
102 |
+
outs, errors = [], []
|
103 |
+
|
104 |
+
for rel in df[col]:
|
105 |
+
try:
|
106 |
+
full = locate_audio(pathlib.Path(tdir), str(rel))
|
107 |
+
outs.append(single_predict(str(full)))
|
108 |
+
except Exception as e:
|
109 |
+
outs.append(np.nan)
|
110 |
+
errors.append(f"{rel}: {e}")
|
111 |
|
112 |
df["pred_mos"] = outs
|
113 |
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
|
114 |
df.to_csv(out_file, index=False)
|
115 |
shutil.rmtree(tdir)
|
116 |
+
|
117 |
+
if errors:
|
118 |
+
gr.Warning("⚠️ Mõned failid jäid leidmata või tekkis viga:\n" +
|
119 |
+
"\n".join(errors[:15]))
|
120 |
+
|
121 |
return df, out_file
|
122 |
# ---------------------------------------------------------------------------
|
123 |
|
124 |
+
# --------------------------- Gradio UI -------------------------------------
|
125 |
with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo:
|
126 |
gr.Markdown(
|
127 |
"""
|
128 |
# UTMOS-v2 (eesti kõne)
|
129 |
+
Ennustab objektiivse MOS-i **kolme treeningu-domääni** keskmisena.
|
130 |
"""
|
131 |
)
|
132 |
|
133 |
+
# Üksik klipp -----------------------------------------------------------
|
134 |
with gr.Tab("Üksik klipp"):
|
135 |
+
audio = gr.Audio(type="filepath", label="WAV (16 kHz või muu)")
|
136 |
out_mos = gr.Number(label="Ennustatud MOS")
|
137 |
+
gr.Button("Hinda").click(single_predict, inputs=audio, outputs=out_mos)
|
|
|
|
|
138 |
|
139 |
+
# Partii ---------------------------------------------------------------
|
140 |
with gr.Tab("Partii (CSV + ZIP)"):
|
141 |
+
csv_in = gr.File(file_types=[".csv"], label="CSV (audio|faili_uus_nimi)")
|
142 |
zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
|
143 |
df_out = gr.Dataframe(label="Tulemused")
|
144 |
file_dl = gr.File(label="Lae CSV ennustustega")
|