Upload processing_gemma3mm.py
Browse files- processing_gemma3mm.py +14 -12
processing_gemma3mm.py
CHANGED
@@ -131,15 +131,15 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
131 |
|
132 |
def __call__(
|
133 |
self,
|
134 |
-
|
135 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
136 |
):
|
137 |
# Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
|
138 |
returned_input_audio_embeds = []
|
139 |
returned_audio_embed_sizes = []
|
140 |
audio_frames_list = []
|
141 |
-
|
142 |
-
for audio_data
|
143 |
audio_embeds = self._extract_features(audio_data, sample_rate)
|
144 |
audio_frames = len(audio_embeds) * self.audio_feat_stride
|
145 |
audio_embed_size = self._compute_audio_embed_size(audio_frames)
|
@@ -152,7 +152,7 @@ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
|
152 |
)
|
153 |
returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
|
154 |
audio_frames = torch.tensor(audio_frames_list)
|
155 |
-
returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(
|
156 |
|
157 |
data = {
|
158 |
"input_audio_embeds": returned_input_audio_embeds,
|
@@ -291,6 +291,7 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
291 |
self.image_seq_length = image_seq_length
|
292 |
self.image_token_id = tokenizer.image_token_id
|
293 |
self.boi_token = tokenizer.boi_token
|
|
|
294 |
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
295 |
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
296 |
|
@@ -312,7 +313,7 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
312 |
images: ImageInput = None,
|
313 |
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
314 |
videos=None,
|
315 |
-
|
316 |
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
317 |
) -> BatchFeature:
|
318 |
if text is None and images is None:
|
@@ -344,8 +345,8 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
344 |
)
|
345 |
|
346 |
# Replace image tokens by the full expanded sequence
|
347 |
-
|
348 |
-
|
349 |
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
350 |
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
351 |
|
@@ -362,14 +363,15 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
362 |
+ " ".join([self.boi_token] * num)
|
363 |
)
|
364 |
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
365 |
-
|
366 |
|
367 |
# Expand placeholder image tokens to the full image token sequence
|
368 |
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
|
|
369 |
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
370 |
|
371 |
audio_inputs = {}
|
372 |
-
if
|
373 |
def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
|
374 |
parts = prompt.split(boa_token)
|
375 |
result = ""
|
@@ -383,7 +385,7 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
383 |
return result
|
384 |
|
385 |
full_audio_sequences = []
|
386 |
-
audio_inputs = self.feature_extractor(
|
387 |
|
388 |
for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
|
389 |
audio_tokens_expanded = "".join([self.audio_token] * embed_size)
|
@@ -395,7 +397,7 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
395 |
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
396 |
|
397 |
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
398 |
-
array_ids =
|
399 |
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
400 |
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
401 |
mm_token_type_ids[array_ids == self.audio_token_id] = 2
|
@@ -409,7 +411,7 @@ class Gemma3MMProcessor(ProcessorMixin):
|
|
409 |
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
410 |
text_inputs["input_modes"] = input_modes.tolist()
|
411 |
|
412 |
-
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs
|
413 |
|
414 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
415 |
def batch_decode(self, *args, **kwargs):
|
|
|
131 |
|
132 |
def __call__(
|
133 |
self,
|
134 |
+
audio: List[AudioInput],
|
135 |
return_tensors: Optional[Union[str, TensorType]] = None,
|
136 |
):
|
137 |
# Ref: https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161
|
138 |
returned_input_audio_embeds = []
|
139 |
returned_audio_embed_sizes = []
|
140 |
audio_frames_list = []
|
141 |
+
sample_rate = 16000
|
142 |
+
for audio_data in audio:
|
143 |
audio_embeds = self._extract_features(audio_data, sample_rate)
|
144 |
audio_frames = len(audio_embeds) * self.audio_feat_stride
|
145 |
audio_embed_size = self._compute_audio_embed_size(audio_frames)
|
|
|
152 |
)
|
153 |
returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
|
154 |
audio_frames = torch.tensor(audio_frames_list)
|
155 |
+
returned_audio_attention_mask = torch.arange(0, audio_frames.max()).unsqueeze(0) < audio_frames.unsqueeze(1) if len(audio) > 1 else None
|
156 |
|
157 |
data = {
|
158 |
"input_audio_embeds": returned_input_audio_embeds,
|
|
|
291 |
self.image_seq_length = image_seq_length
|
292 |
self.image_token_id = tokenizer.image_token_id
|
293 |
self.boi_token = tokenizer.boi_token
|
294 |
+
self.image_token = tokenizer.boi_token
|
295 |
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
296 |
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
297 |
|
|
|
313 |
images: ImageInput = None,
|
314 |
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
315 |
videos=None,
|
316 |
+
audio: List[AudioInput] = None,
|
317 |
**kwargs: Unpack[Gemma3ProcessorKwargs],
|
318 |
) -> BatchFeature:
|
319 |
if text is None and images is None:
|
|
|
345 |
)
|
346 |
|
347 |
# Replace image tokens by the full expanded sequence
|
348 |
+
num_crops = to_py_obj(image_inputs.pop("num_crops"))
|
349 |
+
batch_num_crops = [[num_crops.pop(0) for _ in range(len(images))] for images in batched_images]
|
350 |
for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)):
|
351 |
image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)]
|
352 |
|
|
|
363 |
+ " ".join([self.boi_token] * num)
|
364 |
)
|
365 |
prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :]
|
366 |
+
text[batch_idx] = prompt
|
367 |
|
368 |
# Expand placeholder image tokens to the full image token sequence
|
369 |
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
370 |
+
|
371 |
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
372 |
|
373 |
audio_inputs = {}
|
374 |
+
if audio is not None:
|
375 |
def replace_tokens_sequentially(prompt, boa_token, audio_sequences):
|
376 |
parts = prompt.split(boa_token)
|
377 |
result = ""
|
|
|
385 |
return result
|
386 |
|
387 |
full_audio_sequences = []
|
388 |
+
audio_inputs = self.feature_extractor(audio)
|
389 |
|
390 |
for i, embed_size in enumerate(audio_inputs.audio_embed_sizes):
|
391 |
audio_tokens_expanded = "".join([self.audio_token] * embed_size)
|
|
|
397 |
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
398 |
|
399 |
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
400 |
+
array_ids = text_inputs["input_ids"]
|
401 |
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
402 |
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
403 |
mm_token_type_ids[array_ids == self.audio_token_id] = 2
|
|
|
411 |
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
412 |
text_inputs["input_modes"] = input_modes.tolist()
|
413 |
|
414 |
+
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
|
415 |
|
416 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
417 |
def batch_decode(self, *args, **kwargs):
|