Automatic Speech Recognition
Transformers
Safetensors
meralion2
meralion
meralion-2
custom_code
YingxuHe commited on
Commit
e8b7cca
·
verified ·
1 Parent(s): dde0903

Upload MERaLiON2ForConditionalGeneration

Browse files
Files changed (2) hide show
  1. config.json +6 -2
  2. modeling_meralion2.py +571 -0
config.json CHANGED
@@ -1,7 +1,10 @@
1
  {
2
- "_attn_implementation_autoset": true,
 
 
3
  "auto_map": {
4
- "AutoConfig": "configuration_meralion2.MERaLiON2Config"
 
5
  },
6
  "head_dim": 256,
7
  "hidden_size": 2304,
@@ -91,5 +94,6 @@
91
  "use_cache": true,
92
  "vocab_size": 256000
93
  },
 
94
  "transformers_version": "4.50.1"
95
  }
 
1
  {
2
+ "architectures": [
3
+ "MERaLiON2ForConditionalGeneration"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_meralion2.MERaLiON2Config",
7
+ "AutoModelForSpeechSeq2Seq": "modeling_meralion2.MERaLiON2ForConditionalGeneration"
8
  },
9
  "head_dim": 256,
10
  "hidden_size": 2304,
 
94
  "use_cache": true,
95
  "vocab_size": 256000
96
  },
97
+ "torch_dtype": "bfloat16",
98
  "transformers_version": "4.50.1"
99
  }
modeling_meralion2.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch MERaLiON2 model."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+
10
+ from transformers import Gemma2ForCausalLM
11
+ from transformers.models.whisper.modeling_whisper import WhisperEncoder
12
+ from transformers.cache_utils import HybridCache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import ModelOutput
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import (
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ logging,
20
+ replace_return_docstrings,
21
+ )
22
+
23
+ from .configuration_meralion2 import MERaLiON2Config
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ _CONFIG_FOR_DOC = "MERaLiON2Config"
29
+
30
+
31
+ # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
32
+ def _prepare_4d_causal_attention_mask_with_cache_position(
33
+ attention_mask: torch.Tensor,
34
+ sequence_length: int,
35
+ target_length: int,
36
+ dtype: torch.dtype,
37
+ device: torch.device,
38
+ min_dtype: float,
39
+ cache_position: torch.Tensor,
40
+ batch_size: int,
41
+ ):
42
+ """
43
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
44
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
45
+
46
+ Args:
47
+ attention_mask (`torch.Tensor`):
48
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
49
+ sequence_length (`int`):
50
+ The sequence length being processed.
51
+ target_length (`int`):
52
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
53
+ dtype (`torch.dtype`):
54
+ The dtype to use for the 4D attention mask.
55
+ device (`torch.device`):
56
+ The device to plcae the 4D attention mask on.
57
+ min_dtype (`float`):
58
+ The minimum value representable with the dtype `dtype`.
59
+ cache_position (`torch.Tensor`):
60
+ Indices depicting the position of the input sequence tokens in the sequence.
61
+ batch_size (`torch.Tensor`):
62
+ Batch size.
63
+ """
64
+ if attention_mask is not None and attention_mask.dim() == 4:
65
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
66
+ causal_mask = attention_mask
67
+ else:
68
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
69
+ if sequence_length != 1:
70
+ causal_mask = torch.triu(causal_mask, diagonal=1)
71
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
72
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
73
+ if attention_mask is not None:
74
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
75
+ mask_length = attention_mask.shape[-1]
76
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
77
+ padding_mask = padding_mask == 0
78
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
79
+ padding_mask, min_dtype
80
+ )
81
+ return causal_mask
82
+
83
+
84
+ # copied from Qwen2AudioCausalLMOutputWithPast
85
+ @dataclass
86
+ class MERaLiON2OutputWithPast(ModelOutput):
87
+ """
88
+ Base class for MERaLiON2 causal language model (or autoregressive) outputs.
89
+
90
+ Args:
91
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
92
+ Language modeling loss (for next-token prediction).
93
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
94
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
95
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
96
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
97
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
98
+
99
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
100
+ `past_key_values` input) to speed up sequential decoding.
101
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
102
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
103
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
104
+
105
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
106
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
107
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
108
+ sequence_length)`.
109
+
110
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
111
+ heads.
112
+ attention_mask (`torch.FloatTensor`, *optional*):
113
+ Attentions mask, used to update attention mask and position_ids.
114
+ """
115
+
116
+ loss: Optional[torch.FloatTensor] = None
117
+ logits: torch.FloatTensor = None
118
+ past_key_values: Optional[List[torch.FloatTensor]] = None
119
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
120
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
121
+ attention_mask: Optional[torch.FloatTensor] = None
122
+
123
+
124
+ MERALION_START_DOCSTRING = r"""
125
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
126
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
127
+ etc.)
128
+
129
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
130
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
131
+ and behavior.
132
+
133
+ Parameters:
134
+ config ([`MERaLiON2Config`]):
135
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
136
+ load the weights associated with the model, only the configuration. Check out the
137
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
138
+ """
139
+
140
+
141
+ @add_start_docstrings(
142
+ "The bare MERaLiON2 Model outputting raw hidden-states without any specific head on top.",
143
+ MERALION_START_DOCSTRING,
144
+ )
145
+ class MERaLiON2PreTrainedModel(PreTrainedModel):
146
+ config_class = MERaLiON2Config
147
+ base_model_prefix = "model"
148
+ supports_gradient_checkpointing = True
149
+ _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer", "Gemma2DecoderLayer"]
150
+ _supports_flash_attn_2 = True
151
+ _supports_sdpa = True
152
+ _supports_cache_class = True
153
+ _supports_static_cache = True
154
+
155
+ def _init_weights(self, module):
156
+ # important: this ported version of Qwen2Audio isn't meant for training from scratch - only
157
+ # inference and fine-tuning - so the proper init weights code has been removed
158
+ std = self.config.init_std if hasattr(self.config, "init_std") else self.config.speech_config.init_std
159
+
160
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
161
+ module.weight.data.normal_(mean=0.0, std=std)
162
+ if module.bias is not None:
163
+ module.bias.data.zero_()
164
+ elif isinstance(module, nn.Embedding):
165
+ module.weight.data.normal_(mean=0.0, std=std)
166
+ if module.padding_idx is not None:
167
+ module.weight.data[module.padding_idx].zero_()
168
+
169
+ @property
170
+ def _supports_sdpa(self):
171
+ """
172
+ Retrieve language_model's attribute to check whether the model supports
173
+ SDPA or not.
174
+ """
175
+ return self.text_decoder._supports_sdpa
176
+
177
+ class MERaLiON2SpeechAudioAdaper(nn.Module):
178
+ def __init__(
179
+ self,
180
+ config,
181
+ **kwargs
182
+ ):
183
+ super(MERaLiON2SpeechAudioAdaper, self).__init__()
184
+ speech_audio_encoder_output_dim = config.speech_config.d_model
185
+ llm_input_hidden_size = config.text_config.hidden_size
186
+ speech_mlp_scale_factor = config.speech_mlp_scale_factor
187
+
188
+ self.speech_mlp_scale_factor = speech_mlp_scale_factor
189
+ self.mlp_adapter = nn.Sequential(
190
+ nn.Linear(
191
+ in_features=speech_audio_encoder_output_dim * speech_mlp_scale_factor,
192
+ out_features=speech_audio_encoder_output_dim
193
+ ),
194
+ nn.SiLU(),
195
+ nn.Dropout(0.1),
196
+ )
197
+
198
+ self.speech_llm_proj = nn.Sequential(
199
+ nn.Linear(
200
+ speech_audio_encoder_output_dim,
201
+ speech_audio_encoder_output_dim * 4
202
+ ),
203
+ nn.SiLU(),
204
+ nn.Dropout(0.1),
205
+
206
+ nn.Linear(
207
+ speech_audio_encoder_output_dim * 4,
208
+ llm_input_hidden_size
209
+ ),
210
+ )
211
+
212
+ def forward(self, speech_embeds, **kwargs):
213
+ B, T, C = speech_embeds.shape
214
+ speech_embeds = self.mlp_adapter(
215
+ speech_embeds.reshape(
216
+ B,
217
+ T // self.speech_mlp_scale_factor,
218
+ C * self.speech_mlp_scale_factor,
219
+ )
220
+ )
221
+ return self.speech_llm_proj(speech_embeds)
222
+
223
+
224
+ class MERaLiON2SpeechAudioAdaperLarge(nn.Module):
225
+ def __init__(
226
+ self,
227
+ config,
228
+ **kwargs
229
+ ):
230
+ super(MERaLiON2SpeechAudioAdaperLarge, self).__init__()
231
+ speech_audio_encoder_output_dim = config.speech_config.d_model
232
+ llm_input_hidden_size = config.text_config.hidden_size
233
+ speech_mlp_scale_factor = config.speech_mlp_scale_factor
234
+
235
+ self.speech_mlp_scale_factor = speech_mlp_scale_factor
236
+ self.mlp_adapter = nn.Sequential(
237
+ nn.Linear(
238
+ in_features=speech_audio_encoder_output_dim * speech_mlp_scale_factor,
239
+ out_features=speech_audio_encoder_output_dim * 5,
240
+ ),
241
+ nn.SiLU(),
242
+ nn.Dropout(0.01),
243
+ )
244
+
245
+ self.gate_proj = nn.Linear(
246
+ in_features=speech_audio_encoder_output_dim * 5,
247
+ out_features=speech_audio_encoder_output_dim * 5,
248
+ )
249
+
250
+ self.pool_proj = nn.Linear(
251
+ in_features=speech_audio_encoder_output_dim * 5,
252
+ out_features=speech_audio_encoder_output_dim * 5,
253
+ )
254
+ self.act_fn = nn.SiLU()
255
+ self.out_proj = nn.Linear(
256
+ speech_audio_encoder_output_dim * 5,
257
+ llm_input_hidden_size,
258
+ )
259
+
260
+
261
+ def forward(self, speech_embeds, **kwargs):
262
+ B, T, C = speech_embeds.shape
263
+ speech_embeds = self.mlp_adapter(
264
+ speech_embeds.reshape(
265
+ B,
266
+ T // self.speech_mlp_scale_factor,
267
+ C * self.speech_mlp_scale_factor,
268
+ )
269
+ )
270
+ speech_embeds = self.act_fn(self.gate_proj(speech_embeds)) * self.pool_proj(speech_embeds)
271
+ speech_embeds = self.out_proj(speech_embeds)
272
+ return speech_embeds
273
+
274
+
275
+ MERALION_INPUTS_DOCSTRING = r"""
276
+ Args:
277
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
278
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
279
+ it.
280
+
281
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
282
+ [`PreTrainedTokenizer.__call__`] for details.
283
+
284
+ [What are input IDs?](../glossary#input-ids)
285
+ input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`, *optional*):
286
+ Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
287
+ loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
288
+ the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
289
+ [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
290
+ tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
291
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
292
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
293
+
294
+ - 1 for tokens that are **not masked**,
295
+ - 0 for tokens that are **masked**.
296
+
297
+ [What are attention masks?](../glossary#attention-mask)
298
+
299
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
300
+ [`PreTrainedTokenizer.__call__`] for details.
301
+
302
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
303
+ `past_key_values`).
304
+
305
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
306
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
307
+ information on the default strategy.
308
+
309
+ - 1 indicates the head is **not masked**,
310
+ - 0 indicates the head is **masked**.
311
+ feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*):
312
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
313
+
314
+ - 1 for tokens that are **not masked**,
315
+ - 0 for tokens that are **masked**.
316
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
317
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
318
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
319
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
320
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
321
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
322
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
323
+
324
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
325
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
326
+
327
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
328
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
329
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
330
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
331
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
332
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
333
+ model's internal embedding lookup matrix.
334
+ use_cache (`bool`, *optional*):
335
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
336
+ `past_key_values`).
337
+ output_attentions (`bool`, *optional*):
338
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
339
+ tensors for more detail.
340
+ output_hidden_states (`bool`, *optional*):
341
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
342
+ more detail.
343
+ return_dict (`bool`, *optional*):
344
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
345
+ """
346
+
347
+ @add_start_docstrings(
348
+ """The MERALION model which consists of a audio backbone and a language model.""",
349
+ MERALION_START_DOCSTRING,
350
+ )
351
+ class MERaLiON2ForConditionalGeneration(MERaLiON2PreTrainedModel, GenerationMixin):
352
+ def __init__(self, config: MERaLiON2Config):
353
+ config.text_config._attn_implementation = config._attn_implementation
354
+ config.speech_config._attn_implementation = config._attn_implementation
355
+
356
+ super().__init__(config)
357
+
358
+ self.speech_encoder = WhisperEncoder(config.speech_config)
359
+ # self.speech_encoder = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation)
360
+
361
+ self.ln_speech = nn.LayerNorm(config.speech_config.d_model)
362
+ self.speech_audio_adapter = MERaLiON2SpeechAudioAdaperLarge(config)
363
+ self.vocab_size = config.text_config.vocab_size
364
+ self.text_decoder = Gemma2ForCausalLM(config.text_config)
365
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
366
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
367
+ self.post_init()
368
+
369
+ @property
370
+ def padding_side(self):
371
+ return self._padding_side
372
+
373
+ @padding_side.setter
374
+ def padding_side(self, padding_side: str):
375
+ if padding_side not in ["left", "right"]:
376
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
377
+ self._padding_side = padding_side
378
+
379
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
380
+ def get_input_embeddings(self):
381
+ return self.text_decoder.get_input_embeddings()
382
+
383
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
384
+ def set_input_embeddings(self, value):
385
+ self.text_decoder.set_input_embeddings(value)
386
+
387
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
388
+ def get_output_embeddings(self):
389
+ return self.text_decoder.get_output_embeddings()
390
+
391
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
392
+ def set_output_embeddings(self, new_embeddings):
393
+ self.text_decoder.set_output_embeddings(new_embeddings)
394
+
395
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
396
+ def set_decoder(self, decoder):
397
+ self.text_decoder.set_decoder(decoder)
398
+
399
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
400
+ def get_decoder(self):
401
+ return self.text_decoder.get_decoder()
402
+
403
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
404
+ def tie_weights(self):
405
+ return self.text_decoder.tie_weights()
406
+
407
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
408
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
409
+ model_embeds = self.text_decoder.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
410
+ # update vocab size
411
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
412
+ self.vocab_size = model_embeds.num_embeddings
413
+ return model_embeds
414
+
415
+ @add_start_docstrings_to_model_forward(MERALION_INPUTS_DOCSTRING)
416
+ @replace_return_docstrings(output_type=MERaLiON2OutputWithPast, config_class=_CONFIG_FOR_DOC)
417
+ def forward(
418
+ self,
419
+ input_ids: torch.LongTensor = None,
420
+ input_features: torch.FloatTensor = None,
421
+ attention_mask: Optional[torch.Tensor] = None,
422
+ feature_attention_mask: Optional[torch.Tensor] = None,
423
+ position_ids: Optional[torch.LongTensor] = None,
424
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
425
+ inputs_embeds: Optional[torch.FloatTensor] = None,
426
+ labels: Optional[torch.LongTensor] = None,
427
+ use_cache: Optional[bool] = None,
428
+ cache_position: Optional[torch.LongTensor] = None,
429
+ output_attentions: Optional[bool] = None,
430
+ output_hidden_states: Optional[bool] = None,
431
+ return_dict: Optional[bool] = None,
432
+ ) -> Union[Tuple, MERaLiON2OutputWithPast]:
433
+ r"""
434
+ Args:
435
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
436
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
437
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
438
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
439
+
440
+ Returns:
441
+ """
442
+
443
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
444
+ output_hidden_states = (
445
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
446
+ )
447
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
448
+
449
+ speech_encoder_device = self.speech_encoder.device
450
+
451
+ if input_features is not None:
452
+ input_features = input_features.to(speech_encoder_device)
453
+ feature_attention_mask = feature_attention_mask.to(speech_encoder_device)
454
+
455
+ if inputs_embeds is None:
456
+ speech_contexts_embeds = self.speech_encoder(input_features, attention_mask=feature_attention_mask).last_hidden_state
457
+ speech_contexts_embeds = self.ln_speech(speech_contexts_embeds)
458
+ speech_audio_contexts_embeds = self.speech_audio_adapter(speech_contexts_embeds)
459
+
460
+ inputs_embeds = self.text_decoder.base_model.embed_tokens(input_ids)
461
+
462
+ speech_mask = (input_ids == self.config.speech_token_index).unsqueeze(-1)
463
+ speech_mask = speech_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
464
+
465
+ inputs_embeds = inputs_embeds.masked_scatter(speech_mask, speech_audio_contexts_embeds)
466
+
467
+ input_ids = None
468
+
469
+ outputs = self.text_decoder(
470
+ input_ids=input_ids,
471
+ attention_mask=attention_mask,
472
+ position_ids=position_ids,
473
+ past_key_values=past_key_values,
474
+ inputs_embeds=inputs_embeds,
475
+ use_cache=use_cache,
476
+ cache_position=cache_position,
477
+ output_attentions=output_attentions,
478
+ output_hidden_states=output_hidden_states,
479
+ return_dict=return_dict,
480
+ labels=labels
481
+ )
482
+
483
+ return outputs
484
+
485
+ # from transformers.models.gemma2.modeling_gemma2.Gemma2ForCausalLM.prepare_inputs_for_generation
486
+ def prepare_inputs_for_generation(
487
+ self,
488
+ input_ids,
489
+ attention_mask=None,
490
+ input_features=None,
491
+ feature_attention_mask=None,
492
+ past_key_values=None,
493
+ inputs_embeds=None,
494
+ cache_position=None,
495
+ position_ids=None,
496
+ use_cache=None,
497
+ **kwargs,
498
+ ):
499
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
500
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
501
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
502
+ is_first_step = cache_position[0].item() == 0
503
+ if past_key_values is not None:
504
+ if inputs_embeds is not None: # Exception 1
505
+ input_ids = input_ids[:, -cache_position.shape[0] :]
506
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
507
+ input_ids = input_ids[:, cache_position]
508
+
509
+ if attention_mask is not None and position_ids is None:
510
+ # create position_ids on the fly for batch generation
511
+ position_ids = attention_mask.long().cumsum(-1) - 1
512
+ position_ids.masked_fill_(attention_mask == 0, 1)
513
+ if past_key_values:
514
+ position_ids = position_ids[:, -input_ids.shape[1] :]
515
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
516
+ # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
517
+ # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
518
+ # batch size = 1 case, `position_ids` is already contiguous but with varying stride
519
+ # which retriggers a capture.
520
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
521
+
522
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
523
+ if inputs_embeds is not None and is_first_step:
524
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
525
+ else:
526
+ # The clone here is for the same reason as for `position_ids`.
527
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
528
+
529
+ if (
530
+ isinstance(past_key_values, HybridCache)
531
+ and attention_mask.ndim == 2
532
+ and not self.config._attn_implementation == "flash_attention_2"
533
+ ):
534
+ if model_inputs["inputs_embeds"] is not None:
535
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
536
+ device = model_inputs["inputs_embeds"].device
537
+ else:
538
+ batch_size, sequence_length = model_inputs["input_ids"].shape
539
+ device = model_inputs["input_ids"].device
540
+ dtype = self.text_decoder.lm_head.weight.dtype
541
+ min_dtype = torch.finfo(dtype).min
542
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
543
+ attention_mask,
544
+ sequence_length=sequence_length,
545
+ target_length=past_key_values.get_max_cache_shape(),
546
+ dtype=dtype,
547
+ device=device,
548
+ min_dtype=min_dtype,
549
+ cache_position=cache_position,
550
+ batch_size=batch_size,
551
+ )
552
+
553
+ model_inputs.update(
554
+ {
555
+ "attention_mask": attention_mask,
556
+ "position_ids": position_ids,
557
+ "cache_position": cache_position,
558
+ "past_key_values": past_key_values,
559
+ "use_cache": use_cache
560
+ }
561
+ )
562
+
563
+ # Input ids will only be used from the second step.
564
+ if is_first_step:
565
+ model_inputs["input_features"] = input_features
566
+ model_inputs["feature_attention_mask"] = feature_attention_mask
567
+
568
+ return model_inputs
569
+
570
+ def _reorder_cache(self, *args, **kwargs):
571
+ return self.text_decoder._reorder_cache(*args, **kwargs)