zhoukz commited on
Commit
963f924
·
1 Parent(s): 0939826

Upload folder using huggingface_hub

Browse files
chat_template.jinja CHANGED
@@ -1,25 +1,7 @@
1
- {%- for message in messages -%}
2
- {%- if loop.first and message["role"] != "system" -%}
3
- {{- "<|im_start|>system\nYou are a helpful language and speech assistant.<|im_end|>\n" -}}
4
- {%- endif -%}
5
- {{- "<|im_start|>" -}}
6
- {{- message["role"] -}}
7
- {{- "\n" -}}
8
- {%- if message["content"] is string -%}
9
- {{- message["content"] -}}
10
- {%- else -%}
11
- {%- for content in message["content"] -%}
12
- {%- if content["type"] == "text" -%}
13
- {{- content["text"] -}}
14
- {%- elif content["type"] == "audio" -%}
15
- {{- "<|audio_bos|><|AUDIO|><|audio_eos|>" -}}
16
- {%- endif -%}
17
- {%- endfor -%}
18
- {%- endif -%}
19
- {%- if not loop.last or loop.last and not continue_final_message -%}
20
- {{- "<|im_end|>\n" -}}
21
- {%- endif -%}
22
- {%- endfor -%}
23
- {%- if add_generation_prompt -%}
24
- {{- "<|im_start|>assistant\n" -}}
25
- {%- endif -%}
 
1
+ {% set audio_count = namespace(value=0) %}{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
2
+ You are a helpful assistant.<|im_end|>
3
+ {% endif %}<|im_start|>{{ message['role'] }}
4
+ {% if message['content'] is string %}{{ message['content'] }}<|im_end|>
5
+ {% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_bos|><|IMAGE|><|vision_eos|>{% elif content['type'] == 'audio' or 'audio' in content or 'audio_url' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_audio_id %}Audio {{ audio_count.value }}: {% endif %}<|audio_bos|><|AUDIO|><|audio_eos|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_bos|><|VIDEO|><|vision_eos|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
6
+ {% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
7
+ {% endif %}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -32,6 +32,7 @@
32
  "target_length": 1008,
33
  "win_length": 512
34
  },
 
35
  "auto_map": {
36
  "AutoConfig": "configuration_midashenglm.MiDashengLMConfig",
37
  "AutoModelForCausalLM": "modeling_midashenglm.MiDashengLMModel"
@@ -63,7 +64,6 @@
63
  },
64
  "rope_theta": 1000000.0,
65
  "sliding_window": 32768,
66
- "torch_dtype": "bfloat16",
67
  "use_cache": true,
68
  "use_sliding_window": false,
69
  "vocab_size": 151936
 
32
  "target_length": 1008,
33
  "win_length": 512
34
  },
35
+ "audio_token_id": 151646,
36
  "auto_map": {
37
  "AutoConfig": "configuration_midashenglm.MiDashengLMConfig",
38
  "AutoModelForCausalLM": "modeling_midashenglm.MiDashengLMModel"
 
64
  },
65
  "rope_theta": 1000000.0,
66
  "sliding_window": 32768,
 
67
  "use_cache": true,
68
  "use_sliding_window": false,
69
  "vocab_size": 151936
configuration_midashenglm.py CHANGED
@@ -66,6 +66,7 @@ class MiDashengLMConfig(PretrainedConfig):
66
  audio_encoder_config: Dict = {},
67
  subsample_factor: int = 5,
68
  text_config: Dict = {},
 
69
  **kwargs,
70
  ):
71
  self.audio_encoder_config = DashengConfig(**audio_encoder_config)
@@ -75,4 +76,5 @@ class MiDashengLMConfig(PretrainedConfig):
75
  if text_config
76
  else Qwen2_5OmniTextConfig()
77
  )
 
78
  super().__init__(**kwargs)
 
66
  audio_encoder_config: Dict = {},
67
  subsample_factor: int = 5,
68
  text_config: Dict = {},
69
+ audio_token_id: Optional[int] = None,
70
  **kwargs,
71
  ):
72
  self.audio_encoder_config = DashengConfig(**audio_encoder_config)
 
76
  if text_config
77
  else Qwen2_5OmniTextConfig()
78
  )
