cchaun commited on
Commit
a2148ac
·
1 Parent(s): b63864d

improve model accuracy

Browse files
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  venv
2
- __pycache__
 
 
1
  venv
2
+ __pycache__
3
+ flagged
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import torch, torchaudio
3
  from timeit import default_timer as timer
@@ -8,16 +9,21 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
8
  SAMPLE_RATE = 44100
9
  AUDIO_LEN = 2.90
10
 
11
- model = torch.load("torch_efficientnet_b0_fold4.pth", map_location=torch.device('cpu'))
12
 
13
  CHINESE_LABELS = [
14
  "大提琴", "單簧管", "長笛", "民謠吉他", "電吉他", "風琴", "鋼琴", "薩克斯風", "喇叭", "小提琴", "人聲"
15
  ]
16
 
 
 
 
 
17
  example_list = [
18
  "samples/guitar_acoustic.wav",
19
  "samples/piano.wav",
20
- "samples/guitar_electric.wav"
 
21
  ]
22
 
23
  def predict(audio_path):
@@ -31,7 +37,6 @@ def predict(audio_path):
31
  return
32
  # input Preprocessing
33
  img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
34
- print(img.shape)
35
  model.eval()
36
  with torch.inference_mode():
37
  pred_probs = torch.softmax(model(img), dim=1)
@@ -40,8 +45,8 @@ def predict(audio_path):
40
  return pred_labels_and_probs, pred_time
41
 
42
 
43
- title = "Musical Instrument Classification 🎺🎸🎹🎻"
44
- description = "An EfficientNetB0 feature extractor model to classify 11 different musical instruments"
45
  article = ""
46
 
47
  demo = gr.Interface(fn=predict,
 
1
+ # -*- coding: UTF-8 -*-
2
  import gradio as gr
3
  import torch, torchaudio
4
  from timeit import default_timer as timer
 
9
  SAMPLE_RATE = 44100
10
  AUDIO_LEN = 2.90
11
 
12
+ model = torch.load("models/torch_efficientnet_fold2_CNN.pth", map_location=torch.device('cpu'))
13
 
14
  CHINESE_LABELS = [
15
  "大提琴", "單簧管", "長笛", "民謠吉他", "電吉他", "風琴", "鋼琴", "薩克斯風", "喇叭", "小提琴", "人聲"
16
  ]
17
 
18
+ LABELS = [
19
+ "Cello", "Clarinet", "Flute", "Acoustic Guitar", "Electric Guitar", "Organ", "Piano", "Saxophone", "Trumpet", "Violin", "Voice"
20
+ ]
21
+
22
  example_list = [
23
  "samples/guitar_acoustic.wav",
24
  "samples/piano.wav",
25
+ "samples/violin.wav",
26
+ "samples/flute.wav"
27
  ]
28
 
29
  def predict(audio_path):
 
37
  return
38
  # input Preprocessing
39
  img = audio_preprocess(wav, SAMPLE_RATE).unsqueeze(0)
 
40
  model.eval()
41
  with torch.inference_mode():
42
  pred_probs = torch.softmax(model(img), dim=1)
 
45
  return pred_labels_and_probs, pred_time
46
 
47
 
48
+ title = "樂器辨識🎺🎸🎹🎻"
49
+ description = "使用IRMAS資料集訓練的深度學習模型,可辨識11種不同樂器,包含「大提琴, 單簧管, 長笛, 民謠吉他, 電吉他, 風琴, 鋼琴, 薩克斯風, 喇叭, 小提琴, 人聲」"
50
  article = ""
51
 
52
  demo = gr.Interface(fn=predict,
torch_efficientnet_b0_fold4.pth → models/torch_efficientnet_b0_fold4.pth RENAMED
File without changes
models/torch_efficientnet_fold2_CNN.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a55dbb25c9a1678bd3b5d2968695923931564e5b4e04839c5836a3ee5421c1a
3
+ size 16419953
samples/flute.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aaa6c5640106826a4db1d7932f9edc3b0fbb0c68cbd4e7d7d544d2fdc28af17
3
+ size 3528044
samples/violin.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:690365b52ee8ca9f7b0147247270e375d70be31512c3ae591e52bf55605d3ece
3
+ size 19105034