junnei commited on
Commit
ca699e1
·
verified ·
1 Parent(s): 9c3cebd

Upload processing_gemma3mm.py

Browse files
Files changed (1) hide show
  1. processing_gemma3mm.py +14 -12
processing_gemma3mm.py CHANGED
@@ -131,15 +131,15 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
131
 
132
  def __call__(
133
  self,
134
- audios: List[AudioInput],
135
  return_tensors: Optional[Union[str, TensorType]] = None,
136
  ):
137
  # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
138
  returned_input_audio_embeds = []
139
  returned_audio_embed_sizes = []
140
  audio_frames_list = []
141
-
142
- for audio_data, sample_rate in audios:
143
  audio_embeds = self._extract_features(audio_data, sample_rate)
144
  audio_frames = len(audio_embeds) * self.audio_feat_stride
145
  audio_embed_size = self._compute_audio_embed_size(audio_frames)
@@ -152,7 +152,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
152
  )
153
  returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
154
  audio_frames = torch.tensor(audio_frames_list)
155
- returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audios) > 1 else None
156
 
157
  data = {
158
  "input_audio_embeds": returned_input_audio_embeds,
@@ -291,6 +291,7 @@ class Gemma3MMProcessor(ProcessorMixin):
291
  self.image_seq_length = image_seq_length
292
  self.image_token_id = tokenizer.image_token_id
293
  self.boi_token = tokenizer.boi_token
 
294
  image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
295
  self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
296
 
@@ -312,7 +313,7 @@ class Gemma3MMProcessor(ProcessorMixin):
312
  images: ImageInput = None,
313
  text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
314
  videos=None,
315
- audios: List[AudioInput] = None,
316
  **kwargs: Unpack[Gemma3ProcessorKwargs],
317
  ) -> BatchFeature:
318
  if text is None and images is None:
@@ -344,8 +345,8 @@ class Gemma3MMProcessor(ProcessorMixin):
344
  )
345
 
346
  # Replace image tokens by the full expanded sequence
347
- batch_num_crops = to_py_obj(image_inputs.pop("num_crops"))
348
- text_with_crops = text
349
  for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
350
  image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
351
 
@@ -362,14 +363,15 @@ class Gemma3MMProcessor(ProcessorMixin):
362
  + " ".join([self.boi_token] * num)
363
  )
364
  prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
365
- text_with_crops[batch_idx] = prompt
366
 
367
  # Expand placeholder image tokens to the full image token sequence
368
  text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
 
369
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
370
 
371
  audio_inputs = {}
372
- if audios is not None:
373
  def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
374
  parts = prompt.split(boa_token)
375
  result = ""
@@ -383,7 +385,7 @@ class Gemma3MMProcessor(ProcessorMixin):
383
  return result
384
 
385
  full_audio_sequences = []
386
- audio_inputs = self.feature_extractor(audios)
387
 
388
  for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
389
  audio_tokens_expanded = "".join([self.audio_token] * embed_size)
@@ -395,7 +397,7 @@ class Gemma3MMProcessor(ProcessorMixin):
395
  text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
396
 
397
  # Add token type ids manually, as tokenizer can't do arbitrary position token types
398
- array_ids = np.array(text_inputs["input_ids"])
399
  mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
400
  mm_token_type_ids[array_ids == self.image_token_id] = 1
401
  mm_token_type_ids[array_ids == self.audio_token_id] = 2
@@ -409,7 +411,7 @@ class Gemma3MMProcessor(ProcessorMixin):
409
  text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
410
  text_inputs["input_modes"] = input_modes.tolist()
411
 
412
- return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs, }, tensor_type=return_tensors)
413
 
414
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
415
  def batch_decode(self, *args, **kwargs):
 
131
 
132
  def __call__(
133
  self,
134
+ audio: List[AudioInput],
135
  return_tensors: Optional[Union[str, TensorType]] = None,
136
  ):
137
  # Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
138
  returned_input_audio_embeds = []
139
  returned_audio_embed_sizes = []
140
  audio_frames_list = []
141
+ sample_rate = 16000
142
+ for audio_data in audio:
143
  audio_embeds = self._extract_features(audio_data, sample_rate)
144
  audio_frames = len(audio_embeds) * self.audio_feat_stride
145
  audio_embed_size = self._compute_audio_embed_size(audio_frames)
 
152
  )
153
  returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
154
  audio_frames = torch.tensor(audio_frames_list)
155
+ returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audio) > 1 else None
156
 
157
  data = {
158
  "input_audio_embeds": returned_input_audio_embeds,
 
291
  self.image_seq_length = image_seq_length
292
  self.image_token_id = tokenizer.image_token_id
293
  self.boi_token = tokenizer.boi_token
294
+ self.image_token = tokenizer.boi_token
295
  image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
296
  self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
297
 
 
313
  images: ImageInput = None,
314
  text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
315
  videos=None,
316
+ audio: List[AudioInput] = None,
317
  **kwargs: Unpack[Gemma3ProcessorKwargs],
318
  ) -> BatchFeature:
319
  if text is None and images is None:
 
345
  )
346
 
347
  # Replace image tokens by the full expanded sequence
348
+ num_crops = to_py_obj(image_inputs.pop("num_crops"))
349
+ batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
350
  for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
351
  image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
352
 
 
363
  + " ".join([self.boi_token] * num)
364
  )
365
  prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
366
+ text[batch_idx] = prompt
367
 
368
  # Expand placeholder image tokens to the full image token sequence
369
  text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
370
+
371
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
372
 
373
  audio_inputs = {}
374
+ if audio is not None:
375
  def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
376
  parts = prompt.split(boa_token)
377
  result = ""
 
385
  return result
386
 
387
  full_audio_sequences = []
388
+ audio_inputs = self.feature_extractor(audio)
389
 
390
  for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
391
  audio_tokens_expanded = "".join([self.audio_token] * embed_size)
 
397
  text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
398
 
399
  # Add token type ids manually, as tokenizer can't do arbitrary position token types
400
+ array_ids = text_inputs["input_ids"]
401
  mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
402
  mm_token_type_ids[array_ids == self.image_token_id] = 1
403
  mm_token_type_ids[array_ids == self.audio_token_id] = 2
 
411
  text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
412
  text_inputs["input_modes"] = input_modes.tolist()
413
 
414
+ return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
415
 
416
  # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
417
  def batch_decode(self, *args, **kwargs):