PartyMusicAgent / app.py
rakib730's picture
Create app.py
89c09cb verified
raw
history blame
1.51 kB
import gradio as gr
import torch
import torchaudio
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
# মডেল লোড করো
model_name = "rakib730/finetuned-gtzan"
extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)
# মডেলকে eval মোডে নাও
model.eval()
# অডিও ক্লাসিফিকেশন ফাংশন
def classify_music(audio):
# audio: (numpy array, sample_rate)
waveform, sample_rate = audio
# মডেল ট্রেনিংয়ে ব্যবহৃত sample rate ঠিক করো
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(torch.tensor(waveform))
inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = torch.argmax(logits, dim=1).item()
predicted_label = model.config.id2label[predicted_class_id]
return predicted_label
# Gradio UI
gr.Interface(
fn=classify_music,
inputs=gr.Audio(type="numpy", label="Upload a Music Clip (WAV/MP3)"),
outputs=gr.Textbox(label="Predicted Genre"),
title="🎵 Music Genre Classifier",
description="Upload a short music clip and get the predicted genre using a fine-tuned GTZAN model.",
live=False
).launch()