79
+ self.audio_token_id = audio_token_id
80
  super().__init__(**kwargs)
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ac83714f7a786cfe80cd40b86b64dc63063f8dbebc34c80298be63218c455ee
3
+ size 4978372408
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:084430974214152e9658155dd21babb35413468bc9025a30820a723c0824ad28
3
+ size 4932950784
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a9c20b898e857e682e490a80a602e4b61e79ec2db35ad19ba4cf5720c43301c
3
+ size 4932950856
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e44f1858a81ee7a8dd96cfad57cb0567ed2a5513f0a7d6344b0975579e62b17
3
+ size 1334862432
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_midashenglm.py CHANGED
@@ -1,7 +1,18 @@
1
  import collections
2
  import collections.abc
3
  from dataclasses import dataclass
4
- from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union, cast
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import torch
7
  import torch.nn as nn
@@ -16,6 +27,7 @@ from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
16
  from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
17
  Qwen2_5OmniThinkerTextModel,
18
  )
 
19
 
20
  from .configuration_midashenglm import DashengConfig, MiDashengLMConfig
21
 
@@ -61,7 +73,7 @@ class AudioPatchEmbed(nn.Module):
61
  )
62
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
63
 
64
- def forward(self, x):
65
  x = self.proj(x)
66
  if self.flatten:
