monatolmats's picture
Update app.py
d5b5af5 verified
raw
history blame
5.67 kB
import os
import gradio as gr
import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
from types import SimpleNamespace
from utmosv2.utils import get_model
# ----------------------------------------------------------
# Seadme valik – GPU kui olemas, muidu CPU
# ----------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False) # inference - pole backprop’i vaja
# Tasuta CPU-Space: 2 vCPU → piirame lõimede arvu
if DEVICE.type == "cpu":
torch.set_num_threads(min(2, os.cpu_count() or 2))
MAX_LEN = 160_000 # 10 s @16 kHz (kui vaja kiiremaks, vähenda nt 80 000)
# ----------------------------------------------------------
# Laeme mudeli korra Space'i käivitumisel
# ----------------------------------------------------------
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
cfg.weight = "utmosv2_estonian.pth" # ← sinu checkpointi fail
model = get_model(cfg, DEVICE).to(DEVICE).eval()
specs_cfg = cfg.dataset.specs # mel-vaadete konfiguratsioon
# ----------------------------------------------------------
# Abifunktsioonid
# ----------------------------------------------------------
def compute_spec(wav: torch.Tensor) -> torch.Tensor:
"""Loome mitmevaatelise mel-spectrogrammi [V,3,512,512]."""
views = []
for s in specs_cfg:
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=16_000,
n_fft=s.n_fft,
hop_length=s.hop_length,
win_length=s.win_length,
n_mels=s.n_mels
).to(DEVICE)
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
if db.shape != (512, 512):
db = torch.nn.functional.interpolate(db[None, None],
size=(512, 512),
mode="bilinear",
align_corners=False)[0, 0]
# kaks “külge” sama spektriga – sama logika nagu originaalmudelis
views.extend([db.repeat(3, 1, 1)] * 2)
return torch.stack(views) # [V, 3, 512, 512]
def single_predict(audio_path, domain, quick):
"""Ennusta ühe WAV-i MOS; kui quick=True, kasutab vaid esimest domääni."""
wav, sr = torchaudio.load(audio_path)
if sr != 16_000:
wav = torchaudio.transforms.Resample(sr, 16_000)(wav)[0]
else:
wav = wav[0]
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
spec = compute_spec(wav)
NUM_DOMAINS = 3 # mitu domääni treeningus kasutati
preds = []
for dom in range(NUM_DOMAINS):
d_oh = torch.nn.functional.one_hot(
torch.tensor(dom, device=DEVICE),
num_classes=model.num_dataset
).float()[None]
preds.append(model(wav[None], spec[None], d_oh).item())
if quick: # quick ⇒ ainult esimene domään
break
return float(np.mean(preds))
def batch_predict(csv_file, wav_zip, num_domains):
"""Laeb ZIP-i, arvutab kõikidele CSV-is loetletud failidele MOS-i."""
tdir = tempfile.mkdtemp()
with zipfile.ZipFile(wav_zip.name) as zf:
zf.extractall(tdir)
df = pd.read_csv(csv_file.name)
outs = []
quick = True if int(num_domains) == 1 else False
for rel in df["audio"]:
path = pathlib.Path(tdir) / rel
outs.append(single_predict(str(path), "dummy", quick))
df["pred_mos"] = outs
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
df.to_csv(out_file, index=False)
shutil.rmtree(tdir)
return df, out_file
# ----------------------------------------------------------
# Gradio kasutajaliides
# ----------------------------------------------------------
with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
gr.Markdown(
"""
# UTMOS-v2 (eesti kõne)
Ennusta objektiivseid MOS-skoore üksikutele või suurele failihulgale.
Mudel laetakse mällu korra Space'i käivitumisel (CPU-s töötab ~ paar s / klipp).
"""
)
# --- üksik WAV ---
with gr.Tab("Üksik klipp"):
audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
quick = gr.Checkbox(value=True, label="Kiire režiim (1 domään)")
out_mos = gr.Number(label="Ennustatud MOS")
gr.Button("Hinda").click(fn=lambda a, q: single_predict(a, "default", q),
inputs=[audio, quick],
outputs=out_mos)
# --- partii: CSV + ZIP ---
with gr.Tab("Partii (CSV + ZIP)"):
csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
zip_in = gr.File(file_types=[".zip"], label="ZIP kõigi WAV-idega")
n_dom = gr.Number(value=3, precision=0,
label="Domäänide arv (1 = quick, >1 = täiskeskmistamine)")
df_out = gr.Dataframe(label="Tulemused")
file_dl = gr.File(label="Lae ennustused CSV-na")
gr.Button("Start").click(batch_predict,
inputs=[csv_in, zip_in, n_dom],
outputs=[df_out, file_dl])
# Queue → võimaldab järjekorda; launch() ilma share-liputa
demo.queue(max_size=10).launch()