import gradio as gr import torch import torch.nn as nn import numpy as np import torchaudio from types import SimpleNamespace import importlib from utmosv2.utils import get_model # ⚙️ Konstandid DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_LEN = 160000 # 10 sekundit @16kHz NUM_DOMAINS = 3 # <- muuda vastavalt oma treeningule MODEL_WEIGHT = "utmosv2_finetuned.ckpt" # <- lisa see fail Space'i # 🔁 Lae konfik ja mudel cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3") cfg = SimpleNamespace(**{k:getattr(cfg_mod,k) for k in dir(cfg_mod) if not k.startswith("__")}) cfg.print_config = False cfg.data_config = None cfg.phase = "test" cfg.weight = MODEL_WEIGHT specs_cfg = cfg.dataset.specs model = get_model(cfg, DEVICE).to(DEVICE) model.eval() # 🔊 Spectrogrammide arvutus def compute_spectrogram(wav: torch.Tensor, specs_cfg, device: torch.device, num_frames: int = 2, target_size=(512,512)) -> torch.Tensor: views = [] for spec in specs_cfg: mel = torchaudio.transforms.MelSpectrogram( sample_rate=16000, n_fft=spec.n_fft, hop_length=spec.hop_length, win_length=spec.win_length, n_mels=spec.n_mels ).to(device) wav_in = wav.unsqueeze(0) if wav.dim()==1 else wav spec_t = mel(wav_in) db = torchaudio.transforms.AmplitudeToDB()(spec_t).squeeze(0) if db.shape != target_size: db = nn.functional.interpolate( db.unsqueeze(0).unsqueeze(0), size=target_size, mode='bilinear', align_corners=False ).squeeze() for _ in range(num_frames): views.append(db.unsqueeze(0).repeat(3,1,1)) return torch.stack(views, dim=0) # 🌟 MOS ennustus def predict_mos(audio_file): wav, sr = torchaudio.load(audio_file) if sr != 16000: wav = torchaudio.transforms.Resample(sr, 16000)(wav) wav = wav[0] if wav.shape[0] > MAX_LEN: wav = wav[:MAX_LEN] else: wav = nn.functional.pad(wav, (0, MAX_LEN - wav.shape[0])) wav = wav.to(DEVICE) spec = compute_spectrogram(wav, specs_cfg, DEVICE) wav_b = wav.unsqueeze(0) spec_b = spec.unsqueeze(0) preds = [] for dom in range(NUM_DOMAINS): dom_oh = torch.nn.functional.one_hot( torch.tensor(dom, device=DEVICE), num_classes=model.num_dataset ).float().unsqueeze(0) with torch.no_grad(): p = model(wav_b, spec_b, dom_oh).squeeze().item() preds.append(p) avg_pred = float(np.mean(preds)) return f"Predicted MOS: {avg_pred:.2f}" # 🎛️ Gradio liides demo = gr.Interface( fn=predict_mos, inputs=gr.Audio(type="filepath", label="Upload WAV file (16kHz mono preferred)"), outputs="text", title="UTMOSv2 - Speech Quality Estimator", description="Estimate the Mean Opinion Score (MOS) of synthetic speech using UTMOSv2" ) if __name__ == "__main__": demo.launch()