67
  x = torch.permute(
@@ -77,7 +89,7 @@ class LayerScale(nn.Module):
77
  self.inplace = inplace
78
  self.gamma = nn.Parameter(init_values * torch.ones(dim))
79
 
80
- def forward(self, x):
81
  return x.mul_(self.gamma) if self.inplace else x * self.gamma
82
 
83
 
@@ -97,7 +109,7 @@ class DashengMlp(nn.Module):
97
  self.fc2 = nn.Linear(hidden_features, out_features)
98
  self.drop = nn.Dropout(drop)
99
 
100
- def forward(self, x):
101
  x = self.fc1(x)
102
  x = self.act(x)
103
  x = self.drop(x)
@@ -128,7 +140,7 @@ class DashengAttention(nn.Module):
128
  self.proj_drop = nn.Dropout(proj_drop)
129
  self.causal = causal
130
 
131
- def forward(self, x, mask: Optional[torch.Tensor] = None):
132
  B, N, C = x.shape
133
  qkv = (
134
  self.qkv(x)
@@ -206,14 +218,19 @@ class DashengBlock(nn.Module):
206
  )
207
 
208
  # Kwargs usually has a mask parameter that is passed to Attention
209
- def forward(self, x, **kwargs):
210
- x = x + self.ls1(self.attn(self.norm1(x), **kwargs))
 
 
 
 
211
  x = x + self.ls2(self.mlp(self.norm2(x)))
212
  return x
213
 
214
 
215
  class DashengAudioTransformer(PreTrainedModel):
216
  config_class = DashengConfig
 
217
 
218
  def __init__(self, config: DashengConfig):
219
  super().__init__(config)
@@ -221,6 +238,7 @@ class DashengAudioTransformer(PreTrainedModel):
221
  self.target_length = config.target_length
222
  self.embed_dim = config.embed_dim
223
  self.hop_length = config.hop_length
 
224
 
225
  self.front_end = nn.Sequential(
226
  audio_transforms.MelSpectrogram(
@@ -271,7 +289,11 @@ class DashengAudioTransformer(PreTrainedModel):
271
 
272
  self.post_init()
273
 
274
- def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
 
 
 
 
275
  t = x.shape[-1]
276
  x = x + self.time_pos_embed[:, :, :, :t]
277
  x = (
@@ -282,7 +304,10 @@ class DashengAudioTransformer(PreTrainedModel):
282
  ) # rearrange(x, "b c f t -> b (f t) c")
283
  x = self.pos_drop(x)
284
  for block in self.blocks:
285
- x = block(x, **kwargs)
 
 
 
286
  x = self.norm(x)
287
  return x
288
 
@@ -334,13 +359,19 @@ class DashengAudioTransformer(PreTrainedModel):
334
 
335
 
336
  class AudioProjectorSubsample(nn.Module):
337
- def __init__(self, in_dim: int, out_dim: int, downsample_rate=5):
 
 
 
 
 
 
338
  super().__init__()
339
  self.k = downsample_rate
340
  self.net = nn.Sequential(
341
- nn.Linear(in_dim * self.k, out_dim),
342
  nn.GELU(),
343
- nn.Linear(out_dim, out_dim),
344
  )
345
 
346
  def forward(self, x, mask=None):
@@ -365,6 +396,7 @@ class AudioProjectorSubsample(nn.Module):
365
 
366
  @dataclass
367
  class Qwen25OmniTextModelOutput(ModelOutput):
 
368
  logits: Optional[torch.FloatTensor] = None
369
  past_key_values: Optional[Cache] = None
370
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@@ -390,12 +422,20 @@ class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
390
  )
391
  self.post_init()
392
 
 
393
  def forward(
394
  self,
395
- attention_mask: Optional[Tensor] = None,
396
- position_ids: Optional[torch.Tensor] = None,
397
- return_dict: Optional[bool] = None,
398
- **kwargs: Any,
 
 
 
 
 
 
 
399
  ) -> Union[Tuple, Qwen25OmniTextModelOutput]:
400
  if attention_mask is not None and position_ids is None:
401
  position_ids = (
@@ -406,28 +446,33 @@ class Qwen25OmniThinkerTextOnlyDecoder(PreTrainedModel, GenerationMixin):
406
  )
407
 
408
  outputs: BaseModelOutputWithPast = self.model(
 
409
  attention_mask=attention_mask,
410
  position_ids=position_ids,
 
 
 
 
 
 
411
  return_dict=True,
412
- **kwargs,
413
  )
414
  hidden_states = outputs.last_hidden_state
415
  logits = self.lm_head(hidden_states)
416
 
417
- if not return_dict:
418
- return tuple(
419
- v
420
- for v in [
421
- logits,
422
- outputs.last_hidden_state,
423
- outputs.past_key_values,
424
- outputs.hidden_states,
425
- outputs.attentions,
426
- ]
427
- if v is not None
428
  )
 
 
 
429
 
430
  return Qwen25OmniTextModelOutput(
 
431
  logits=logits,
432
  past_key_values=outputs.past_key_values,
433
  hidden_states=outputs.hidden_states,
@@ -443,12 +488,17 @@ class MiDashengLMModel(PreTrainedModel):
443
  _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
444
  _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
445
  _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
 
 
 
446
 
447
  def __init__(self, config: MiDashengLMConfig):
448
  super().__init__(config)
449
 
 
 
450
  self.audio_encoder = DashengAudioTransformer._from_config(
451
- config.audio_encoder_config
452
  )
453
  self.audio_projector = AudioProjectorSubsample(
454
  self.audio_encoder.embed_dim,
@@ -480,7 +530,6 @@ class MiDashengLMModel(PreTrainedModel):
480
  input_values: Optional[torch.Tensor],
481
  inputs_embeds: Optional[torch.Tensor],
482
  audio_length: Optional[Iterable[int]] = None,
483
- audio_token_id: Optional[int] = None,
484
  ) -> torch.Tensor:
485
  if input_ids is not None:
486
  if inputs_embeds is not None:
@@ -492,9 +541,9 @@ class MiDashengLMModel(PreTrainedModel):
492
  )
493
 
494
  if input_values is not None:
495
- if audio_token_id is None:
496
  raise ValueError(
497
- "If `input_values` is provided, `audio_token_id` must also be provided."
498
  )
499
 
500
  audio_embeddings = self._forward_audio_encoder(
@@ -502,7 +551,7 @@ class MiDashengLMModel(PreTrainedModel):
502
  audio_length=audio_length,
503
  ).to(inputs_embeds.dtype)
504
 
505
- audio_mask = (input_ids == audio_token_id).flatten()
506
  diff = torch.diff(
507
  audio_mask.long(),
508
  prepend=torch.zeros(
@@ -540,7 +589,9 @@ class MiDashengLMModel(PreTrainedModel):
540
  input_values: Optional[Tensor] = None,
541
  inputs_embeds: Optional[Tensor] = None,
542
  audio_length: Optional[Iterable[int]] = None,
543
- audio_token_id: Optional[int] = None,
 
 
544
  **kwargs: Any,
545
  ):
546
  inputs_embeds = self._prepare_inputs_embeds(
@@ -548,11 +599,13 @@ class MiDashengLMModel(PreTrainedModel):
548
  input_values=input_values,
549
  inputs_embeds=inputs_embeds,
550
  audio_length=audio_length,
551
- audio_token_id=audio_token_id,
552
  )
553
  return self.decoder(
554
  input_ids=None,
555
  inputs_embeds=inputs_embeds,
 
 
 
556
  **kwargs,
557
  )
558
 
@@ -562,7 +615,6 @@ class MiDashengLMModel(PreTrainedModel):
562
  input_values: Optional[Tensor] = None,
563
  inputs_embeds: Optional[Tensor] = None,
564
  audio_length: Optional[Iterable[int]] = None,
565
- audio_token_id: Optional[int] = None,
566
  **kwargs,
567
  ):
568
  inputs_embeds = self._prepare_inputs_embeds(
@@ -570,7 +622,6 @@ class MiDashengLMModel(PreTrainedModel):
570
  input_values=input_values,
571
  inputs_embeds=inputs_embeds,
572
  audio_length=audio_length,
573
- audio_token_id=audio_token_id,
574
  )
575
  return self.decoder.generate(
576
  inputs_embeds=inputs_embeds,
 
1
  import collections
2
  import collections.abc
3
  from dataclasses import dataclass
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Iterable,
8
+ List,
9
+ Optional,
10
+ Sequence,
11
+ Tuple,
12
+ Union,
13
+ Unpack,
14
+ cast,
15
+ )
16
 
17
  import torch
18
  import torch.nn as nn
 
27
  from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
28
  Qwen2_5OmniThinkerTextModel,
29
  )
30
+ from transformers.utils import LossKwargs, can_return_tuple
31
 
32
  from .configuration_midashenglm import DashengConfig, MiDashengLMConfig
33
 
 
73
  )
74
  self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
75
 
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
  x = self.proj(x)
78
  if self.flatten:
79
  x = torch.permute(
 
89
  self.inplace = inplace
90
  self.gamma = nn.Parameter(init_values * torch.ones(dim))
91
 
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
  return x.mul_(self.gamma) if self.inplace else x * self.gamma
94
 
95
 
 
109
  self.fc2 = nn.Linear(hidden_features, out_features)
110
  self.drop = nn.Dropout(drop)
111
 
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
  x = self.fc1(x)
114
  x = self.act(x)
115
  x = self.drop(x)
 
140
  self.proj_drop = nn.Dropout(proj_drop)
141
  self.causal = causal
142
 
143
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
144
  B, N, C = x.shape
145
  qkv = (
146
  self.qkv(x)
 
218
  )
219
 
220
  # Kwargs usually has a mask parameter that is passed to Attention
221
+ def forward(
222
+ self,
223
+ x: torch.Tensor,
224
+ mask: Optional[torch.Tensor] = None,
225
+ ) -> torch.Tensor:
226
+ x = x + self.ls1(self.attn(self.norm1(x), mask))
227
  x = x + self.ls2(self.mlp(self.norm2(x)))
228
  return x
229
 
230
 
231
  class DashengAudioTransformer(PreTrainedModel):
232
  config_class = DashengConfig
233
+ supports_gradient_checkpointing = True
234
 
235
  def __init__(self, config: DashengConfig):
236
  super().__init__(config)
 
238
  self.target_length = config.target_length
239
  self.embed_dim = config.embed_dim
240
  self.hop_length = config.hop_length
241
+ self.gradient_checkpointing = False
242
 
243
  self.front_end = nn.Sequential(
244
  audio_transforms.MelSpectrogram(
 
289
 
290
  self.post_init()
291
 
292
+ def forward_features(
293
+ self,
294
+ x: torch.Tensor,
295
+ mask: Optional[torch.Tensor] = None,
296
+ ) -> torch.Tensor:
297
  t = x.shape[-1]
298
  x = x + self.time_pos_embed[:, :, :, :t]
299
  x = (
 
304
  ) # rearrange(x, "b c f t -> b (f t) c")
305
  x = self.pos_drop(x)
306
  for block in self.blocks:
307
+ if self.gradient_checkpointing and self.training:
308
+ x = self._gradient_checkpointing_func(block, x, mask)
309
+ else:
310
+ x = block(x, mask)
311
  x = self.norm(x)
312
  return x
313
 
 
359
 
360
 
361
  class AudioProjectorSubsample(nn.Module):
362
+ def __init__(
363
+ self,
364
+ in_dim: int,
365
+ out_dim: int,
366
+ downsample_rate=5,
367
+ dtype: Optional[torch.dtype] = None,
368
+ ):
369
  super().__init__()
370
  self.k = downsample_rate
371
  self.net = nn.Sequential(
372
+ nn.Linear(in_dim * self.k, out_dim, dtype=dtype),
373
  nn.GELU(),
374
+ nn.Linear(out_dim, out_dim, dtype=dtype),
375
  )
376
 
377
  def forward(self, x, mask=None):
 
396
 
397
  @dataclass
398
  class Qwen25OmniTextModelOutput(ModelOutput):
399
+ loss: Optional[torch.FloatTensor] = None
400
  logits: Optional[torch.FloatTensor] = None
401
  past_key_values: Optional[Cache] = None
402
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
 
422
  )
423
  self.post_init()
424
 
425
+ @can_return_tuple
426
  def forward(
427
  self,
428
+ input_ids: Optional[torch.LongTensor] = None,
429
+ attention_mask: Optional[torch.Tensor] = None,
430
+ position_ids: Optional[torch.LongTensor] = None,
431
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
432
+ inputs_embeds: Optional[torch.FloatTensor] = None,
433
+ use_cache: Optional[bool] = None,
434
+ output_attentions: Optional[bool] = None,
435
+ output_hidden_states: Optional[bool] = None,
436
+ cache_position: Optional[torch.LongTensor] = None,
437
+ labels: Optional[torch.Tensor] = None,
438
+ **kwargs: Unpack[LossKwargs],
439
  ) -> Union[Tuple, Qwen25OmniTextModelOutput]:
440
  if attention_mask is not None and position_ids is None:
441
  position_ids = (
 
446
  )
447
 
448
  outputs: BaseModelOutputWithPast = self.model(
449
+ input_ids=input_ids,
450
  attention_mask=attention_mask,
451
  position_ids=position_ids,
452
+ past_key_values=past_key_values,
453
+ inputs_embeds=inputs_embeds,
454
+ use_cache=use_cache,
455
+ output_attentions=output_attentions,
456
+ output_hidden_states=output_hidden_states,
457
+ cache_position=cache_position,
458
  return_dict=True,
 
459
  )
460
  hidden_states = outputs.last_hidden_state
461
  logits = self.lm_head(hidden_states)
462
 
463
+ loss = (
464
+ self.loss_function(
465
+ logits=logits,
466
+ labels=labels,
467
+ vocab_size=self.config.vocab_size,
468
+ **kwargs,
 
 
 
 
 
469
  )
470
+ if labels is not None
471
+ else None
472
+ )
473
 
474
  return Qwen25OmniTextModelOutput(
475
+ loss=loss,
476
  logits=logits,
477
  past_key_values=outputs.past_key_values,
478
  hidden_states=outputs.hidden_states,
 
488
  _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
489
  _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
490
  _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
491
+ supports_gradient_checkpointing = (
492
+ Qwen2_5OmniThinkerTextModel.supports_gradient_checkpointing
493
+ )
494
 
495
  def __init__(self, config: MiDashengLMConfig):
496
  super().__init__(config)
497
 
498
+ self.audio_token_id = config.audio_token_id
499
+
500
  self.audio_encoder = DashengAudioTransformer._from_config(
501
+ config.audio_encoder_config,
502
  )
503
  self.audio_projector = AudioProjectorSubsample(
504
  self.audio_encoder.embed_dim,
 
530
  input_values: Optional[torch.Tensor],
531
  inputs_embeds: Optional[torch.Tensor],
532
  audio_length: Optional[Iterable[int]] = None,
 
533
  ) -> torch.Tensor:
534
  if input_ids is not None:
535
  if inputs_embeds is not None:
 
541
  )
542
 
543
  if input_values is not None:
544
+ if self.audio_token_id is None:
545
  raise ValueError(
546
+ "Audio input is provided, but `audio_token_id` is not configured."
547
  )
548
 
549
  audio_embeddings = self._forward_audio_encoder(
 
551
  audio_length=audio_length,
552
  ).to(inputs_embeds.dtype)
553
 
554
+ audio_mask = (input_ids == self.audio_token_id).flatten()
555
  diff = torch.diff(
556
  audio_mask.long(),
557
  prepend=torch.zeros(
 
589
  input_values: Optional[Tensor] = None,
590
  inputs_embeds: Optional[Tensor] = None,
591
  audio_length: Optional[Iterable[int]] = None,
592
+ attention_mask: Optional[Tensor] = None,
593
+ position_ids: Optional[torch.Tensor] = None,
594
+ labels: Optional[torch.Tensor] = None,
595
  **kwargs: Any,
596
  ):
597
  inputs_embeds = self._prepare_inputs_embeds(
 
599
  input_values=input_values,
600
  inputs_embeds=inputs_embeds,
601
  audio_length=audio_length,
 
602
  )
603
  return self.decoder(
604
  input_ids=None,
605
  inputs_embeds=inputs_embeds,
606
+ attention_mask=attention_mask,
607
+ position_ids=position_ids,
608
+ labels=labels,
609
  **kwargs,
610
  )
611
 
 
615
  input_values: Optional[Tensor] = None,
616
  inputs_embeds: Optional[Tensor] = None,
617
  audio_length: Optional[Iterable[int]] = None,
 
618
  **kwargs,
619
  ):
620
  inputs_embeds = self._prepare_inputs_embeds(
 
622
  input_values=input_values,
623
  inputs_embeds=inputs_embeds,
624
  audio_length=audio_length,
 
625
  )
626
  return self.decoder.generate(
627
  inputs_embeds=inputs_embeds,
processing.py DELETED
@@ -1,277 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import List
4
-
5
- import numpy as np
6
- import torch
7
- from transformers import Qwen2Tokenizer, Qwen2TokenizerFast, Wav2Vec2FeatureExtractor
8
- from transformers.feature_extraction_utils import BatchFeature
9
- from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
10
-
11
-
12
- class MiAudioLLMProcessorKwargs(ProcessingKwargs):
13
- _defaults = {
14
- "text_kwargs": {
15
- "padding": True,
16
- "padding_side": "left",
17
- },
18
- "audio_kwargs": {},
19
- }
20
-
21
-
22
- def calculate_mel_frames_dasheng(
23
- audio_length_samples: int,
24
- n_fft: int = 512,
25
- hop_size: int = 160,
26
- dasheng_subsampling: int = 4,
27
- center=True,
28
- model_subsampling: int = 5,
29
- ) -> int:
30
- """Calculate the number of Mel-spectrogram frames."""
31
- if center:
32
- audio_length_samples = audio_length_samples + n_fft
33
-
34
- return (
35
- int(1 + ((audio_length_samples - n_fft) / hop_size))
36
- // dasheng_subsampling
37
- // model_subsampling
38
- )
39
-
40
-
41
- class MiAudioLLMProcessor(ProcessorMixin):
42
- attributes = ["feature_extractor", "tokenizer"]
43
- valid_kwargs = [
44
- "chat_template",
45
- "audio_token",
46
- "audio_bos_token",
47
- "audio_eos_token",
48
- ]
49
- feature_extractor_class = "Wav2Vec2FeatureExtractor"
50
- tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
51
-
52
- def __init__(
53
- self,
54
- feature_extractor: Wav2Vec2FeatureExtractor | None = None,
55
- tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
56
- model_subsampling: int = 5,
57
- chat_template: str | None = None,
58
- # TODO 是否可以移除?
59
- audio_token: str = "<|AUDIO|>",
60
- audio_bos_token: str = "<|audio_bos|>",
61
- audio_eos_token: str = "<|audio_eos|>",
62
- ):
63
- if chat_template is None:
64
- chat_template = self.default_chat_template
65
- assert tokenizer is not None, "Tokenizer Needs to be passed"
66
- self.audio_token = (
67
- tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
68
- )
69
- self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
70
- self.audio_bos_token = (
71
- tokenizer.audio_bos_token
72
- if hasattr(tokenizer, "audio_bos_token")
73
- else audio_bos_token
74
- )
75
- self.audio_eos_token = (
76
- tokenizer.audio_eos_token
77
- if hasattr(tokenizer, "audio_eos_token")
78
- else audio_eos_token
79
- )
80
- self.model_subsampling = model_subsampling
81
- # Fix Normalization
82
- if feature_extractor is not None and feature_extractor.do_normalize is True:
83
- feature_extractor.do_normalize = False
84
- super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
85
-
86
- def __call__(
87
- self,
88
- text: List[str] | None = None,
89
- audio: List[np.ndarray] | List[torch.Tensor] | None = None,
90
- **kwargs: Unpack[MiAudioLLMProcessorKwargs],
91
- ) -> BatchFeature:
92
- if text is None:
93
- raise ValueError("You need to specify `text` input to process.")
94
- elif isinstance(text, str):
95
- text = [text]
96
- elif not isinstance(text, list) and not isinstance(text[0], str):
97
- raise ValueError(
98
- "Invalid input text. Please provide a string, or a list of strings"
99
- )
100
-
101
- output_kwargs = self._merge_kwargs(
102
- MiAudioLLMProcessorKwargs,
103
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
104
- **kwargs,
105
- )
106
-
107
- if audio is not None:
108
- if isinstance(audio[0], torch.Tensor):
109
- audio = [sample_.numpy() for sample_ in audio]
110
-
111
- if isinstance(audio[0], torch.Tensor):
112
- audio = [sample_.squeeze(0) for sample_ in audio]
113
- if not all(x_.ndim == 1 for x_ in audio):
114
- raise ValueError("All samples in a list must be 1D.")
115
- if isinstance(audio[0], np.ndarray):
116
- if not all(x_.ndim == 1 for x_ in audio):
117
- raise ValueError("All samples in a list must be 1D.")
118
- # ensure we have as much audios as audio tokens
119
- num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
120
- num_audios = 1 if type(audio) is np.ndarray else len(audio)
121
- if num_audio_tokens != num_audios:
122
- raise ValueError(
123
- f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
124
- )
125
-
126
- # Some kwargs should not be changed so we can expand text with audio tokens below
127
- output_kwargs["audio_kwargs"]["return_attention_mask"] = True
128
- output_kwargs["audio_kwargs"]["padding"] = True
129
- output_kwargs["audio_kwargs"]["return_tensors"] = "pt"
130
-
131
- # + Padding
132
- audio_inputs = self.feature_extractor(
133
- audio, **output_kwargs["audio_kwargs"]
134
- )
135
-
136
- # remove attention mask, dasheng uses lengths
137
- audio_feature_mask = audio_inputs.pop("attention_mask")
138
-
139
- expanded_text = []
140
- audio_lengths = audio_feature_mask.sum(-1).tolist()
141
- audio_inputs["audio_length"] = torch.tensor(audio_lengths).long()
142
- audio_inputs["audio_token_id"] = (
143
- self.audio_token_id
144
- ) # Pass to the model such that i knows what is the placeholder id
145
-
146
- for sample in text:
147
- replace_str = []
148
- while self.audio_token in sample:
149
- audio_length = audio_lengths.pop(0)
150
- num_audio_tokens = calculate_mel_frames_dasheng(
151
- audio_length, model_subsampling=self.model_subsampling
152
- )
153
-
154
- expanded_audio_token = self.audio_token * num_audio_tokens
155
-
156
- audio_token_start_idx = sample.find(self.audio_token)
157
- audio_token_end_idx = audio_token_start_idx + len(self.audio_token)
158
-
159
- has_bos = (
160
- sample[
161
- audio_token_start_idx
162
- - len(self.audio_bos_token) : audio_token_start_idx
163
- ]
164
- == self.audio_bos_token
165
- )
166
- has_eos = (
167
- sample[
168
- audio_token_end_idx : audio_token_end_idx
169
- + len(self.audio_eos_token)
170
- ]
171
- == self.audio_eos_token
172
- )
173
-
174
- # Check if this audio token is surrounded by bos/eos tokens
175
- if not has_bos and not has_eos:
176
- expanded_audio_token = (
177
- self.audio_bos_token
178
- + expanded_audio_token
179
- + self.audio_eos_token
180
- )
181
-
182
- replace_str.append(expanded_audio_token)
183
- sample = sample.replace(self.audio_token, "<placeholder>", 1)
184
-
185
- while "<placeholder>" in sample:
186
- sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
187
- expanded_text.append(sample)
188
- text = expanded_text
189
-
190
- return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
191
- inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
192
- if hasattr(self, "_check_special_mm_tokens"):
193
- self._check_special_mm_tokens(text, inputs, modalities=["audio"])
194
-
195
- if audio is not None:
196
- inputs.update(audio_inputs)
197
-
198
- return BatchFeature(data={**inputs}, tensor_type=return_tensors)
199
-
200
- def batch_decode(self, *args, **kwargs):
201
- """
202
- This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
203
- refer to the docstring of this method for more information.
204
- """
205
- return self.tokenizer.batch_decode(*args, **kwargs)
206
-
207
- def decode(self, *args, **kwargs):
208
- """
209
- This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
210
- the docstring of this method for more information.
211
- """
212
- return self.tokenizer.decode(*args, **kwargs)
213
-
214
- @property
215
- def model_input_names(self):
216
- tokenizer_input_names = self.tokenizer.model_input_names
217
- feature_extractor_input_names = self.feature_extractor.model_input_names
218
- return list(
219
- dict.fromkeys(
220
- tokenizer_input_names + feature_extractor_input_names + ["audio_length"]
221
- )
222
- )
223
-
224
- @property
225
- # NOTE: we don't have default templates anymore, and the below is kept only because the hub config is not yet updated!
226
- def default_chat_template(self):
227
- """
228
- This default vicuna template formats inputs in the form of a chat history. For each message in the chat history:
229
- * the template will output the role of the speaker followed by the content of the message.
230
- * content is a list of strings and audios.
231
- * If the content element is an audio, the template will output a sequence of <|AUDIO|> tokens
232
-
233
- Example:
234
-
235
- ```python
236
- messages = [
237
- {'role': 'system', 'content': 'You are a helpful assistant.'},
238
- {"role": "user", "content": [
239
- {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
240
- {"type": "text", "text": "What's that sound?"},
241
- ]},
242
- {"role": "assistant", "content": "It is the sound of glass shattering."},
243
- {"role": "user", "content": [
244
- {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"},
245
- {"type": "text", "text": "How about this one?"},
246
- ]},
247
- ]
248
-
249
- result = template.render(messages=messages, add_generation_prompt=True)
250
- ```
251
- """
252
- # fmt: off
253
- return (
254
- "{% set audio_count = namespace(value=0) %}"
255
- "{% for message in messages %}"
256
- "{% if loop.first and message['role'] != 'system' %}"
257
- "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
258
- "{% endif %}"
259
- "<|im_start|>{{ message['role'] }}\n"
260
- "{% if message['content'] is string %}"
261
- "{{ message['content'] }}<|im_end|>\n"
262
- "{% else %}"
263
- "{% for content in message['content'] %}"
264
- "{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}"
265
- "{% set audio_count.value = audio_count.value + 1 %}"
266
- "Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
267
- "{% elif 'text' in content %}"
268
- "{{ content['text'] }}"
269
- "{% endif %}"
270
- "{% endfor %}"
271
- "<|im_end|>\n"
272
- "{% endif %}"
273
- "{% endfor %}"
274
- "{% if add_generation_prompt %}"
275
- "<|im_start|>assistant\n"
276
- "{% endif %}"
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
processing_midashenglm.py CHANGED
@@ -207,9 +207,6 @@ class MiDashengLMProcessor(ProcessorMixin):
207
  expanded_text = []
208
  audio_lengths = audio_feature_mask.sum(-1).tolist()
209
  audio_inputs["audio_length"] = torch.tensor(audio_lengths).long()
210
- audio_inputs["audio_token_id"] = (
211
- self.audio_token_id
212
- ) # Pass to the model such that i knows what is the placeholder id
213
 
214
  for sample in text:
215
  replace_str = []
 
207
  expanded_text = []
208
  audio_lengths = audio_feature_mask.sum(-1).tolist()
209
  audio_inputs["audio_length"] = torch.tensor(audio_lengths).long()
 
 
 
210
 
211
  for sample in text:
212
  replace_str = []