Replace the inplace operation
Browse files- modeling_minicpmo.py +23 -7
modeling_minicpmo.py
CHANGED
|
@@ -377,6 +377,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 377 |
else:
|
| 378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
| 379 |
|
|
|
|
|
|
|
| 380 |
vision_hidden_states = [
|
| 381 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
| 382 |
]
|
|
@@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 392 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 393 |
).to(vllm_embedding.device)
|
| 394 |
|
| 395 |
-
cur_vllm_emb.
|
| 396 |
0,
|
| 397 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 398 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
| 399 |
)
|
|
|
|
| 400 |
elif self.training:
|
| 401 |
-
|
| 402 |
|
| 403 |
-
return
|
| 404 |
|
| 405 |
def get_audio_embedding_streaming(self, data):
|
| 406 |
r"""
|
|
@@ -463,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 463 |
else:
|
| 464 |
return []
|
| 465 |
|
| 466 |
-
def get_audio_embedding(self, data, chunk_length=-1):
|
| 467 |
r"""
|
| 468 |
Extract full audio embeddings with optional chunk-based attention.
|
| 469 |
|
|
@@ -481,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 481 |
Returns:
|
| 482 |
List[List[torch.Tensor]]: audio embeddings
|
| 483 |
"""
|
|
|
|
|
|
|
| 484 |
|
| 485 |
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
| 486 |
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
|
@@ -541,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 541 |
idx += 1
|
| 542 |
final_audio_embeds.append(target_audio_embeds)
|
| 543 |
return final_audio_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
else:
|
| 545 |
return []
|
| 546 |
|
|
@@ -595,7 +611,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 595 |
elif self.training:
|
| 596 |
for i in range(bs):
|
| 597 |
# dummy audio_embeddings
|
| 598 |
-
input_embeddings
|
| 599 |
|
| 600 |
return input_embeddings
|
| 601 |
|
|
@@ -751,7 +767,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 751 |
input_ids=None,
|
| 752 |
pixel_values=None,
|
| 753 |
tgt_sizes=None,
|
| 754 |
-
audio_features=
|
| 755 |
audio_feature_lens=None,
|
| 756 |
image_bound=None,
|
| 757 |
audio_bounds=None,
|
|
@@ -2982,7 +2998,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|
| 2982 |
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
| 2983 |
|
| 2984 |
position_ids = torch.tensor(
|
| 2985 |
-
[past_key_values[0][0].shape[2]
|
| 2986 |
).unsqueeze(0)
|
| 2987 |
|
| 2988 |
cache_position = position_ids.clone()
|
|
|
|
| 377 |
else:
|
| 378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
| 379 |
|
| 380 |
+
new_vllm_embedding = vllm_embedding.clone()
|
| 381 |
+
|
| 382 |
vision_hidden_states = [
|
| 383 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
| 384 |
]
|
|
|
|
| 394 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 395 |
).to(vllm_embedding.device)
|
| 396 |
|
| 397 |
+
new_vllm_embedding[i] = cur_vllm_emb.scatter(
|
| 398 |
0,
|
| 399 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 400 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
| 401 |
)
|
| 402 |
+
|
| 403 |
elif self.training:
|
| 404 |
+
new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
|
| 405 |
|
| 406 |
+
return new_vllm_embedding, vision_hidden_states
|
| 407 |
|
| 408 |
def get_audio_embedding_streaming(self, data):
|
| 409 |
r"""
|
|
|
|
| 466 |
else:
|
| 467 |
return []
|
| 468 |
|
| 469 |
+
def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
|
| 470 |
r"""
|
| 471 |
Extract full audio embeddings with optional chunk-based attention.
|
| 472 |
|
|
|
|
| 484 |
Returns:
|
| 485 |
List[List[torch.Tensor]]: audio embeddings
|
| 486 |
"""
|
| 487 |
+
dtype = self.apm.embed_positions.weight.dtype
|
| 488 |
+
device = self.apm.embed_positions.weight.device
|
| 489 |
|
| 490 |
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
| 491 |
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
|
|
|
| 546 |
idx += 1
|
| 547 |
final_audio_embeds.append(target_audio_embeds)
|
| 548 |
return final_audio_embeds
|
| 549 |
+
elif self.training and dummy:
|
| 550 |
+
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
| 551 |
+
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
| 552 |
+
|
| 553 |
+
audio_embeds = self.audio_projection_layer(audio_states)
|
| 554 |
+
|
| 555 |
+
audio_embeds = audio_embeds.transpose(1, 2)
|
| 556 |
+
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
| 557 |
+
audio_embeds = audio_embeds.transpose(1, 2)
|
| 558 |
+
return [audio_embeds]
|
| 559 |
+
|
| 560 |
else:
|
| 561 |
return []
|
| 562 |
|
|
|
|
| 611 |
elif self.training:
|
| 612 |
for i in range(bs):
|
| 613 |
# dummy audio_embeddings
|
| 614 |
+
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
|
| 615 |
|
| 616 |
return input_embeddings
|
| 617 |
|
|
|
|
| 767 |
input_ids=None,
|
| 768 |
pixel_values=None,
|
| 769 |
tgt_sizes=None,
|
| 770 |
+
audio_features=[],
|
| 771 |
audio_feature_lens=None,
|
| 772 |
image_bound=None,
|
| 773 |
audio_bounds=None,
|
|
|
|
| 2998 |
inputs_embeds = torch.stack(code_emb, 3).sum(3)
|
| 2999 |
|
| 3000 |
position_ids = torch.tensor(
|
| 3001 |
+
[past_key_values[0][0].shape[2]], dtype=torch.long, device=self.device
|
| 3002 |
).unsqueeze(0)
|
| 3003 |
|
| 3004 |
cache_position = position_ids.clone()
|