farzadab commited on
Commit
42c9246
·
verified ·
1 Parent(s): 4215e51

Upload ultravox_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ultravox_model.py +116 -42
ultravox_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  import re
3
- from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
@@ -56,6 +56,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
56
  self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
 
 
 
 
 
59
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
60
  # FSDP throws an error if some of the layer types are not found in the model.
61
  # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
@@ -64,6 +69,44 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
64
  self.loss_config = LossConfig()
65
  self.post_init()
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def get_input_embeddings(self):
68
  return self.language_model.get_input_embeddings()
69
 
@@ -110,6 +153,30 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
110
  self.vocab_size = model_embeds.num_embeddings
111
  return model_embeds
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def _compute_kl_loss(
114
  self,
115
  lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
@@ -134,11 +201,12 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
134
  # compute the KL divergence loss between the two models
135
  kl_loss = F.kl_div(
136
  F.log_softmax(
137
- lm_output.logits[labels != -100] / self.loss_config.kl_temperature,
 
138
  dim=-1,
139
  ),
140
  F.softmax(
141
- alt_lm_output.logits[alt_labels != -100]
142
  / self.loss_config.kl_temperature,
143
  dim=-1,
144
  ),
@@ -289,7 +357,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
289
 
290
  # include audio information in model_input only when it is needed during prefilling
291
  # audio_token_start_idx should always be relative to the current cache position
292
- prefill_start_idx = 0 if cache_position is None else cache_position[0]
 
 
293
  if (
294
  audio_values is not None
295
  and audio_token_start_idx is not None
@@ -317,23 +387,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
317
  def _create_audio_tower(
318
  cls, config: UltravoxConfig
319
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
320
- if config.audio_model_id is not None:
321
- if "whisper" in config.audio_model_id.lower():
322
- audio_tower = ModifiedWhisperEncoder.from_pretrained(
323
- config.audio_model_id, torch_dtype=config.torch_dtype
324
- )
325
- audio_tower.init_latency_mask(
326
- config.audio_latency_block_size, dtype=config.torch_dtype
327
- )
328
- else:
329
- assert config.audio_latency_block_size in (
330
- None,
331
- 0,
332
- ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
333
- audio_tower = transformers.AutoModel.from_pretrained(
334
- config.audio_model_id, torch_dtype=config.torch_dtype
335
- )
336
- else:
337
  if "whisper" in config.audio_config._name_or_path.lower():
338
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
339
  audio_tower.init_latency_mask(
@@ -344,12 +400,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
344
  None,
345
  0,
346
  ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
347
- with transformers.modeling_utils.no_init_weights():
348
- # we only ever use from_config if the weights are retrained, hence initializing is not
349
- # required. This makes the model quite creation faster since init on CPU is quite slow.
350
- audio_tower = transformers.AutoModel.from_config(
351
- config.audio_config
352
- )
353
 
354
  if isinstance(
355
  audio_tower,
@@ -367,21 +418,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
367
  def _create_language_model(
368
  cls, config: UltravoxConfig
369
  ) -> transformers.LlamaForCausalLM:
370
- if config.text_model_id is not None:
371
- language_model = transformers.AutoModelForCausalLM.from_pretrained(
372
- config.text_model_id,
373
- attn_implementation=config._attn_implementation,
 
 
374
  torch_dtype=config.torch_dtype,
375
  )
376
- else:
377
- with transformers.modeling_utils.no_init_weights():
378
- # we only ever use from_config if the weights are retrained, hence initializing is not
379
- # required. This makes the model quite creation faster since init on CPU is quite slow.
380
- language_model = transformers.AutoModelForCausalLM.from_config(
381
- config.text_config,
382
- attn_implementation=config._attn_implementation,
383
- torch_dtype=config.torch_dtype,
384
- )
385
 
386
  language_model = apply_lora(language_model, config.text_model_lora_config)
387
  return language_model
@@ -495,7 +539,10 @@ def is_cache_empty(
495
  return past_key_values.get_seq_length() == 0
496
 
497
 
498
- def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
 
 
 
499
  """
500
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
501
  """
@@ -574,11 +621,35 @@ class UltravoxProjector(nn.Module):
574
  self.ln_post = RMSNorm(dim_out, init=config.norm_init)
575
 
576
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  audio_features = self._pad_and_stack(audio_features)
578
  audio_features = self.ln_pre(audio_features)
 
579
  hidden_states = self.linear_1(audio_features)
 
580
  hidden_states = self.act(hidden_states)
581
  hidden_states = self.ln_mid(hidden_states)
 
582
  hidden_states = self.linear_2(hidden_states)
583
  hidden_states = self.ln_post(hidden_states)
584
  return hidden_states
@@ -601,6 +672,7 @@ class ModifiedWhisperEncoder(
601
 
602
  base_model_prefix = "model.encoder"
603
  _no_split_modules = ["WhisperEncoderLayer"]
 
604
 
605
  def __init__(self, config: transformers.WhisperConfig):
606
  super().__init__(config)
@@ -614,7 +686,9 @@ class ModifiedWhisperEncoder(
614
  * self.conv2.stride[0]
615
  )
616
 
617
- def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
 
 
618
  if audio_latency_block_size is None:
619
  self.audio_streaming_mask = None
620
  return
 
1
  import logging
2
  import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
4
 
5
  import peft
6
  import torch
 
56
  self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
59
+ if self.language_model._tied_weights_keys is not None:
60
+ self._tied_weights_keys = [
61
+ f"language_model.{k}" for k in self.language_model._tied_weights_keys
62
+ ]
63
+
64
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
65
  # FSDP throws an error if some of the layer types are not found in the model.
66
  # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
 
69
  self.loss_config = LossConfig()
70
  self.post_init()
71
 
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
74
+ model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
75
+ model._load_child_model_weights(*args, **kwargs)
76
+ return model
77
+
78
+ def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
79
+ if "torch_dtype" in kwargs:
80
+ self.config.torch_dtype = kwargs.pop("torch_dtype")
81
+
82
+ kwargs.pop("config", None)
83
+
84
+ if (
85
+ self.config.text_model_id is not None
86
+ and self.language_model.device.type == "meta"
87
+ ):
88
+ # Load the language model weights
89
+ self.language_model = transformers.AutoModelForCausalLM.from_pretrained(
90
+ self.config.text_model_id,
91
+ torch_dtype=self.config.torch_dtype,
92
+ *args,
93
+ **kwargs,
94
+ )
95
+
96
+ if (
97
+ self.config.audio_model_id is not None
98
+ and self.audio_tower.device.type == "meta"
99
+ ):
100
+ # Load the audio tower weights
101
+ self.audio_tower = transformers.AutoModel.from_pretrained(
102
+ self.config.audio_model_id,
103
+ torch_dtype=self.config.torch_dtype,
104
+ *args,
105
+ **kwargs,
106
+ )
107
+
108
+ return self
109
+
110
  def get_input_embeddings(self):
111
  return self.language_model.get_input_embeddings()
112
 
 
153
  self.vocab_size = model_embeds.num_embeddings
154
  return model_embeds
155
 
156
+ def _get_prediction_mask(self, labels: Optional[torch.Tensor]) -> torch.Tensor:
157
+ """Get a boolean mask for positions where we want to compute KL divergence.
158
+
159
+ For each label position, we want the position before it since that's where
160
+ the model makes the prediction for that label.
161
+
162
+ Args:
163
+ labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
164
+ with -100 for masked positions and token ids for label positions
165
+
166
+ Returns:
167
+ Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
168
+ """
169
+ if labels is None:
170
+ raise ValueError("labels must be provided")
171
+ # Shift the label mask right by 1 along the sequence dimension
172
+ # This gives us positions where we make predictions for the next token
173
+ label_mask = labels != -100
174
+ pred_mask = torch.zeros_like(label_mask)
175
+ pred_mask[:, :-1] = label_mask[
176
+ :, 1:
177
+ ] # shift right by 1 along sequence dimension
178
+ return pred_mask
179
+
180
  def _compute_kl_loss(
181
  self,
182
  lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
 
201
  # compute the KL divergence loss between the two models
202
  kl_loss = F.kl_div(
203
  F.log_softmax(
204
+ lm_output.logits[self._get_prediction_mask(labels)]
205
+ / self.loss_config.kl_temperature,
206
  dim=-1,
207
  ),
208
  F.softmax(
209
+ alt_lm_output.logits[self._get_prediction_mask(alt_labels)]
210
  / self.loss_config.kl_temperature,
211
  dim=-1,
212
  ),
 
357
 
358
  # include audio information in model_input only when it is needed during prefilling
359
  # audio_token_start_idx should always be relative to the current cache position
360
+ prefill_start_idx: int | torch.Tensor = (
361
+ 0 if cache_position is None else cache_position[0]
362
+ )
363
  if (
364
  audio_values is not None
365
  and audio_token_start_idx is not None
 
387
  def _create_audio_tower(
388
  cls, config: UltravoxConfig
389
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
390
+ with transformers.modeling_utils.no_init_weights():
391
+ # we only ever use from_config if the weights are retrained, hence initializing is not
392
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  if "whisper" in config.audio_config._name_or_path.lower():
394
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
395
  audio_tower.init_latency_mask(
 
400
  None,
401
  0,
402
  ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
403
+ audio_tower = transformers.AutoModel.from_config(config.audio_config)
 
 
 
 
 
404
 
405
  if isinstance(
406
  audio_tower,
 
418
  def _create_language_model(
419
  cls, config: UltravoxConfig
420
  ) -> transformers.LlamaForCausalLM:
421
+ with transformers.modeling_utils.no_init_weights():
422
+ # we only ever use from_config if the weights are retrained, hence initializing is not
423
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
424
+ language_model = transformers.AutoModelForCausalLM.from_config(
425
+ config.text_config,
426
+ attn_implementation=config.text_config._attn_implementation,
427
  torch_dtype=config.torch_dtype,
428
  )
 
 
 
 
 
 
 
 
 
429
 
430
  language_model = apply_lora(language_model, config.text_model_lora_config)
431
  return language_model
 
539
  return past_key_values.get_seq_length() == 0
540
 
541
 
542
+ T = TypeVar("T", bound=torch.nn.Module)
543
+
544
+
545
+ def apply_lora(model: T, lora_config: dict) -> T:
546
  """
547
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
548
  """
 
621
  self.ln_post = RMSNorm(dim_out, init=config.norm_init)
622
 
623
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
624
+ """
625
+ Takes in audio features from the audio tower and projects them to the text model's embedding space.
626
+ It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
627
+ If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
628
+
629
+ Input shape:
630
+ audio_features: B, T*S, C
631
+ Output shape:
632
+ hidden_states: B, T, D
633
+ Where:
634
+ B: batch size
635
+ F: number of frames in the audio tower
636
+ T: number of output embeddings
637
+ T = ceil(F / S)
638
+ S: stack factor
639
+ C: number of channels out of the encoder (aka audio tower)
640
+ H: hidden size of the projector (config.hidden_size)
641
+ D: dimension of the text model (config.text_config.hidden_size)
642
+
643
+ """
644
+ # B, F, C -> B, T, C*S
645
  audio_features = self._pad_and_stack(audio_features)
646
  audio_features = self.ln_pre(audio_features)
647
+ # B, T, C*S -> B, T, H
648
  hidden_states = self.linear_1(audio_features)
649
+ # B, T, H -> B, T, H/2 (assuming swiglu)
650
  hidden_states = self.act(hidden_states)
651
  hidden_states = self.ln_mid(hidden_states)
652
+ # B, T, H/2 -> B, T, D
653
  hidden_states = self.linear_2(hidden_states)
654
  hidden_states = self.ln_post(hidden_states)
655
  return hidden_states
 
672
 
673
  base_model_prefix = "model.encoder"
674
  _no_split_modules = ["WhisperEncoderLayer"]
675
+ _keys_to_ignore_on_load_unexpected = ["model.decoder.*"]
676
 
677
  def __init__(self, config: transformers.WhisperConfig):
678
  super().__init__(config)
 
686
  * self.conv2.stride[0]
687
  )
688
 
689
+ def init_latency_mask(
690
+ self, audio_latency_block_size: int | None, dtype: torch.dtype
691
+ ):
692
  if audio_latency_block_size is None:
693
  self.audio_streaming_mask = None
694
  return