rakib730 commited on
Commit
89c09cb
·
verified ·
1 Parent(s): e5b1b0f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
5
+
6
+ # মডেল লোড করো
7
+ model_name = "rakib730/finetuned-gtzan"
8
+ extractor = AutoFeatureExtractor.from_pretrained(model_name)
9
+ model = AutoModelForAudioClassification.from_pretrained(model_name)
10
+
11
+ # মডেলকে eval মোডে নাও
12
+ model.eval()
13
+
14
+ # অডিও ক্লাসিফিকেশন ফাংশন
15
+ def classify_music(audio):
16
+ # audio: (numpy array, sample_rate)
17
+ waveform, sample_rate = audio
18
+ # মডেল ট্রেনিংয়ে ব্যবহৃত sample rate ঠিক করো
19
+ if sample_rate != 16000:
20
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
21
+ waveform = resampler(torch.tensor(waveform))
22
+
23
+ inputs = extractor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
24
+
25
+ with torch.no_grad():
26
+ logits = model(**inputs).logits
27
+ predicted_class_id = torch.argmax(logits, dim=1).item()
28
+ predicted_label = model.config.id2label[predicted_class_id]
29
+
30
+ return predicted_label
31
+
32
+ # Gradio UI
33
+ gr.Interface(
34
+ fn=classify_music,
35
+ inputs=gr.Audio(type="numpy", label="Upload a Music Clip (WAV/MP3)"),
36
+ outputs=gr.Textbox(label="Predicted Genre"),
37
+ title="🎵 Music Genre Classifier",
38
+ description="Upload a short music clip and get the predicted genre using a fine-tuned GTZAN model.",
39
+ live=False
40
+ ).launch()