# ---------------- monkey-patch: CUDA-checkpoint → CPU ----------------------- import torch _original_torch_load = torch.load def _cpu_load(*args, **kwargs): kwargs.setdefault("map_location", torch.device("cpu")) return _original_torch_load(*args, **kwargs) torch.load = _cpu_load # --------------------------------------------------------------------------- import os, gradio as gr, torchaudio, importlib, pandas as pd import tempfile, zipfile, pathlib, shutil, numpy as np from types import SimpleNamespace from utmosv2.utils import get_model # ---------- device & perf tweaks ------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.set_grad_enabled(False) if DEVICE.type == "cpu": torch.set_num_threads(min(2, os.cpu_count() or 2)) MAX_LEN = 160_000 # 10 s @16 kHz NUM_DOMAINS = 3 # ennustame alati 3 domeeni põhjal # --------------------------------------------------------------------------- # ---------- laadime mudeli korra ------------------------------------------- cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3") cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")}) cfg.phase, cfg.data_config, cfg.print_config = "test", None, False cfg.weight = "utmosv2_estonian.pth" model = get_model(cfg, DEVICE).to(DEVICE).eval() specs_cfg = cfg.dataset.specs # --------------------------------------------------------------------------- def compute_spec(wav: torch.Tensor) -> torch.Tensor: """Tagastab [V,3,512,512] multi-view mel-spectrogrammi.""" views = [] for s in specs_cfg: mel = torchaudio.transforms.MelSpectrogram( sample_rate=16_000, n_fft=s.n_fft, hop_length=s.hop_length, win_length=s.win_length, n_mels=s.n_mels).to(DEVICE) db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0] if db.shape != (512, 512): db = torch.nn.functional.interpolate(db[None, None], size=(512, 512), mode="bilinear", align_corners=False)[0, 0] views.extend([db.repeat(3, 1, 1)] * 2) return torch.stack(views) def single_predict(audio_path: str) -> float: """MOS ühe faili kohta – keskmistab 3 domeeni.""" wav, sr = torchaudio.load(audio_path) if sr != 16_000: wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0] else: wav = wav[0] wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel())) ).to(DEVICE) spec = compute_spec(wav) preds = [] for dom in range(NUM_DOMAINS): dom_oh = torch.nn.functional.one_hot( torch.tensor(dom, device=DEVICE), num_classes=model.num_dataset ).float()[None] preds.append(model(wav[None], spec[None], dom_oh).item()) return float(np.mean(preds)) # ---------- abifunktsioon faili otsimiseks --------------------------------- def locate_audio(root: pathlib.Path, rel_path: str) -> pathlib.Path: """ Leia helifail ajutisest kataloogist. 1) root / rel_path 2) kui ei leidu, otsi failinime järgi kogu puust. """ rel_path = rel_path.strip().lstrip("/\\") direct = root / rel_path if direct.is_file(): return direct matches = list(root.rglob(pathlib.Path(rel_path).name)) if matches: return matches[0] raise FileNotFoundError(f"'{rel_path}' ei leitud ZIP-ist – " f"kontrolli CSV ja ZIP teede vastavust.") # ---------- partii-töötlus -------------------------------------------------- def batch_predict(csv_file, wav_zip): """Loeb CSV & ZIP, lisab 'pred_mos', annab CSV-i ja tabeli tagasi.""" tdir = tempfile.mkdtemp() with zipfile.ZipFile(wav_zip.name) as zf: zf.extractall(tdir) df = pd.read_csv(csv_file.name) col = "faili_nimi" outs, errors = [], [] for rel in df[col]: try: full = locate_audio(pathlib.Path(tdir), str(rel)) outs.append(single_predict(str(full))) except Exception as e: outs.append(np.nan) errors.append(f"{rel}: {e}") df["pred_mos"] = outs out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name df.to_csv(out_file, index=False) shutil.rmtree(tdir) if errors: gr.Warning("⚠️ Mõned failid jäid leidmata või tekkis viga:\n" + "\n".join(errors[:15])) return df, out_file # --------------------------------------------------------------------------- # --------------------------- Gradio UI ------------------------------------- with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo: gr.Markdown( """ # UTMOS-v2 (eesti kõne) Ennustab objektiivse MOS-i **kolme treeningu-domeeni** keskmisena. """ ) # Üksik klipp ----------------------------------------------------------- with gr.Tab("1 WAV fail"): audio = gr.Audio(type="filepath", label="WAV (16 kHz või muu)") out_mos = gr.Number(label="Ennustatud MOS") gr.Button("Hinda").click(single_predict, inputs=audio, outputs=out_mos) # Partii --------------------------------------------------------------- with gr.Tab("Partii (CSV + ZIP)"): csv_in = gr.File(file_types=[".csv"], label="CSV (veerud: faili_nimi)") zip_in = gr.File(file_types=[".zip"], label="ZIP (WAV failid)") df_out = gr.Dataframe(label="Tulemused") file_dl = gr.File(label="Lae CSV ennustustega") gr.Button("Start").click(batch_predict, inputs=[csv_in, zip_in], outputs=[df_out, file_dl]) demo.queue(max_size=10).launch()