yuzaa BUAADreamer commited on
Commit
2ea48c8
·
verified ·
1 Parent(s): 686e87a

support audio finetuning (#22)

Browse files

- support audio finetuning (29a824ea848aa2a72b26706a1c641bc339a51869)


Co-authored-by: Zhangchi Feng <[email protected]>

Files changed (1) hide show
  1. modeling_minicpmo.py +14 -1
modeling_minicpmo.py CHANGED
@@ -466,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
466
  else:
467
  return []
468
 
469
- def get_audio_embedding(self, data, chunk_length=-1):
470
  r"""
471
  Extract full audio embeddings with optional chunk-based attention.
472
 
@@ -484,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
484
  Returns:
485
  List[List[torch.Tensor]]: audio embeddings
486
  """
 
 
487
 
488
  wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
489
  audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
@@ -544,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
544
  idx += 1
545
  final_audio_embeds.append(target_audio_embeds)
546
  return final_audio_embeds
 
 
 
 
 
 
 
 
 
 
 
547
  else:
548
  return []
549
 
 
466
  else:
467
  return []
468
 
469
+ def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
470
  r"""
471
  Extract full audio embeddings with optional chunk-based attention.
472
 
 
484
  Returns:
485
  List[List[torch.Tensor]]: audio embeddings
486
  """
487
+ dtype = self.apm.embed_positions.weight.dtype
488
+ device = self.apm.embed_positions.weight.device
489
 
490
  wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
491
  audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
 
546
  idx += 1
547
  final_audio_embeds.append(target_audio_embeds)
548
  return final_audio_embeds
549
+ elif self.training and dummy:
550
+ dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
551
+ audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
552
+
553
+ audio_embeds = self.audio_projection_layer(audio_states)
554
+
555
+ audio_embeds = audio_embeds.transpose(1, 2)
556
+ audio_embeds = self.audio_avg_pooler(audio_embeds)
557
+ audio_embeds = audio_embeds.transpose(1, 2)
558
+ return [audio_embeds]
559
+
560
  else:
561
  return []
562