Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification | |
# মডেল লোড করো | |
model_name = "rakib730/finetuned-gtzan" | |
extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = AutoModelForAudioClassification.from_pretrained(model_name) | |
# মডেলকে eval মোডে নাও | |
model.eval() | |
# অডিও ক্লাসিফিকেশন ফাংশন | |
def classify_music(audio): | |
# audio: (numpy array, sample_rate) | |
waveform, sample_rate = audio | |
# মডেল ট্রেনিংয়ে ব্যবহৃত sample rate ঠিক করো | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resampler(torch.tensor(waveform)) | |
inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_class_id = torch.argmax(logits, dim=1).item() | |
predicted_label = model.config.id2label[predicted_class_id] | |
return predicted_label | |
# Gradio UI | |
gr.Interface( | |
fn=classify_music, | |
inputs=gr.Audio(type="numpy", label="Upload a Music Clip (WAV/MP3)"), | |
outputs=gr.Textbox(label="Predicted Genre"), | |
title="🎵 Music Genre Classifier", | |
description="Upload a short music clip and get the predicted genre using a fine-tuned GTZAN model.", | |
live=False | |
).launch() |