Spaces:
Build error
Build error
import gradio as gr | |
import torch, torchaudio, importlib, pandas as pd, tempfile, zipfile, pathlib, shutil, numpy as np | |
from types import SimpleNamespace | |
from utmosv2.utils import get_model | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MAX_LEN = 160_000 # 10 s @16 kHz | |
# ---- mudel laetakse korra kogu Space’i eluea jooksul ---- | |
cfg_mod = importlib.import_module("utmosv2.config.fusion_stage3") | |
cfg = SimpleNamespace(**{k: getattr(cfg_mod, k) for k in dir(cfg_mod) if not k.startswith("__")}) | |
cfg.phase, cfg.data_config, cfg.print_config = "test", None, False | |
cfg.weight = "YOUR_FINE_TUNED_WEIGHT.ckpt" # lisad checks-pointi repo „Files“ vaatesse | |
model = get_model(cfg, DEVICE).eval() | |
specs_cfg = cfg.dataset.specs | |
def compute_spec(wav: torch.Tensor): | |
views = [] | |
for s in specs_cfg: | |
mel = torchaudio.transforms.MelSpectrogram( | |
sample_rate=16000, n_fft=s.n_fft, hop_length=s.hop_length, | |
win_length=s.win_length, n_mels=s.n_mels).to(DEVICE) | |
db = torchaudio.transforms.AmplitudeToDB()(mel(wav[None]))[0] | |
if db.shape != (512,512): | |
db = torch.nn.functional.interpolate(db[None,None], size=(512,512), | |
mode="bilinear", align_corners=False)[0,0] | |
views.extend([db.repeat(3,1,1)]*2) | |
return torch.stack(views) | |
def single_predict(audio_path, domain, quick): | |
# identne loogika sarulab-speech Space’iga :contentReference[oaicite:0]{index=0} | |
wav, sr = torchaudio.load(audio_path) | |
if sr != 16000: | |
wav = torchaudio.transforms.Resample(sr, 16000)(wav)[0] | |
else: | |
wav = wav[0] | |
wav = (wav[:MAX_LEN] if wav.numel() > MAX_LEN | |
else torch.nn.functional.pad(wav, (0, MAX_LEN - wav.numel()))).to(DEVICE) | |
spec = compute_spec(wav) | |
preds = [] | |
# kui sul on fine-tuningus kasutatud nt 8 domääni, muuda siit | |
NUM_DOMAINS = 3 | |
for dom in range(NUM_DOMAINS): | |
d_oh = torch.nn.functional.one_hot(torch.tensor(dom, device=DEVICE), | |
num_classes=model.num_dataset).float()[None] | |
with torch.no_grad(): | |
p = model(wav[None], spec[None], d_oh).item() | |
preds.append(p) | |
if quick: break | |
return float(np.mean(preds)) | |
def batch_predict(csv_file, wav_zip, num_domains): | |
tdir = tempfile.mkdtemp() | |
with zipfile.ZipFile(wav_zip.name) as zf: | |
zf.extractall(tdir) | |
df = pd.read_csv(csv_file.name) | |
outs = [] | |
for rel in df["audio"]: | |
path = pathlib.Path(tdir) / rel | |
outs.append(single_predict(str(path), "dummy", quick=True)) # domeeni-väärtus ei loe siin | |
df["pred_mos"] = outs | |
out_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv").name | |
df.to_csv(out_file, index=False) | |
shutil.rmtree(tdir) | |
return df, out_file | |
with gr.Blocks(title="UTMOS-v2 MOS-hinnang") as demo: | |
gr.Markdown( | |
""" | |
# UTMOS-v2 | |
Laadi üksik `.wav` või kogu partii ning saa ennustatud MOS-id. | |
Mudel laetakse GPU-le ühe korra, seega järgmised päringud on kiiremad. | |
""" | |
) | |
with gr.Tab("Üksik klipp"): | |
audio = gr.Audio(type="filepath", label="Helifail (16 kHz WAV)") | |
domain = gr.Dropdown(["default"], value="default", | |
label="Domään (valikuline, kui ise muutsid koodi)") | |
quick = gr.Checkbox(value=True, label="Kiire (1 iteratsioon/fold)") | |
out_mos = gr.Number(label="Ennustatud MOS") | |
gr.Button("Hinda").click(single_predict, [audio, domain, quick], out_mos) | |
with gr.Tab("Partii (CSV + ZIP)"): | |
csv_in = gr.File(file_types=[".csv"], label="CSV (audio[, MOS, method])") | |
zip_in = gr.File(file_types=[".zip"], label="ZIP kõikide WAV-idega") | |
n_dom = gr.Number(value=8, precision=0, label="Treeningu domäänide arv") | |
df_out = gr.Dataframe(label="Tulemused") | |
file_dl = gr.File(label="Lae CSV ennustustega") | |
gr.Button("Start").click(batch_predict, | |
[csv_in, zip_in, n_dom], | |
[df_out, file_dl]) | |
demo.queue(max_size=10).launch() | |