Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,76 +1,93 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
from utmosv2.utils import get_model
|
4 |
from types import SimpleNamespace
|
5 |
-
|
6 |
|
7 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
8 |
|
9 |
-
#
|
10 |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
|
11 |
-
cfg
|
12 |
-
for k in dir(cfg_mod) if not k.startswith("_")})
|
13 |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
|
14 |
-
cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt"
|
15 |
-
model
|
16 |
-
specs_cfg
|
17 |
|
18 |
-
def
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
wav
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
gr.
|
63 |
-
gr.
|
64 |
-
gr.
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
gr.File(label="Download predictions CSV")
|
69 |
-
],
|
70 |
-
title="UTMOS-v2 MOS Estimator",
|
71 |
-
description="Upload the ground-truth CSV and a ZIP containing all WAVs. "
|
72 |
-
"The Space appends predicted MOS scores."
|
73 |
-
)
|
74 |
|
75 |
-
|
76 |
-
demo.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
|
|
|
3 |
from types import SimpleNamespace
|
4 |
+
from utmosv2.utils import get_model
|
5 |
|
6 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
MAX_LEN = 160_000 # 10 s @16 kHz
|
8 |
|
9 |
+
# ---- mudel laetakse korra kogu Space’i eluea jooksul ----
|
10 |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
|
11 |
+
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
|
|
|
12 |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
|
13 |
+
cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # lisad checks-pointi repo „Files“ vaatesse
|
14 |
+
model = get_model(cfg, DEVICE).eval()
|
15 |
+
specs_cfg = cfg.dataset.specs
|
16 |
|
17 |
+
def compute_spec(wav: torch.Tensor):
|
18 |
+
views = []
|
19 |
+
for s in specs_cfg:
|
20 |
+
mel = torchaudio.transforms.MelSpectrogram(
|
21 |
+
sample_rate=16000, n_fft=s.n_fft, hop_length=s.hop_length,
|
22 |
+
win_length=s.win_length, n_mels=s.n_mels).to(DEVICE)
|
23 |
+
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
|
24 |
+
if db.shape != (512,512):
|
25 |
+
db = torch.nn.functional.interpolate(db[None,None], size=(512,512),
|
26 |
+
mode="bilinear", align_corners=False)[0,0]
|
27 |
+
views.extend([db.repeat(3,1,1)]*2)
|
28 |
+
return torch.stack(views)
|
29 |
|
30 |
+
def single_predict(audio_path, domain, quick):
|
31 |
+
# identne loogika sarulab-speech Space’iga :contentReference[oaicite:0]{index=0}
|
32 |
+
wav, sr = torchaudio.load(audio_path)
|
33 |
+
if sr != 16000:
|
34 |
+
wav = torchaudio.transforms.Resample(sr, 16000)(wav)[0]
|
35 |
+
else:
|
36 |
+
wav = wav[0]
|
37 |
+
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
|
38 |
+
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
|
39 |
+
spec = compute_spec(wav)
|
40 |
+
preds = []
|
41 |
+
# kui sul on fine-tuningus kasutatud nt 8 domääni, muuda siit
|
42 |
+
NUM_DOMAINS = 8
|
43 |
+
for dom in range(NUM_DOMAINS):
|
44 |
+
d_oh = torch.nn.functional.one_hot(torch.tensor(dom, device=DEVICE),
|
45 |
+
num_classes=model.num_dataset).float()[None]
|
46 |
+
with torch.no_grad():
|
47 |
+
p = model(wav[None], spec[None], d_oh).item()
|
48 |
+
preds.append(p)
|
49 |
+
if quick: break
|
50 |
+
return float(np.mean(preds))
|
51 |
|
52 |
+
def batch_predict(csv_file, wav_zip, num_domains):
|
53 |
+
tdir = tempfile.mkdtemp()
|
54 |
+
with zipfile.ZipFile(wav_zip.name) as zf:
|
55 |
+
zf.extractall(tdir)
|
56 |
+
df = pd.read_csv(csv_file.name)
|
57 |
+
outs = []
|
58 |
+
for rel in df["audio"]:
|
59 |
+
path = pathlib.Path(tdir) / rel
|
60 |
+
outs.append(single_predict(str(path), "dummy", quick=True)) # domeeni-väärtus ei loe siin
|
61 |
+
df["pred_mos"] = outs
|
62 |
+
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
|
63 |
+
df.to_csv(out_file, index=False)
|
64 |
+
shutil.rmtree(tdir)
|
65 |
+
return df, out_file
|
66 |
|
67 |
+
with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
|
68 |
+
gr.Markdown(
|
69 |
+
"""
|
70 |
+
# UTMOS-v2
|
71 |
+
Laadi üksik `.wav` või kogu partii ning saa ennustatud MOS-id.
|
72 |
+
Mudel laetakse GPU-le ühe korra, seega järgmised päringud on kiiremad.
|
73 |
+
"""
|
74 |
+
)
|
75 |
+
with gr.Tab("Üksik klipp"):
|
76 |
+
audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
|
77 |
+
domain = gr.Dropdown(["default"], value="default",
|
78 |
+
label="Domään (valikuline, kui ise muutsid koodi)")
|
79 |
+
quick = gr.Checkbox(value=True, label="Kiire (1 iteratsioon/fold)")
|
80 |
+
out_mos = gr.Number(label="Ennustatud MOS")
|
81 |
+
gr.Button("Hinda").click(single_predict, [audio, domain, quick], out_mos)
|
82 |
|
83 |
+
with gr.Tab("Partii (CSV + ZIP)"):
|
84 |
+
csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
|
85 |
+
zip_in = gr.File(file_types=[".zip"], label="ZIP kõikide WAV-idega")
|
86 |
+
n_dom = gr.Number(value=8, precision=0, label="Treeningu domäänide arv")
|
87 |
+
df_out = gr.Dataframe(label="Tulemused")
|
88 |
+
file_dl = gr.File(label="Lae CSV ennustustega")
|
89 |
+
gr.Button("Start").click(batch_predict,
|
90 |
+
[csv_in, zip_in, n_dom],
|
91 |
+
[df_out, file_dl])
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
demo.queue(max_size=10).launch()
|
|