monatolmats's picture
Update app.py
dbf8a9e verified
raw
history blame
4.19 kB
import gradio as gr
import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np
from types import SimpleNamespace
from utmosv2.utils import get_model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 160_000 # 10 s @16 kHz
# ---- mudel laetakse korra kogu Space’i eluea jooksul ----
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3")
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")})
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False
cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # lisad checks-pointi repo „Files“ vaatesse
model = get_model(cfg, DEVICE).eval()
specs_cfg = cfg.dataset.specs
def compute_spec(wav: torch.Tensor):
views = []
for s in specs_cfg:
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=s.n_fft, hop_length=s.hop_length,
win_length=s.win_length, n_mels=s.n_mels).to(DEVICE)
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0]
if db.shape != (512,512):
db = torch.nn.functional.interpolate(db[None,None], size=(512,512),
mode="bilinear", align_corners=False)[0,0]
views.extend([db.repeat(3,1,1)]*2)
return torch.stack(views)
def single_predict(audio_path, domain, quick):
# identne loogika sarulab-speech Space’iga :contentReference[oaicite:0]{index=0}
wav, sr = torchaudio.load(audio_path)
if sr != 16000:
wav = torchaudio.transforms.Resample(sr, 16000)(wav)[0]
else:
wav = wav[0]
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE)
spec = compute_spec(wav)
preds = []
# kui sul on fine-tuningus kasutatud nt 8 domääni, muuda siit
NUM_DOMAINS = 3
for dom in range(NUM_DOMAINS):
d_oh = torch.nn.functional.one_hot(torch.tensor(dom, device=DEVICE),
num_classes=model.num_dataset).float()[None]
with torch.no_grad():
p = model(wav[None], spec[None], d_oh).item()
preds.append(p)
if quick: break
return float(np.mean(preds))
def batch_predict(csv_file, wav_zip, num_domains):
tdir = tempfile.mkdtemp()
with zipfile.ZipFile(wav_zip.name) as zf:
zf.extractall(tdir)
df = pd.read_csv(csv_file.name)
outs = []
for rel in df["audio"]:
path = pathlib.Path(tdir) / rel
outs.append(single_predict(str(path), "dummy", quick=True)) # domeeni-väärtus ei loe siin
df["pred_mos"] = outs
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name
df.to_csv(out_file, index=False)
shutil.rmtree(tdir)
return df, out_file
with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo:
gr.Markdown(
"""
# UTMOS-v2
Laadi üksik `.wav` või kogu partii ning saa ennustatud MOS-id.
Mudel laetakse GPU-le ühe korra, seega järgmised päringud on kiiremad.
"""
)
with gr.Tab("Üksik klipp"):
audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)")
domain = gr.Dropdown(["default"], value="default",
label="Domään (valikuline, kui ise muutsid koodi)")
quick = gr.Checkbox(value=True, label="Kiire (1 iteratsioon/fold)")
out_mos = gr.Number(label="Ennustatud MOS")
gr.Button("Hinda").click(single_predict, [audio, domain, quick], out_mos)
with gr.Tab("Partii (CSV + ZIP)"):
csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])")
zip_in = gr.File(file_types=[".zip"], label="ZIP kõikide WAV-idega")
n_dom = gr.Number(value=8, precision=0, label="Treeningu domäänide arv")
df_out = gr.Dataframe(label="Tulemused")
file_dl = gr.File(label="Lae CSV ennustustega")
gr.Button("Start").click(batch_predict,
[csv_in, zip_in, n_dom],
[df_out, file_dl])
demo.queue(max_size=10).launch()