monatolmats's picture
Update app.py
80558ad verified
raw
history blame
2.95 kB
import gradio as gr, pandas as pd, zipfile, tempfile, shutil, pathlib, torch
from utmosv2_batch_predict import compute_spec, MAX_LEN # reuse the function
from utmosv2.utils import get_model
from types import SimpleNamespace
import importlib, torchaudio, numpy as np, torch.nn as nn
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- load UTMOSv2 once ---
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{k:getattr(cfg_mod,k)
for k in dir(cfg_mod) if not k.startswith("_")})
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # put the file in the repo
model = get_model(cfg, DEVICE).eval()
specs_cfg = cfg.dataset.specs
def run_space(csv_file, wav_zip, num_domains):
"""
Inputs:
csv_file – csv with 'audio' and optional 'method'
wav_zip – zip that contains all referenced .wav files
Output:
DataFrame shown + downloadable CSV
"""
# ----- prepare wav directory -----
tempdir = tempfile.mkdtemp()
with zipfile.ZipFile(wav_zip.name) as zf:
zf.extractall(tempdir)
df = pd.read_csv(csv_file.name)
pred = []
for relpath in df["audio"]:
path = pathlib.Path(tempdir) / relpath
wav, sr = torchaudio.load(path)
if sr != 16_000:
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
else:
wav = wav[0]
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
else nn.functional.pad(wav,(0,MAX_LEN-wav.numel()))).to(DEVICE)
spec = compute_spec(wav, specs_cfg, DEVICE)
dom_p = []
for dom in range(int(num_domains)):
dom_oh = torch.nn.functional.one_hot(
torch.tensor(dom,device=DEVICE),
num_classes=model.num_dataset).float()[None]
with torch.no_grad():
dom_p.append(model(wav[None], spec[None], dom_oh).item())
pred.append(float(np.mean(dom_p)))
shutil.rmtree(tempdir)
df["pred_mos"] = pred
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
df.to_csv(out_path, index=False)
return df, out_path # gr.File returns a link automatically
demo = gr.Interface(
run_space,
inputs=[
gr.File(label="CSV (audio, method, MOS)", file_types=[".csv"]),
gr.File(label="ZIP with .wav files", file_types=[".zip"]),
gr.Number(label="Training domains", value=8, precision=0)
],
outputs=[
gr.Dataframe(label="Results"),
gr.File(label="Download predictions CSV")
],
title="UTMOS-v2 MOS Estimator",
description="Upload the ground-truth CSV and a ZIP containing all WAVs. "
"The Space appends predicted MOS scores."
)
if __name__ == "__main__":
demo.launch()