monatolmats's picture
Update app.py
4d5d07d verified
# ---------------- monkey-patch: CUDA-checkpoint → CPU -----------------------
import torch
_original_torch_load = torch.load
def _cpu_load(*args, **kwargs):
kwargs.setdefault("map_location", torch.device("cpu"))
return _original_torch_load(*args, **kwargs)
torch.load = _cpu_load
# ---------------------------------------------------------------------------
import os, gradio as gr, torchaudio, importlib, pandas as pd
import tempfile, zipfile, pathlib, shutil, numpy as np
from types import SimpleNamespace
from utmosv2.utils import get_model
# ---------- device & perf tweaks -------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if DEVICE.type == "cpu":
torch.set_num_threads(min(2, os.cpu_count() or 2))
MAX_LEN = 160_000 # 10 s @16 kHz
NUM_DOMAINS = 3 # ennustame alati 3 domeeni põhjal
# ---------------------------------------------------------------------------
# ---------- laadime mudeli korra -------------------------------------------
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod)
if not k.startswith("__")})
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
cfg.weight = "utmosv2_estonian.pth"
model = get_model(cfg, DEVICE).to(DEVICE).eval()
specs_cfg = cfg.dataset.specs
# ---------------------------------------------------------------------------
def compute_spec(wav: torch.Tensor) -> torch.Tensor:
"""Tagastab [V,3,512,512] multi-view mel-spectrogrammi."""
views = []
for s in specs_cfg:
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=16_000, n_fft=s.n_fft,
hop_length=s.hop_length, win_length=s.win_length,
n_mels=s.n_mels).to(DEVICE)
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
if db.shape != (512, 512):
db = torch.nn.functional.interpolate(db[None, None],
size=(512, 512),
mode="bilinear",
align_corners=False)[0, 0]
views.extend([db.repeat(3, 1, 1)] * 2)
return torch.stack(views)
def single_predict(audio_path: str) -> float:
"""MOS ühe faili kohta – keskmistab 3 domeeni."""
wav, sr = torchaudio.load(audio_path)
if sr != 16_000:
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
else:
wav = wav[0]
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))
).to(DEVICE)
spec = compute_spec(wav)
preds = []
for dom in range(NUM_DOMAINS):
dom_oh = torch.nn.functional.one_hot(
torch.tensor(dom, device=DEVICE),
num_classes=model.num_dataset
).float()[None]
preds.append(model(wav[None], spec[None], dom_oh).item())
return float(np.mean(preds))
# ---------- abifunktsioon faili otsimiseks ---------------------------------
def locate_audio(root: pathlib.Path, rel_path: str) -> pathlib.Path:
"""
Leia helifail ajutisest kataloogist.
1) root / rel_path
2) kui ei leidu, otsi failinime järgi kogu puust.
"""
rel_path = rel_path.strip().lstrip("/\\")
direct = root / rel_path
if direct.is_file():
return direct
matches = list(root.rglob(pathlib.Path(rel_path).name))
if matches:
return matches[0]
raise FileNotFoundError(f"'{rel_path}' ei leitud ZIP-ist – "
f"kontrolli CSV ja ZIP teede vastavust.")
# ---------- partii-töötlus --------------------------------------------------
def batch_predict(csv_file, wav_zip):
"""Loeb CSV & ZIP, lisab 'pred_mos', annab CSV-i ja tabeli tagasi."""
tdir = tempfile.mkdtemp()
with zipfile.ZipFile(wav_zip.name) as zf:
zf.extractall(tdir)
df = pd.read_csv(csv_file.name)
col = "faili_nimi"
outs, errors = [], []
for rel in df[col]:
try:
full = locate_audio(pathlib.Path(tdir), str(rel))
outs.append(single_predict(str(full)))
except Exception as e:
outs.append(np.nan)
errors.append(f"{rel}: {e}")
df["pred_mos"] = outs
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
df.to_csv(out_file, index=False)
shutil.rmtree(tdir)
if errors:
gr.Warning("⚠️ Mõned failid jäid leidmata või tekkis viga:\n" +
"\n".join(errors[:15]))
return df, out_file
# ---------------------------------------------------------------------------
# --------------------------- Gradio UI -------------------------------------
with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo:
gr.Markdown(
"""
# UTMOS-v2 (eesti kõne)
Ennustab objektiivse MOS-i **kolme treeningu-domeeni** keskmisena.
"""
)
# Üksik klipp -----------------------------------------------------------
with gr.Tab("1 WAV fail"):
audio = gr.Audio(type="filepath", label="WAV (16 kHz või muu)")
out_mos = gr.Number(label="Ennustatud MOS")
gr.Button("Hinda").click(single_predict, inputs=audio, outputs=out_mos)
# Partii ---------------------------------------------------------------
with gr.Tab("Partii (CSV + ZIP)"):
csv_in = gr.File(file_types=[".csv"], label="CSV (veerud: faili_nimi)")
zip_in = gr.File(file_types=[".zip"], label="ZIP (WAV failid)")
df_out = gr.Dataframe(label="Tulemused")
file_dl = gr.File(label="Lae CSV ennustustega")
gr.Button("Start").click(batch_predict,
inputs=[csv_in, zip_in],
outputs=[df_out, file_dl])
demo.queue(max_size=10).launch()