cchuan commited on
Commit
236a9d1
·
1 Parent(s): cdcd86e

first commit

Browse files
Files changed (4) hide show
  1. app.py +56 -0
  2. data_setups.py +80 -0
  3. requirements.txt +5 -0
  4. torch_efficientnet_b0_fold4.pth +3 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch, torchaudio
3
+ from timeit import default_timer as timer
4
+ from data_setups import audio_preprocess, resample
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ SAMPLE_RATE = 44100
9
+ AUDIO_LEN = 2.90
10
+
11
+ model = torch.load("torch_efficientnet_b0_fold4.pth", map_location=torch.device('cpu'))
12
+
13
+ CHINESE_LABELS = [
14
+ "大提琴", "單簧管", "長笛", "民謠吉他", "電吉他", "風琴", "鋼琴", "薩克斯風", "喇叭", "小提琴", "人聲"
15
+ ]
16
+
17
+ example_list = [
18
+ "samples/guitar_acoustic.wav",
19
+ "samples/piano.wav",
20
+ "samples/guitar_electric.wav"
21
+ ]
22
+
23
+ def predict(audio_path):
24
+ start_time = timer()
25
+ wavform, sample_rate = torchaudio.load(audio_path)
26
+ wav = resample(wavform, sample_rate, SAMPLE_RATE)
27
+ if len(wav) > int(AUDIO_LEN * SAMPLE_RATE):
28
+ wav = wav[:int(AUDIO_LEN * SAMPLE_RATE)]
29
+ else:
30
+ print(f"input length {len(wav)} too small!, need over {int(AUDIO_LEN * SAMPLE_RATE)}")
31
+ return
32
+ # input Preprocessing
33
+ img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
34
+ print(img.shape)
35
+ model.eval()
36
+ with torch.inference_mode():
37
+ pred_probs = torch.softmax(model(img), dim=1)
38
+ pred_labels_and_probs = {CHINESE_LABELS[i]: float(pred_probs[0][i]) for i in range(len(CHINESE_LABELS))}
39
+ pred_time = round(timer() - start_time, 5)
40
+ return pred_labels_and_probs, pred_time
41
+
42
+
43
+ title = "Musical Instrument Classification 🎺🎸🎹🎻"
44
+ description = "An EfficientNetB0 feature extractor model to classify 11 different musical instruments"
45
+ article = ""
46
+
47
+ demo = gr.Interface(fn=predict,
48
+ inputs=gr.Audio(type="filepath"),
49
+ outputs=[gr.Label(num_top_classes=11, label="Predictions"),
50
+ gr.Number(label="Prediction time (s)")],
51
+ examples=example_list,
52
+ title=title,
53
+ description=description,
54
+ article=article)
55
+
56
+ demo.launch(debug=False)
data_setups.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Make function to find classes in target directory
2
+ import os
3
+ import librosa
4
+ import torch
5
+ import numpy as np
6
+ from torchaudio.transforms import Resample
7
+
8
+ SAMPLE_RATE = 44100
9
+ AUDIO_LEN = 2.90
10
+
11
+ # Parameters to control the MelSpec generation
12
+ N_MELS = 128
13
+ F_MIN = 20
14
+ F_MAX = 16000
15
+ N_FFT = 1024
16
+ HOP_LEN = 512
17
+
18
+ # Make function to find classes in target directory
19
+ def find_classes(directory: str):
20
+ # 1. Get the class names by scanning the target directory
21
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
22
+ # 2. Raise an error if class names not found
23
+ if not classes:
24
+ raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
25
+ # 3. Crearte a dictionary of index labels (computers prefer numerical rather than string labels)
26
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
27
+ return classes, class_to_idx
28
+
29
+ def resample(wav, sample_rate, new_sample_rate):
30
+ if wav.shape[0] >= 2:
31
+ wav = torch.mean(wav, dim=0)
32
+ else:
33
+ wav = wav.squeeze(0)
34
+ if sample_rate > new_sample_rate:
35
+ resampler = Resample(sample_rate, new_sample_rate)
36
+ wav = resampler(wav)
37
+ return wav
38
+
39
+ def mono_to_color(X, eps=1e-6, mean=None, std=None):
40
+ X = np.stack([X, X, X], axis=-1)
41
+ # Standardize
42
+ mean = mean or X.mean()
43
+ std = std or X.std()
44
+ X = (X - mean) / (std + eps)
45
+ # Normalize to [0, 255]
46
+ _min, _max = X.min(), X.max()
47
+ if (_max - _min) > eps:
48
+ V = np.clip(X, _min, _max)
49
+ V = 255 * (V - _min) / (_max - _min)
50
+ V = V.astype(np.uint8)
51
+ else:
52
+ V = np.zeros_like(X, dtype=np.uint8)
53
+ return V
54
+
55
+ def normalize(image, mean=None, std=None):
56
+ image = image / 255.0
57
+ if mean is not None and std is not None:
58
+ image = (image - mean) / std
59
+ return np.moveaxis(image, 2, 0).astype(np.float32)
60
+
61
+ def compute_melspec(wav, sample_rate=SAMPLE_RATE):
62
+ melspec = librosa.feature.melspectrogram(
63
+ y=wav,
64
+ sr=sample_rate,
65
+ n_fft=N_FFT,
66
+ fmin=F_MIN,
67
+ fmax=F_MAX,
68
+ n_mels=N_MELS,
69
+ hop_length=HOP_LEN
70
+ )
71
+ melspec = librosa.power_to_db(melspec).astype(np.float32)
72
+ return melspec
73
+
74
+ def audio_preprocess(wav, sample_rate):
75
+ wav = wav.numpy()
76
+ melspec = compute_melspec(wav, sample_rate)
77
+ image = mono_to_color(melspec)
78
+ image = normalize(image, mean=None, std=None)
79
+ image = torch.from_numpy(image)
80
+ return image
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ torchaudio==0.12.1
4
+ gradio==3.1.4
5
+ librosa==0.9.2
torch_efficientnet_b0_fold4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2083c006db34451c229a820aae94091ca7417856b8601da008b20048ccebfb2e
3
+ size 16419889