monatolmats commited on
Commit
dc1aa8f
·
verified ·
1 Parent(s): e8a1e5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -64
app.py CHANGED
@@ -1,76 +1,93 @@
1
- import gradio as gr, pandas as pd, zipfile, tempfile, shutil, pathlib, torch
2
- from utmosv2_batch_predict import compute_spec, MAX_LEN # reuse the function
3
- from utmosv2.utils import get_model
4
  from types import SimpleNamespace
5
- import importlib, torchaudio, numpy as np, torch.nn as nn
6
 
7
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
8
 
9
- # --- load UTMOSv2 once ---
10
  cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
11
- cfg = SimpleNamespace(**{k:getattr(cfg_mod,k)
12
- for k in dir(cfg_mod) if not k.startswith("_")})
13
  cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
14
- cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # put the file in the repo
15
- model = get_model(cfg, DEVICE).eval()
16
- specs_cfg = cfg.dataset.specs
17
 
18
- def run_space(csv_file, wav_zip, num_domains):
19
- """
20
- Inputs:
21
- csv_file csv with 'audio' and optional 'method'
22
- wav_zip – zip that contains all referenced .wav files
23
- Output:
24
- DataFrame shown + downloadable CSV
25
- """
26
- # ----- prepare wav directory -----
27
- tempdir = tempfile.mkdtemp()
28
- with zipfile.ZipFile(wav_zip.name) as zf:
29
- zf.extractall(tempdir)
30
 
31
- df = pd.read_csv(csv_file.name)
32
- pred = []
33
- for relpath in df["audio"]:
34
- path = pathlib.Path(tempdir) / relpath
35
- wav, sr = torchaudio.load(path)
36
- if sr != 16_000:
37
- wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
38
- else:
39
- wav = wav[0]
40
- wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
41
- else nn.functional.pad(wav,(0,MAX_LEN-wav.numel()))).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
42
 
43
- spec = compute_spec(wav, specs_cfg, DEVICE)
44
- dom_p = []
45
- for dom in range(int(num_domains)):
46
- dom_oh = torch.nn.functional.one_hot(
47
- torch.tensor(dom,device=DEVICE),
48
- num_classes=model.num_dataset).float()[None]
49
- with torch.no_grad():
50
- dom_p.append(model(wav[None], spec[None], dom_oh).item())
51
- pred.append(float(np.mean(dom_p)))
 
 
 
 
 
52
 
53
- shutil.rmtree(tempdir)
54
- df["pred_mos"] = pred
55
- out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
56
- df.to_csv(out_path, index=False)
57
- return df, out_path # gr.File returns a link automatically
 
 
 
 
 
 
 
 
 
 
58
 
59
- demo = gr.Interface(
60
- run_space,
61
- inputs=[
62
- gr.File(label="CSV (audio, method, MOS)", file_types=[".csv"]),
63
- gr.File(label="ZIP with .wav files", file_types=[".zip"]),
64
- gr.Number(label="Training domains", value=8, precision=0)
65
- ],
66
- outputs=[
67
- gr.Dataframe(label="Results"),
68
- gr.File(label="Download predictions CSV")
69
- ],
70
- title="UTMOS-v2 MOS Estimator",
71
- description="Upload the ground-truth CSV and a ZIP containing all WAVs. "
72
- "The Space appends predicted MOS scores."
73
- )
74
 
75
- if __name__ == "__main__":
76
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
 
3
  from types import SimpleNamespace
4
+ from utmosv2.utils import get_model
5
 
6
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ MAX_LEN = 160_000 # 10 s @16 kHz
8
 
9
+ # ---- mudel laetakse korra kogu Space’i eluea jooksul ----
10
  cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
11
+ cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
 
12
  cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
13
+ cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # lisad checks-pointi repo „Files“ vaatesse
14
+ model = get_model(cfg, DEVICE).eval()
15
+ specs_cfg = cfg.dataset.specs
16
 
