Spaces:
Build error
Build error
improve model accuracy
Browse files- .gitignore +2 -1
- app.py +10 -5
- torch_efficientnet_b0_fold4.pth → models/torch_efficientnet_b0_fold4.pth +0 -0
- models/torch_efficientnet_fold2_CNN.pth +3 -0
- samples/flute.wav +3 -0
- samples/violin.wav +3 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
venv
|
2 |
-
__pycache__
|
|
|
|
1 |
venv
|
2 |
+
__pycache__
|
3 |
+
flagged
|
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch, torchaudio
|
3 |
from timeit import default_timer as timer
|
@@ -8,16 +9,21 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
8 |
SAMPLE_RATE = 44100
|
9 |
AUDIO_LEN = 2.90
|
10 |
|
11 |
-
model = torch.load("
|
12 |
|
13 |
CHINESE_LABELS = [
|
14 |
"大提琴", "單簧管", "長笛", "民謠吉他", "電吉他", "風琴", "鋼琴", "薩克斯風", "喇叭", "小提琴", "人聲"
|
15 |
]
|
16 |
|
|
|
|
|
|
|
|
|
17 |
example_list = [
|
18 |
"samples/guitar_acoustic.wav",
|
19 |
"samples/piano.wav",
|
20 |
-
"samples/
|
|
|
21 |
]
|
22 |
|
23 |
def predict(audio_path):
|
@@ -31,7 +37,6 @@ def predict(audio_path):
|
|
31 |
return
|
32 |
# input Preprocessing
|
33 |
img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
|
34 |
-
print(img.shape)
|
35 |
model.eval()
|
36 |
with torch.inference_mode():
|
37 |
pred_probs = torch.softmax(model(img), dim=1)
|
@@ -40,8 +45,8 @@ def predict(audio_path):
|
|
40 |
return pred_labels_and_probs, pred_time
|
41 |
|
42 |
|
43 |
-
title = "
|
44 |
-
description = "
|
45 |
article = ""
|
46 |
|
47 |
demo = gr.Interface(fn=predict,
|
|
|
1 |
+
# -*- coding: UTF-8 -*-
|
2 |
import gradio as gr
|
3 |
import torch, torchaudio
|
4 |
from timeit import default_timer as timer
|
|
|
9 |
SAMPLE_RATE = 44100
|
10 |
AUDIO_LEN = 2.90
|
11 |
|
12 |
+
model = torch.load("models/torch_efficientnet_fold2_CNN.pth", map_location=torch.device('cpu'))
|
13 |
|
14 |
CHINESE_LABELS = [
|
15 |
"大提琴", "單簧管", "長笛", "民謠吉他", "電吉他", "風琴", "鋼琴", "薩克斯風", "喇叭", "小提琴", "人聲"
|
16 |
]
|
17 |
|
18 |
+
LABELS = [
|
19 |
+
"Cello", "Clarinet", "Flute", "Acoustic Guitar", "Electric Guitar", "Organ", "Piano", "Saxophone", "Trumpet", "Violin", "Voice"
|
20 |
+
]
|
21 |
+
|
22 |
example_list = [
|
23 |
"samples/guitar_acoustic.wav",
|
24 |
"samples/piano.wav",
|
25 |
+
"samples/violin.wav",
|
26 |
+
"samples/flute.wav"
|
27 |
]
|
28 |
|
29 |
def predict(audio_path):
|
|
|
37 |
return
|
38 |
# input Preprocessing
|
39 |
img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
|
|
|
40 |
model.eval()
|
41 |
with torch.inference_mode():
|
42 |
pred_probs = torch.softmax(model(img), dim=1)
|
|
|
45 |
return pred_labels_and_probs, pred_time
|
46 |
|
47 |
|
48 |
+
title = "樂器辨識🎺🎸🎹🎻"
|
49 |
+
description = "使用IRMAS資料集訓練的深度學習模型,可辨識11種不同樂器,包含「大提琴, 單簧管, 長笛, 民謠吉他, 電吉他, 風琴, 鋼琴, 薩克斯風, 喇叭, 小提琴, 人聲」"
|
50 |
article = ""
|
51 |
|
52 |
demo = gr.Interface(fn=predict,
|
torch_efficientnet_b0_fold4.pth → models/torch_efficientnet_b0_fold4.pth
RENAMED
File without changes
|
models/torch_efficientnet_fold2_CNN.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a55dbb25c9a1678bd3b5d2968695923931564e5b4e04839c5836a3ee5421c1a
|
3 |
+
size 16419953
|
samples/flute.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2aaa6c5640106826a4db1d7932f9edc3b0fbb0c68cbd4e7d7d544d2fdc28af17
|
3 |
+
size 3528044
|
samples/violin.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:690365b52ee8ca9f7b0147247270e375d70be31512c3ae591e52bf55605d3ece
|
3 |
+
size 19105034
|