Spaces:
Build error
Build error
import gradio as gr, pandas as pd, zipfile, tempfile, shutil, pathlib, torch | |
from utmosv2_batch_predict import compute_spec, MAX_LEN # reuse the function | |
from utmosv2.utils import get_model | |
from types import SimpleNamespace | |
import importlib, torchaudio, numpy as np, torch.nn as nn | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# --- load UTMOSv2 once --- | |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3") | |
cfg = SimpleNamespace(**{k:getattr(cfg_mod,k) | |
for k in dir(cfg_mod) if not k.startswith("_")}) | |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False | |
cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # put the file in the repo | |
model = get_model(cfg, DEVICE).eval() | |
specs_cfg = cfg.dataset.specs | |
def run_space(csv_file, wav_zip, num_domains): | |
""" | |
Inputs: | |
csv_file – csv with 'audio' and optional 'method' | |
wav_zip – zip that contains all referenced .wav files | |
Output: | |
DataFrame shown + downloadable CSV | |
""" | |
# ----- prepare wav directory ----- | |
tempdir = tempfile.mkdtemp() | |
with zipfile.ZipFile(wav_zip.name) as zf: | |
zf.extractall(tempdir) | |
df = pd.read_csv(csv_file.name) | |
pred = [] | |
for relpath in df["audio"]: | |
path = pathlib.Path(tempdir) / relpath | |
wav, sr = torchaudio.load(path) | |
if sr != 16_000: | |
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0] | |
else: | |
wav = wav[0] | |
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN | |
else nn.functional.pad(wav,(0,MAX_LEN-wav.numel()))).to(DEVICE) | |
spec = compute_spec(wav, specs_cfg, DEVICE) | |
dom_p = [] | |
for dom in range(int(num_domains)): | |
dom_oh = torch.nn.functional.one_hot( | |
torch.tensor(dom,device=DEVICE), | |
num_classes=model.num_dataset).float()[None] | |
with torch.no_grad(): | |
dom_p.append(model(wav[None], spec[None], dom_oh).item()) | |
pred.append(float(np.mean(dom_p))) | |
shutil.rmtree(tempdir) | |
df["pred_mos"] = pred | |
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name | |
df.to_csv(out_path, index=False) | |
return df, out_path # gr.File returns a link automatically | |
demo = gr.Interface( | |
run_space, | |
inputs=[ | |
gr.File(label="CSV (audio, method, MOS)", file_types=[".csv"]), | |
gr.File(label="ZIP with .wav files", file_types=[".zip"]), | |
gr.Number(label="Training domains", value=8, precision=0) | |
], | |
outputs=[ | |
gr.Dataframe(label="Results"), | |
gr.File(label="Download predictions CSV") | |
], | |
title="UTMOS-v2 MOS Estimator", | |
description="Upload the ground-truth CSV and a ZIP containing all WAVs. " | |
"The Space appends predicted MOS scores." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |