farzadab commited on
Commit
80fc598
·
verified ·
1 Parent(s): b400365

Update ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +218 -64
ultravox_processing.py CHANGED
@@ -1,12 +1,69 @@
1
- from typing import Optional, Union
 
2
 
3
  import numpy as np
4
  import torch
 
5
  import transformers
6
 
7
  from .ultravox_config import UltravoxConfig
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class UltravoxProcessor(transformers.ProcessorMixin):
11
  """
12
  Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
@@ -17,11 +74,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
17
  """
18
 
19
  attributes = ["audio_processor", "tokenizer"]
20
- audio_processor_class = (
21
- "Wav2Vec2Processor",
22
- "SeamlessM4TFeatureExtractor",
23
- "WhisperProcessor",
24
- )
25
  tokenizer_class = (
26
  "PreTrainedTokenizer",
27
  "PreTrainedTokenizerFast",
@@ -35,41 +88,45 @@ class UltravoxProcessor(transformers.ProcessorMixin):
35
  audio_processor=None,
36
  tokenizer=None,
37
  audio_padding: str = "longest",
38
- encoder_ds_factor: int = 320,
39
  stack_factor: int = 8,
40
  audio_placeholder: str = "<|audio|>",
 
 
41
  ):
42
  """
43
  Args:
44
  audio_processor: The audio processor for the audio encoder.
45
  tokenizer: The tokenizer for the language model.
46
  audio_padding: The padding strategy for the audio encoder.
47
- encoder_ds_factor: The downsample factor of the audio encoder.
48
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
 
49
  audio_placeholder: The placeholder for the audio in the text.
 
50
  """
51
  self.audio_padding = audio_padding
52
  self.encoder_ds_factor = encoder_ds_factor
53
  self.stack_factor = stack_factor
54
  self.audio_placeholder = audio_placeholder
55
- self.audio_token_replacement = tokenizer.eos_token
56
  assert (
57
- self.audio_token_replacement is not None
58
  ), "The tokenizer has no EOS token. Cannot recover."
 
59
  if tokenizer.pad_token_id is None:
60
  tokenizer.pad_token_id = tokenizer.eos_token_id
61
 
62
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
63
 
64
  @classmethod
65
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
66
  config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
67
  pretrained_model_name_or_path, **kwargs
68
  )
69
  audio_processor = transformers.AutoProcessor.from_pretrained(
70
  config.audio_model_id
71
  or config.audio_config._name_or_path
72
- or "facebook/wav2vec2-base-960h"
73
  )
74
 
75
  tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -84,10 +141,69 @@ class UltravoxProcessor(transformers.ProcessorMixin):