17
+ def compute_spec(wav: torch.Tensor):
18
+ views = []
19
+ for s in specs_cfg:
20
+ mel = torchaudio.transforms.MelSpectrogram(
21
+ sample_rate=16000, n_fft=s.n_fft, hop_length=s.hop_length,
22
+ win_length=s.win_length, n_mels=s.n_mels).to(DEVICE)
23
+ db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
24
+ if db.shape != (512,512):
25
+ db = torch.nn.functional.interpolate(db[None,None], size=(512,512),
26
+ mode="bilinear", align_corners=False)[0,0]
27
+ views.extend([db.repeat(3,1,1)]*2)
28
+ return torch.stack(views)
29
 
30
+ def single_predict(audio_path, domain, quick):
31
+ # identne loogika sarulab-speech Space’iga :contentReference[oaicite:0]{index=0}
32
+ wav, sr = torchaudio.load(audio_path)
33
+ if sr != 16000:
34
+ wav = torchaudio.transforms.Resample(sr, 16000)(wav)[0]
35
+ else:
36
+ wav = wav[0]
37
+ wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
38
+ else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
39
+ spec = compute_spec(wav)
40
+ preds = []
41
+ # kui sul on fine-tuningus kasutatud nt 8 domääni, muuda siit
42
+ NUM_DOMAINS = 8
43
+ for dom in range(NUM_DOMAINS):
44
+ d_oh = torch.nn.functional.one_hot(torch.tensor(dom, device=DEVICE),
45
+ num_classes=model.num_dataset).float()[None]
46
+ with torch.no_grad():
47
+ p = model(wav[None], spec[None], d_oh).item()
48
+ preds.append(p)
49
+ if quick: break
50
+ return float(np.mean(preds))
51
 
52
+ def batch_predict(csv_file, wav_zip, num_domains):
53
+ tdir = tempfile.mkdtemp()
54
+ with zipfile.ZipFile(wav_zip.name) as zf:
55
+ zf.extractall(tdir)
56
+ df = pd.read_csv(csv_file.name)
57
+ outs = []
58
+ for rel in df["audio"]:
59
+ path = pathlib.Path(tdir) / rel
60
+ outs.append(single_predict(str(path), "dummy", quick=True)) # domeeni-väärtus ei loe siin
61
+ df["pred_mos"] = outs
62
+ out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
63
+ df.to_csv(out_file, index=False)
64
+ shutil.rmtree(tdir)
65
+ return df, out_file
66
 
67
+ with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
68
+ gr.Markdown(
69
+ """
70
+ # UTMOS-v2
71
+ Laadi üksik `.wav` või kogu partii ning saa ennustatud MOS-id.
72
+ Mudel laetakse GPU-le ühe korra, seega järgmised päringud on kiiremad.
73
+ """
74
+ )
75
+ with gr.Tab("Üksik klipp"):
76
+ audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
77
+ domain = gr.Dropdown(["default"], value="default",
78
+ label="Domään (valikuline, kui ise muutsid koodi)")
79
+ quick = gr.Checkbox(value=True, label="Kiire (1 iteratsioon/fold)")
80
+ out_mos = gr.Number(label="Ennustatud MOS")
81
+ gr.Button("Hinda").click(single_predict, [audio, domain, quick], out_mos)
82
 
83
+ with gr.Tab("Partii (CSV + ZIP)"):
84
+ csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
85
+ zip_in = gr.File(file_types=[".zip"], label="ZIP kõikide WAV-idega")
86
+ n_dom = gr.Number(value=8, precision=0, label="Treeningu domäänide arv")
87
+ df_out = gr.Dataframe(label="Tulemused")
88
+ file_dl = gr.File(label="Lae CSV ennustustega")
89
+ gr.Button("Start").click(batch_predict,
90
+ [csv_in, zip_in, n_dom],
91
+ [df_out, file_dl])
 
 
 
 
 
 
92
 
93
+ demo.queue(max_size=10).launch()