Spaces:
Build error
Build error
cchuan
commited on
Commit
·
236a9d1
1
Parent(s):
cdcd86e
first commit
Browse files- app.py +56 -0
- data_setups.py +80 -0
- requirements.txt +5 -0
- torch_efficientnet_b0_fold4.pth +3 -0
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch, torchaudio
|
3 |
+
from timeit import default_timer as timer
|
4 |
+
from data_setups import audio_preprocess, resample
|
5 |
+
|
6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
7 |
+
|
8 |
+
SAMPLE_RATE = 44100
|
9 |
+
AUDIO_LEN = 2.90
|
10 |
+
|
11 |
+
model = torch.load("torch_efficientnet_b0_fold4.pth", map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
CHINESE_LABELS = [
|
14 |
+
"大提琴", "單簧管", "長笛", "民謠吉他", "電吉他", "風琴", "鋼琴", "薩克斯風", "喇叭", "小提琴", "人聲"
|
15 |
+
]
|
16 |
+
|
17 |
+
example_list = [
|
18 |
+
"samples/guitar_acoustic.wav",
|
19 |
+
"samples/piano.wav",
|
20 |
+
"samples/guitar_electric.wav"
|
21 |
+
]
|
22 |
+
|
23 |
+
def predict(audio_path):
|
24 |
+
start_time = timer()
|
25 |
+
wavform, sample_rate = torchaudio.load(audio_path)
|
26 |
+
wav = resample(wavform, sample_rate, SAMPLE_RATE)
|
27 |
+
if len(wav) > int(AUDIO_LEN * SAMPLE_RATE):
|
28 |
+
wav = wav[:int(AUDIO_LEN * SAMPLE_RATE)]
|
29 |
+
else:
|
30 |
+
print(f"input length {len(wav)} too small!, need over {int(AUDIO_LEN * SAMPLE_RATE)}")
|
31 |
+
return
|
32 |
+
# input Preprocessing
|
33 |
+
img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
|
34 |
+
print(img.shape)
|
35 |
+
model.eval()
|
36 |
+
with torch.inference_mode():
|
37 |
+
pred_probs = torch.softmax(model(img), dim=1)
|
38 |
+
pred_labels_and_probs = {CHINESE_LABELS[i]: float(pred_probs[0][i]) for i in range(len(CHINESE_LABELS))}
|
39 |
+
pred_time = round(timer() - start_time, 5)
|
40 |
+
return pred_labels_and_probs, pred_time
|
41 |
+
|
42 |
+
|
43 |
+
title = "Musical Instrument Classification 🎺🎸🎹🎻"
|
44 |
+
description = "An EfficientNetB0 feature extractor model to classify 11 different musical instruments"
|
45 |
+
article = ""
|
46 |
+
|
47 |
+
demo = gr.Interface(fn=predict,
|
48 |
+
inputs=gr.Audio(type="filepath"),
|
49 |
+
outputs=[gr.Label(num_top_classes=11, label="Predictions"),
|
50 |
+
gr.Number(label="Prediction time (s)")],
|
51 |
+
examples=example_list,
|
52 |
+
title=title,
|
53 |
+
description=description,
|
54 |
+
article=article)
|
55 |
+
|
56 |
+
demo.launch(debug=False)
|
data_setups.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Make function to find classes in target directory
|
2 |
+
import os
|
3 |
+
import librosa
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from torchaudio.transforms import Resample
|
7 |
+
|
8 |
+
SAMPLE_RATE = 44100
|
9 |
+
AUDIO_LEN = 2.90
|
10 |
+
|
11 |
+
# Parameters to control the MelSpec generation
|
12 |
+
N_MELS = 128
|
13 |
+
F_MIN = 20
|
14 |
+
F_MAX = 16000
|
15 |
+
N_FFT = 1024
|
16 |
+
HOP_LEN = 512
|
17 |
+
|
18 |
+
# Make function to find classes in target directory
|
19 |
+
def find_classes(directory: str):
|
20 |
+
# 1. Get the class names by scanning the target directory
|
21 |
+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
|
22 |
+
# 2. Raise an error if class names not found
|
23 |
+
if not classes:
|
24 |
+
raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
|
25 |
+
# 3. Crearte a dictionary of index labels (computers prefer numerical rather than string labels)
|
26 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
27 |
+
return classes, class_to_idx
|
28 |
+
|
29 |
+
def resample(wav, sample_rate, new_sample_rate):
|
30 |
+
if wav.shape[0] >= 2:
|
31 |
+
wav = torch.mean(wav, dim=0)
|
32 |
+
else:
|
33 |
+
wav = wav.squeeze(0)
|
34 |
+
if sample_rate > new_sample_rate:
|
35 |
+
resampler = Resample(sample_rate, new_sample_rate)
|
36 |
+
wav = resampler(wav)
|
37 |
+
return wav
|
38 |
+
|
39 |
+
def mono_to_color(X, eps=1e-6, mean=None, std=None):
|
40 |
+
X = np.stack([X, X, X], axis=-1)
|
41 |
+
# Standardize
|
42 |
+
mean = mean or X.mean()
|
43 |
+
std = std or X.std()
|
44 |
+
X = (X - mean) / (std + eps)
|
45 |
+
# Normalize to [0, 255]
|
46 |
+
_min, _max = X.min(), X.max()
|
47 |
+
if (_max - _min) > eps:
|
48 |
+
V = np.clip(X, _min, _max)
|
49 |
+
V = 255 * (V - _min) / (_max - _min)
|
50 |
+
V = V.astype(np.uint8)
|
51 |
+
else:
|
52 |
+
V = np.zeros_like(X, dtype=np.uint8)
|
53 |
+
return V
|
54 |
+
|
55 |
+
def normalize(image, mean=None, std=None):
|
56 |
+
image = image / 255.0
|
57 |
+
if mean is not None and std is not None:
|
58 |
+
image = (image - mean) / std
|
59 |
+
return np.moveaxis(image, 2, 0).astype(np.float32)
|
60 |
+
|
61 |
+
def compute_melspec(wav, sample_rate=SAMPLE_RATE):
|
62 |
+
melspec = librosa.feature.melspectrogram(
|
63 |
+
y=wav,
|
64 |
+
sr=sample_rate,
|
65 |
+
n_fft=N_FFT,
|
66 |
+
fmin=F_MIN,
|
67 |
+
fmax=F_MAX,
|
68 |
+
n_mels=N_MELS,
|
69 |
+
hop_length=HOP_LEN
|
70 |
+
)
|
71 |
+
melspec = librosa.power_to_db(melspec).astype(np.float32)
|
72 |
+
return melspec
|
73 |
+
|
74 |
+
def audio_preprocess(wav, sample_rate):
|
75 |
+
wav = wav.numpy()
|
76 |
+
melspec = compute_melspec(wav, sample_rate)
|
77 |
+
image = mono_to_color(melspec)
|
78 |
+
image = normalize(image, mean=None, std=None)
|
79 |
+
image = torch.from_numpy(image)
|
80 |
+
return image
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.0
|
2 |
+
torchvision==0.13.0
|
3 |
+
torchaudio==0.12.1
|
4 |
+
gradio==3.1.4
|
5 |
+
librosa==0.9.2
|
torch_efficientnet_b0_fold4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2083c006db34451c229a820aae94091ca7417856b8601da008b20048ccebfb2e
|
3 |
+
size 16419889
|