Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,93 +1,131 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
|
3 |
from types import SimpleNamespace
|
4 |
from utmosv2.utils import get_model
|
5 |
|
|
|
|
|
|
|
6 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
-
|
8 |
|
9 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
|
11 |
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
|
12 |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
|
13 |
-
cfg.weight = "utmosv2_estonian.pth"
|
14 |
-
model = get_model(cfg, DEVICE).eval()
|
15 |
-
specs_cfg = cfg.dataset.specs
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
views = []
|
19 |
for s in specs_cfg:
|
20 |
mel = torchaudio.transforms.MelSpectrogram(
|
21 |
-
sample_rate=
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
|
24 |
-
if db.shape != (512,512):
|
25 |
-
db = torch.nn.functional.interpolate(db[None,None],
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
def single_predict(audio_path, domain, quick):
|
31 |
-
|
32 |
wav, sr = torchaudio.load(audio_path)
|
33 |
-
if sr !=
|
34 |
-
wav = torchaudio.transforms.Resample(sr,
|
35 |
else:
|
36 |
wav = wav[0]
|
|
|
37 |
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
|
38 |
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
|
39 |
spec = compute_spec(wav)
|
|
|
|
|
40 |
preds = []
|
41 |
-
# kui sul on fine-tuningus kasutatud nt 8 domääni, muuda siit
|
42 |
-
NUM_DOMAINS = 3
|
43 |
for dom in range(NUM_DOMAINS):
|
44 |
-
d_oh = torch.nn.functional.one_hot(
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
preds.append(
|
49 |
-
if quick:
|
|
|
50 |
return float(np.mean(preds))
|
51 |
|
52 |
def batch_predict(csv_file, wav_zip, num_domains):
|
|
|
53 |
tdir = tempfile.mkdtemp()
|
54 |
with zipfile.ZipFile(wav_zip.name) as zf:
|
55 |
zf.extractall(tdir)
|
56 |
-
|
|
|
57 |
outs = []
|
|
|
|
|
58 |
for rel in df["audio"]:
|
59 |
path = pathlib.Path(tdir) / rel
|
60 |
-
outs.append(single_predict(str(path), "dummy", quick
|
|
|
61 |
df["pred_mos"] = outs
|
62 |
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
|
63 |
df.to_csv(out_file, index=False)
|
64 |
shutil.rmtree(tdir)
|
65 |
return df, out_file
|
66 |
|
|
|
|
|
|
|
67 |
with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
|
68 |
gr.Markdown(
|
69 |
"""
|
70 |
-
# UTMOS-v2
|
71 |
-
|
72 |
-
Mudel laetakse
|
73 |
"""
|
74 |
)
|
|
|
|
|
75 |
with gr.Tab("Üksik klipp"):
|
76 |
audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
|
77 |
-
|
78 |
-
label="Domään (valikuline, kui ise muutsid koodi)")
|
79 |
-
quick = gr.Checkbox(value=True, label="Kiire (1 iteratsioon/fold)")
|
80 |
out_mos = gr.Number(label="Ennustatud MOS")
|
81 |
-
gr.Button("Hinda").click(
|
|
|
|
|
82 |
|
|
|
83 |
with gr.Tab("Partii (CSV + ZIP)"):
|
84 |
csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
|
85 |
-
zip_in = gr.File(file_types=[".zip"], label="ZIP
|
86 |
-
n_dom = gr.Number(value=
|
|
|
87 |
df_out = gr.Dataframe(label="Tulemused")
|
88 |
-
file_dl = gr.File(label="Lae CSV
|
89 |
gr.Button("Start").click(batch_predict,
|
90 |
-
[csv_in, zip_in, n_dom],
|
91 |
-
[df_out, file_dl])
|
92 |
|
|
|
93 |
demo.queue(max_size=10).launch()
|
|
|
1 |
+
import os
|
2 |
import gradio as gr
|
3 |
import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
|
4 |
from types import SimpleNamespace
|
5 |
from utmosv2.utils import get_model
|
6 |
|
7 |
+
# ----------------------------------------------------------
|
8 |
+
# Seadme valik – GPU kui olemas, muidu CPU
|
9 |
+
# ----------------------------------------------------------
|
10 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
torch.set_grad_enabled(False) # inference - pole backprop’i vaja
|
12 |
|
13 |
+
# Tasuta CPU-Space: 2 vCPU → piirame lõimede arvu
|
14 |
+
if DEVICE.type == "cpu":
|
15 |
+
torch.set_num_threads(min(2, os.cpu_count() or 2))
|
16 |
+
|
17 |
+
MAX_LEN = 160_000 # 10 s @16 kHz (kui vaja kiiremaks, vähenda nt 80 000)
|
18 |
+
|
19 |
+
# ----------------------------------------------------------
|
20 |
+
# Laeme mudeli korra Space'i käivitumisel
|
21 |
+
# ----------------------------------------------------------
|
22 |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
|
23 |
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
|
24 |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
|
25 |
+
cfg.weight = "utmosv2_estonian.pth" # ← sinu checkpointi fail
|
26 |
+
model = get_model(cfg, DEVICE).to(DEVICE).eval()
|
27 |
+
specs_cfg = cfg.dataset.specs # mel-vaadete konfiguratsioon
|
28 |
|
29 |
+
# ----------------------------------------------------------
|
30 |
+
# Abifunktsioonid
|
31 |
+
# ----------------------------------------------------------
|
32 |
+
def compute_spec(wav: torch.Tensor) -> torch.Tensor:
|
33 |
+
"""Loome mitmevaatelise mel-spectrogrammi [V,3,512,512]."""
|
34 |
views = []
|
35 |
for s in specs_cfg:
|
36 |
mel = torchaudio.transforms.MelSpectrogram(
|
37 |
+
sample_rate=16_000,
|
38 |
+
n_fft=s.n_fft,
|
39 |
+
hop_length=s.hop_length,
|
40 |
+
win_length=s.win_length,
|
41 |
+
n_mels=s.n_mels
|
42 |
+
).to(DEVICE)
|
43 |
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
|
44 |
+
if db.shape != (512, 512):
|
45 |
+
db = torch.nn.functional.interpolate(db[None, None],
|
46 |
+
size=(512, 512),
|
47 |
+
mode="bilinear",
|
48 |
+
align_corners=False)[0, 0]
|
49 |
+
# kaks “külge” sama spektriga – sama logika nagu originaalmudelis
|
50 |
+
views.extend([db.repeat(3, 1, 1)] * 2)
|
51 |
+
return torch.stack(views) # [V, 3, 512, 512]
|
52 |
|
53 |
def single_predict(audio_path, domain, quick):
|
54 |
+
"""Ennusta ühe WAV-i MOS; kui quick=True, kasutab vaid esimest domääni."""
|
55 |
wav, sr = torchaudio.load(audio_path)
|
56 |
+
if sr != 16_000:
|
57 |
+
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
|
58 |
else:
|
59 |
wav = wav[0]
|
60 |
+
|
61 |
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
|
62 |
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
|
63 |
spec = compute_spec(wav)
|
64 |
+
|
65 |
+
NUM_DOMAINS = 3 # mitu domääni treeningus kasutati
|
66 |
preds = []
|
|
|
|
|
67 |
for dom in range(NUM_DOMAINS):
|
68 |
+
d_oh = torch.nn.functional.one_hot(
|
69 |
+
torch.tensor(dom, device=DEVICE),
|
70 |
+
num_classes=model.num_dataset
|
71 |
+
).float()[None]
|
72 |
+
preds.append(model(wav[None], spec[None], d_oh).item())
|
73 |
+
if quick: # quick ⇒ ainult esimene domään
|
74 |
+
break
|
75 |
return float(np.mean(preds))
|
76 |
|
77 |
def batch_predict(csv_file, wav_zip, num_domains):
|
78 |
+
"""Laeb ZIP-i, arvutab kõikidele CSV-is loetletud failidele MOS-i."""
|
79 |
tdir = tempfile.mkdtemp()
|
80 |
with zipfile.ZipFile(wav_zip.name) as zf:
|
81 |
zf.extractall(tdir)
|
82 |
+
|
83 |
+
df = pd.read_csv(csv_file.name)
|
84 |
outs = []
|
85 |
+
quick = True if int(num_domains) == 1 else False
|
86 |
+
|
87 |
for rel in df["audio"]:
|
88 |
path = pathlib.Path(tdir) / rel
|
89 |
+
outs.append(single_predict(str(path), "dummy", quick))
|
90 |
+
|
91 |
df["pred_mos"] = outs
|
92 |
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
|
93 |
df.to_csv(out_file, index=False)
|
94 |
shutil.rmtree(tdir)
|
95 |
return df, out_file
|
96 |
|
97 |
+
# ----------------------------------------------------------
|
98 |
+
# Gradio kasutajaliides
|
99 |
+
# ----------------------------------------------------------
|
100 |
with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
|
101 |
gr.Markdown(
|
102 |
"""
|
103 |
+
# UTMOS-v2 (eesti kõne)
|
104 |
+
Ennusta objektiivseid MOS-skoore üksikutele või suurele failihulgale.
|
105 |
+
Mudel laetakse mällu korra Space'i käivitumisel (CPU-s töötab ~ paar s / klipp).
|
106 |
"""
|
107 |
)
|
108 |
+
|
109 |
+
# --- üksik WAV ---
|
110 |
with gr.Tab("Üksik klipp"):
|
111 |
audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
|
112 |
+
quick = gr.Checkbox(value=True, label="Kiire režiim (1 domään)")
|
|
|
|
|
113 |
out_mos = gr.Number(label="Ennustatud MOS")
|
114 |
+
gr.Button("Hinda").click(fn=lambda a, q: single_predict(a, "default", q),
|
115 |
+
inputs=[audio, quick],
|
116 |
+
outputs=out_mos)
|
117 |
|
118 |
+
# --- partii: CSV + ZIP ---
|
119 |
with gr.Tab("Partii (CSV + ZIP)"):
|
120 |
csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
|
121 |
+
zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
|
122 |
+
n_dom = gr.Number(value=3, precision=0,
|
123 |
+
label="Domäänide arv (1 = quick, >1 = täiskeskmistamine)")
|
124 |
df_out = gr.Dataframe(label="Tulemused")
|
125 |
+
file_dl = gr.File(label="Lae ennustused CSV-na")
|
126 |
gr.Button("Start").click(batch_predict,
|
127 |
+
inputs=[csv_in, zip_in, n_dom],
|
128 |
+
outputs=[df_out, file_dl])
|
129 |
|
130 |
+
# Queue → võimaldab järjekorda; launch() ilma share-liputa
|
131 |
demo.queue(max_size=10).launch()
|