Spaces:
Build error
Build error
# ---------------- monkey-patch: CUDA-checkpoint → CPU ----------------------- | |
import torch | |
_original_torch_load = torch.load | |
def _cpu_load(*args, **kwargs): | |
kwargs.setdefault("map_location", torch.device("cpu")) | |
return _original_torch_load(*args, **kwargs) | |
torch.load = _cpu_load | |
# --------------------------------------------------------------------------- | |
import os, gradio as gr, torchaudio, importlib, pandas as pd | |
import tempfile, zipfile, pathlib, shutil, numpy as np | |
from types import SimpleNamespace | |
from utmosv2.utils import get_model | |
# ---------- device & perf tweaks ------------------------------------------- | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.set_grad_enabled(False) | |
if DEVICE.type == "cpu": | |
torch.set_num_threads(min(2, os.cpu_count() or 2)) | |
MAX_LEN = 160_000 # 10 s @16 kHz | |
NUM_DOMAINS = 3 # ennustame alati 3 domeeni põhjal | |
# --------------------------------------------------------------------------- | |
# ---------- laadime mudeli korra ------------------------------------------- | |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3") | |
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) | |
if not k.startswith("__")}) | |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False | |
cfg.weight = "utmosv2_estonian.pth" | |
model = get_model(cfg, DEVICE).to(DEVICE).eval() | |
specs_cfg = cfg.dataset.specs | |
# --------------------------------------------------------------------------- | |
def compute_spec(wav: torch.Tensor) -> torch.Tensor: | |
"""Tagastab [V,3,512,512] multi-view mel-spectrogrammi.""" | |
views = [] | |
for s in specs_cfg: | |
mel = torchaudio.transforms.MelSpectrogram( | |
sample_rate=16_000, n_fft=s.n_fft, | |
hop_length=s.hop_length, win_length=s.win_length, | |
n_mels=s.n_mels).to(DEVICE) | |
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0] | |
if db.shape != (512, 512): | |
db = torch.nn.functional.interpolate(db[None, None], | |
size=(512, 512), | |
mode="bilinear", | |
align_corners=False)[0, 0] | |
views.extend([db.repeat(3, 1, 1)] * 2) | |
return torch.stack(views) | |
def single_predict(audio_path: str) -> float: | |
"""MOS ühe faili kohta – keskmistab 3 domeeni.""" | |
wav, sr = torchaudio.load(audio_path) | |
if sr != 16_000: | |
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0] | |
else: | |
wav = wav[0] | |
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN | |
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel())) | |
).to(DEVICE) | |
spec = compute_spec(wav) | |
preds = [] | |
for dom in range(NUM_DOMAINS): | |
dom_oh = torch.nn.functional.one_hot( | |
torch.tensor(dom, device=DEVICE), | |
num_classes=model.num_dataset | |
).float()[None] | |
preds.append(model(wav[None], spec[None], dom_oh).item()) | |
return float(np.mean(preds)) | |
# ---------- abifunktsioon faili otsimiseks --------------------------------- | |
def locate_audio(root: pathlib.Path, rel_path: str) -> pathlib.Path: | |
""" | |
Leia helifail ajutisest kataloogist. | |
1) root / rel_path | |
2) kui ei leidu, otsi failinime järgi kogu puust. | |
""" | |
rel_path = rel_path.strip().lstrip("/\\") | |
direct = root / rel_path | |
if direct.is_file(): | |
return direct | |
matches = list(root.rglob(pathlib.Path(rel_path).name)) | |
if matches: | |
return matches[0] | |
raise FileNotFoundError(f"'{rel_path}' ei leitud ZIP-ist – " | |
f"kontrolli CSV ja ZIP teede vastavust.") | |
# ---------- partii-töötlus -------------------------------------------------- | |
def batch_predict(csv_file, wav_zip): | |
"""Loeb CSV & ZIP, lisab 'pred_mos', annab CSV-i ja tabeli tagasi.""" | |
tdir = tempfile.mkdtemp() | |
with zipfile.ZipFile(wav_zip.name) as zf: | |
zf.extractall(tdir) | |
df = pd.read_csv(csv_file.name) | |
col = "faili_nimi" | |
outs, errors = [], [] | |
for rel in df[col]: | |
try: | |
full = locate_audio(pathlib.Path(tdir), str(rel)) | |
outs.append(single_predict(str(full))) | |
except Exception as e: | |
outs.append(np.nan) | |
errors.append(f"{rel}: {e}") | |
df["pred_mos"] = outs | |
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name | |
df.to_csv(out_file, index=False) | |
shutil.rmtree(tdir) | |
if errors: | |
gr.Warning("⚠️ Mõned failid jäid leidmata või tekkis viga:\n" + | |
"\n".join(errors[:15])) | |
return df, out_file | |
# --------------------------------------------------------------------------- | |
# --------------------------- Gradio UI ------------------------------------- | |
with gr.Blocks(title="UTMOS-v2 MOS-hinnang (3 domeeni)") as demo: | |
gr.Markdown( | |
""" | |
# UTMOS-v2 (eesti kõne) | |
Ennustab objektiivse MOS-i **kolme treeningu-domeeni** keskmisena. | |
""" | |
) | |
# Üksik klipp ----------------------------------------------------------- | |
with gr.Tab("1 WAV fail"): | |
audio = gr.Audio(type="filepath", label="WAV (16 kHz või muu)") | |
out_mos = gr.Number(label="Ennustatud MOS") | |
gr.Button("Hinda").click(single_predict, inputs=audio, outputs=out_mos) | |
# Partii --------------------------------------------------------------- | |
with gr.Tab("Partii (CSV + ZIP)"): | |
csv_in = gr.File(file_types=[".csv"], label="CSV (veerud: faili_nimi)") | |
zip_in = gr.File(file_types=[".zip"], label="ZIP (WAV failid)") | |
df_out = gr.Dataframe(label="Tulemused") | |
file_dl = gr.File(label="Lae CSV ennustustega") | |
gr.Button("Start").click(batch_predict, | |
inputs=[csv_in, zip_in], | |
outputs=[df_out, file_dl]) | |
demo.queue(max_size=10).launch() | |