slslslrhfem commited on
Commit
508344f
·
1 Parent(s): 617e256

fix some model

Browse files
Files changed (2) hide show
  1. model.py +1 -4
  2. requirements.txt +2 -1
model.py CHANGED
@@ -45,10 +45,7 @@ class MusicAudioClassifier(pl.LightningModule):
45
  def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
46
  B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
47
  x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T]
48
- if self.is_emb == False:
49
- _, embeddings = self.emb_model(x) # [B*S, emb_dim]
50
- else:
51
- embeddings = x
52
  if embeddings.dim() == 3:
53
  pooled_features = embeddings.mean(dim=1) # transformer
54
  else:
 
45
  def _process_audio_batch(self, x: torch.Tensor) -> torch.Tensor:
46
  B, S = x.shape[:2] # [B, S, C, M, T] or [B, S, C, T] for wav, [B, S, 1?, embsize] for emb
47
  x = x.view(B*S, *x.shape[2:]) # [B*S, C, M, T]
48
+ embeddings=x
 
 
 
49
  if embeddings.dim() == 3:
50
  pooled_features = embeddings.mean(dim=1) # transformer
51
  else:
requirements.txt CHANGED
@@ -10,4 +10,5 @@ scipy>=1.10.0
10
  soundfile>=0.12.0
11
  datasets>=2.0.0
12
  accelerate>=0.20.0
13
- spaces
 
 
10
  soundfile>=0.12.0
11
  datasets>=2.0.0
12
  accelerate>=0.20.0
13
+ spaces
14
+ nnAudio