84
  stack_factor=config.stack_factor,
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def __call__(
88
  self,
89
  text: Optional[str] = None,
90
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
 
 
 
 
 
91
  sampling_rate: Optional[int] = None,
92
  return_tensors: Optional[
93
  Union[str, transformers.TensorType]
@@ -98,16 +214,16 @@ class UltravoxProcessor(transformers.ProcessorMixin):
98
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
99
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
100
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
101
- audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
102
  of the above two methods for more information.
103
 
104
  Args:
105
  text (`str`, `List[str]`):
106
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
107
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
108
- The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
109
- NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
110
- sample length of the audio.
111
  sampling_rate (`int`, *optional*, defaults to 16000):
112
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
113
  you are doing.
@@ -131,64 +247,102 @@ class UltravoxProcessor(transformers.ProcessorMixin):
131
  Returned when `audio` is not `None`.
132
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
133
  """
134
- # TODO: Add support for multiple audio and text inputs.
 
 
 
 
 
 
 
135
  data = {}
136
- audio_embed_frames = 0
137
- if audio is not None and len(audio) > 0:
138
- if self.audio_padding == "max_length":
139
- # 30 seconds is the expected length for Whisper
140
- assert sampling_rate is not None, "Sampling rate must be provided."
141
- audio_len = 30 * sampling_rate
142
- else:
143
- audio_len = audio.shape[-1]
144
- # It's guaranteed that the number of frames is less than or equal to this amount.
145
- # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
146
- # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
147
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
148
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
149
- data["audio_token_len"] = [audio_embed_frames]
150
 
151
  # Main audio processing. The processor is model-specific.
152
- x = self.audio_processor(
153
- audio,
154
  sampling_rate=sampling_rate,
155
  padding="longest",
156
- max_length=audio_len,
 
 
157
  **kwargs,
158
  )
159
- if "input_features" in x:
160
- data["audio_values"] = x.input_features
161
- else:
162
- data["audio_values"] = x.input_values
163
 
164
- if text is not None:
165
- assert isinstance(
166
- text, str
167
- ), "Text must be a string. Batch mode not supported yet."
168
- if self.audio_placeholder in text:
169
- if "audio_token_len" not in data:
170
- raise ValueError(
171
- f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
172
- )
173
-
174
- start_idx = len(
175
- self.tokenizer.encode(
176
- text[: text.index(self.audio_placeholder)],
177
- add_special_tokens=False,
178
- )
179
- )
180
- data["audio_token_start_idx"] = [start_idx]
181
-
182
- # Replace the audio placeholder with the audio token.
183
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
184
- # where the number of </s> is the number of audio frames.
185
- text = text.replace(
186
- self.audio_placeholder,
187
- self.audio_token_replacement * audio_embed_frames,
188
  )
 
 
 
 
 
 
 
 
 
 
189
 
190
  # Special tokens like BOS should already have been added by the caller.
191
- data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
194
 
@@ -207,4 +361,4 @@ class UltravoxProcessor(transformers.ProcessorMixin):
207
 
208
  UltravoxProcessor.register_for_auto_class()
209
 
210
- transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
 
1
+ import dataclasses
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
6
+ import torch.nn.functional as F
7
  import transformers
8
 
9
  from .ultravox_config import UltravoxConfig
10
 
11
 
12
+ @dataclasses.dataclass
13
+ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
14
+ # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel
15
+ include_alt_fields: bool = False
16
+
17
+ def __call__(self, features, *args, **kwargs):
18
+ audio_values = [x for f in features for x in f.pop("audio_values", [])]
19
+ audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
20
+ audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
21
+ audio_token_start_idx = [
22
+ x for f in features for x in f.pop("audio_token_start_idx", [])
23
+ ]
24
+
25
+ if self.include_alt_fields:
26
+ # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
27
+ alt_features = [
28
+ {
29
+ "input_ids": f.pop("alt_input_ids"),
30
+ "attention_mask": f.pop("alt_attention_mask"),
31
+ "labels": f.pop("alt_labels"),
32
+ }
33
+ for f in features
34
+ ]
35
+
36
+ batch = super().__call__(features, *args, **kwargs)
37
+ if self.include_alt_fields:
38
+ alt_batch = super().__call__(alt_features, *args, **kwargs)
39
+ batch["alt_input_ids"] = alt_batch["input_ids"]
40
+ batch["alt_attention_mask"] = alt_batch["attention_mask"]
41
+ batch["alt_labels"] = alt_batch["labels"]
42
+
43
+ batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
44
+ batch["audio_lens"] = torch.stack(audio_lens)
45
+ batch["audio_token_len"] = torch.stack(audio_token_len)
46
+
47
+ # Pad the last dimension of all audio_values to the same length, with 0s on the right.
48
+ if audio_values:
49
+ max_len = max([x.shape[-1] for x in audio_values])
50
+ batch["audio_values"] = torch.stack(
51
+ [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
52
+ )
53
+ if self.tokenizer.padding_side == "left":
54
+ input_ids_lens = torch.LongTensor(
55
+ [f["input_ids"].shape[-1] for f in features]
56
+ )
57
+ displacement = batch["input_ids"].shape[-1] - input_ids_lens
58
+ displacement = displacement.repeat_interleave(
59
+ batch["audio_batch_size"].squeeze(-1)
60
+ )
61
+ batch["audio_token_start_idx"] += displacement.to(
62
+ batch["audio_token_start_idx"].device
63
+ )
64
+ return batch
65
+
66
+
67
  class UltravoxProcessor(transformers.ProcessorMixin):
68
  """
69
  Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
 
74
  """
75
 
76
  attributes = ["audio_processor", "tokenizer"]
77
+ audio_processor_class = ("WhisperProcessor",)
 
 
 
 
78
  tokenizer_class = (
79
  "PreTrainedTokenizer",
80
  "PreTrainedTokenizerFast",
 
88
  audio_processor=None,
89
  tokenizer=None,
90
  audio_padding: str = "longest",
91
+ encoder_ds_factor: int = 2,
92
  stack_factor: int = 8,
93
  audio_placeholder: str = "<|audio|>",
94
+ # Defaults to whisper encoder context size
95
+ audio_context_size: Optional[int] = 3000,
96
  ):
97
  """
98
  Args:
99
  audio_processor: The audio processor for the audio encoder.
100
  tokenizer: The tokenizer for the language model.
101
  audio_padding: The padding strategy for the audio encoder.
 
102
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
103
+ encoder_ds_factor: The downsampling factor of the audio encoder.
104
  audio_placeholder: The placeholder for the audio in the text.
105
+ audio_context_size: The maximum number of frames that the audio encoder can handle.
106
  """
107
  self.audio_padding = audio_padding
108
  self.encoder_ds_factor = encoder_ds_factor
109
  self.stack_factor = stack_factor
110
  self.audio_placeholder = audio_placeholder
111
+ self.audio_context_size = audio_context_size
112
  assert (
113
+ tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
+ self.audio_replacement_token_id = tokenizer.get_vocab()[tokenizer.eos_token]
116
  if tokenizer.pad_token_id is None:
117
  tokenizer.pad_token_id = tokenizer.eos_token_id
118
 
119
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
120
 
121
  @classmethod
122
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
123
  config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
124
  pretrained_model_name_or_path, **kwargs
125
  )
126
  audio_processor = transformers.AutoProcessor.from_pretrained(
127
  config.audio_model_id
128
  or config.audio_config._name_or_path
129
+ or "openai/whisper-tiny"
130
  )
131
 
132
  tokenizer = transformers.AutoTokenizer.from_pretrained(
 
141
  stack_factor=config.stack_factor,
142
  )
143
 
144
+ def _chunk_and_pad_audio(
145
+ self, audio_values: torch.Tensor, audio_lens: torch.Tensor
146
+ ) -> Dict[str, Any]:
147
+ """
148
+ Processes the audio batch by chunking any items in the batch according to the audio_context_size,
149
+ padding the last chunk if needed, and returns a dictionary with updated audio data.
150
+
151
+ Args:
152
+ audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
153
+ audio_lens (torch.Tensor): A tensor of audio lengths.
154
+
155
+ Returns:
156
+ Dict[str, Any]: Dictionary with the following keys:
157
+ - "audio_values": The concatenated audio tensor after chunking and padding.
158
+ - "audio_lens": Tensor of lengths for each chunk.
159
+ - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
160
+ - "audio_batch_size": A Tensor with one integer representing the number of chunks.
161
+
162
+ """
163
+ chunked_audio_values: List[torch.Tensor] = []
164
+ chunked_audio_lens: List[int] = []
165
+ is_continuation_list: List[bool] = []
166
+ context_size = self.audio_context_size or audio_values.shape[-1]
167
+
168
+ for i in range(audio_values.shape[0]): # iterate over the batch
169
+ for offset in range(0, audio_lens[i], context_size):
170
+ is_continuation = offset > 0
171
+ chunk = audio_values[i, :, offset : offset + context_size]
172
+ if is_continuation and chunk.shape[-1] < context_size:
173
+ # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
174
+ # batch might not (need to) be padded all the way to the audio_context_size, in which case
175
+ # we've already included the padding above. On the other hand, if we have any continuation
176
+ # chunks we know that the batch needs to be padded to audio_context_size because that's what
177
+ # we're slicing to.
178
+ chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
179
+ chunked_audio_values.append(chunk)
180
+ chunked_audio_lens.append(
181
+ min(int(audio_lens[i].item()) - offset, context_size)
182
+ )
183
+ is_continuation_list.append(is_continuation)
184
+
185
+ return {
186
+ "audio_values": torch.stack(chunked_audio_values, dim=0),
187
+ "audio_lens": torch.tensor(
188
+ chunked_audio_lens, dtype=torch.int64, device=audio_values.device
189
+ ),
190
+ "audio_is_continuation": torch.tensor(
191
+ is_continuation_list, dtype=torch.bool, device=audio_values.device
192
+ ),
193
+ "audio_batch_size": torch.tensor(
194
+ [len(chunked_audio_values)], device=audio_values.device
195
+ ),
196
+ }
197
+
198
  def __call__(
199
  self,
200
  text: Optional[str] = None,
201
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
202
+ audios: Optional[
203
+ Union[
204
+ List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
205
+ ]
206
+ ] = None,
207
  sampling_rate: Optional[int] = None,
208
  return_tensors: Optional[
209
  Union[str, transformers.TensorType]
 
214
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
215
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
216
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
217
+ audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
218
  of the above two methods for more information.
219
 
220
  Args:
221
  text (`str`, `List[str]`):
222
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
223
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
224
+ The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
225
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
226
+ A list or two dimensional array of audio to be prepared.
227
  sampling_rate (`int`, *optional*, defaults to 16000):
228
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
229
  you are doing.
 
247
  Returned when `audio` is not `None`.
248
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
249
  """
250
+ # TODO: Add support for multiple text inputs.
251
+ if audio is not None and audios is not None:
252
+ raise ValueError("Only one of `audio` or `audios` should be provided.")
253
+ elif audio is not None:
254
+ audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
255
+ elif audios is None:
256
+ audios = []
257
+
258
  data = {}
259
+ audio_is_continuation = []
260
+ if len(audios) > 0:
261
+ audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]
262
+
263
+ # Pad out each audio to at least 2 hops (the minimum required by the processor).
264
+ hop_length = self.audio_processor.feature_extractor.hop_length
265
+ audios = [
266
+ (
267
+ np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
268
+ if len(x) < 2 * hop_length
269
+ else x
270
+ )
271
+ for x in audios
272
+ ]
273
 
274
  # Main audio processing. The processor is model-specific.
275
+ x: transformers.BatchFeature = self.audio_processor(
276
+ audios,
277
  sampling_rate=sampling_rate,
278
  padding="longest",
279
+ pad_to_multiple_of=hop_length, # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
280
+ truncation=False,
281
+ return_attention_mask=True,
282
  **kwargs,
283
  )
 
 
 
 
284
 
285
+ data.update(
286
+ self._chunk_and_pad_audio(
287
+ audio_values=torch.as_tensor(
288
+ x.input_features if "input_features" in x else x.input_values
289
+ ),
290
+ audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  )
292
+ )
293
+
294
+ audio_is_continuation = data.pop("audio_is_continuation")
295
+ data["audio_token_len"] = torch.ceil(
296
+ data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
297
+ ).to(dtype=torch.int)
298
+
299
+ if text is not None:
300
+ if not isinstance(text, str):
301
+ raise ValueError("Text must be a string. Batch mode not supported yet.")
302
 
303
  # Special tokens like BOS should already have been added by the caller.
304
+ tokenized_parts = self.tokenizer(
305
+ text.split(
306
+ "<|audio|>" # The placeholder isn't part of the vocabulary, so split the text around it.
307
+ ),
308
+ add_special_tokens=False,
309
+ **kwargs,
310
+ )
311
+
312
+ audio_token_start_idx = []
313
+ placeholder_index = -1
314
+ split_input_ids = tokenized_parts["input_ids"]
315
+ input_ids: List[int] = []
316
+
317
+ for i, token_len in enumerate(data.get("audio_token_len", [])):
318
+ if not audio_is_continuation[i]:
319
+ placeholder_index += 1
320
+ if placeholder_index >= len(split_input_ids):
321
+ raise ValueError(
322
+ f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
323
+ )
324
+
325
+ input_ids.extend(split_input_ids[placeholder_index])
326
+
327
+ audio_token_start_idx.append(len(input_ids))
328
+
329
+ input_ids.extend([self.audio_replacement_token_id] * token_len)
330
+
331
+ # Include any tokens after the last audio.
332
+ placeholder_index += 1
333
+ if placeholder_index != len(split_input_ids) - 1:
334
+ raise ValueError(
335
+ f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
336
+ )
337
+ input_ids.extend(split_input_ids[placeholder_index])
338
+
339
+ if "audio_token_len" in data:
340
+ data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)
341
+
342
+ data["input_ids"] = [input_ids]
343
+ data["attention_mask"] = [[1] * len(input_ids)]
344
+
345
+ # Ensure that there are no audio placeholders after the last audio.
346
 
347
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
348
 
 
361
 
362
  UltravoxProcessor.register_for_auto_class()
363
 
364
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)