Upload modeling_gemma3mm.py
Browse files- modeling_gemma3mm.py +4 -4
modeling_gemma3mm.py
CHANGED
@@ -24,7 +24,7 @@ from transformers.utils import (
|
|
24 |
from transformers.utils.deprecation import deprecate_kwarg
|
25 |
from transformers import AutoModel, AutoModelForCausalLM
|
26 |
|
27 |
-
from transformers.models.gemma3.modeling_gemma3 import
|
28 |
|
29 |
from transformers import AutoConfig, AutoModelForCausalLM
|
30 |
|
@@ -337,7 +337,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
|
|
337 |
|
338 |
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
339 |
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
340 |
-
@replace_return_docstrings(output_type=
|
341 |
def forward(
|
342 |
self,
|
343 |
input_ids: torch.LongTensor = None,
|
@@ -359,7 +359,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
|
|
359 |
return_dict: Optional[bool] = None,
|
360 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
361 |
**lm_kwargs,
|
362 |
-
) -> Union[Tuple,
|
363 |
r"""
|
364 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
365 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
@@ -551,7 +551,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
|
|
551 |
output = (logits,) + outputs[1:]
|
552 |
return (loss,) + output if loss is not None else output
|
553 |
|
554 |
-
return
|
555 |
loss=loss,
|
556 |
logits=logits,
|
557 |
past_key_values=outputs.past_key_values,
|
|
|
24 |
from transformers.utils.deprecation import deprecate_kwarg
|
25 |
from transformers import AutoModel, AutoModelForCausalLM
|
26 |
|
27 |
+
from transformers.models.gemma3.modeling_gemma3 import Gemma3PreTrainedModel, Gemma3MultiModalProjector
|
28 |
|
29 |
from transformers import AutoConfig, AutoModelForCausalLM
|
30 |
|
|
|
337 |
|
338 |
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
339 |
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
340 |
+
@replace_return_docstrings(output_type=Gemma3MMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
341 |
def forward(
|
342 |
self,
|
343 |
input_ids: torch.LongTensor = None,
|
|
|
359 |
return_dict: Optional[bool] = None,
|
360 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
361 |
**lm_kwargs,
|
362 |
+
) -> Union[Tuple, Gemma3MMCausalLMOutputWithPast]:
|
363 |
r"""
|
364 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
365 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
|
551 |
output = (logits,) + outputs[1:]
|
552 |
return (loss,) + output if loss is not None else output
|
553 |
|
554 |
+
return Gemma3MMCausalLMOutputWithPast(
|
555 |
loss=loss,
|
556 |
logits=logits,
|
557 |
past_key_values=outputs.past_key_values,
|