junnei commited on
Commit
9c3cebd
·
verified ·
1 Parent(s): badc51c

Upload modeling_gemma3mm.py

Browse files
Files changed (1) hide show
  1. modeling_gemma3mm.py +4 -4
modeling_gemma3mm.py CHANGED
@@ -24,7 +24,7 @@ from transformers.utils import (
24
  from transformers.utils.deprecation import deprecate_kwarg
25
  from transformers import AutoModel, AutoModelForCausalLM
26
 
27
- from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast, Gemma3PreTrainedModel, Gemma3MultiModalProjector
28
 
29
  from transformers import AutoConfig, AutoModelForCausalLM
30
 
@@ -337,7 +337,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
337
 
338
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
339
  @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
340
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
341
  def forward(
342
  self,
343
  input_ids: torch.LongTensor = None,
@@ -359,7 +359,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
359
  return_dict: Optional[bool] = None,
360
  logits_to_keep: Union[int, torch.Tensor] = 0,
361
  **lm_kwargs,
362
- ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
363
  r"""
364
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
365
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -551,7 +551,7 @@ class Gemma3MMForConditionalGeneration(Gemma3MMPreTrainedModel, GenerationMixin)
551
  output = (logits,) + outputs[1:]
552
  return (loss,) + output if loss is not None else output
553
 
554
- return Gemma3CausalLMOutputWithPast(
555
  loss=loss,
556
  logits=logits,
557
  past_key_values=outputs.past_key_values,
 
24
  from transformers.utils.deprecation import deprecate_kwarg
25
  from transformers import AutoModel, AutoModelForCausalLM
26
 
27
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3PreTrainedModel, Gemma3MultiModalProjector
28
 
29
  from transformers import AutoConfig, AutoModelForCausalLM
30
 
 
337
 
338
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
339
  @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
340
+ @replace_return_docstrings(output_type=Gemma3MMCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
341
  def forward(
342
  self,
343
  input_ids: torch.LongTensor = None,
 
359
  return_dict: Optional[bool] = None,
360
  logits_to_keep: Union[int, torch.Tensor] = 0,
361
  **lm_kwargs,
362
+ ) -> Union[Tuple, Gemma3MMCausalLMOutputWithPast]:
363
  r"""
364
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
365
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 
551
  output = (logits,) + outputs[1:]
552
  return (loss,) + output if loss is not None else output
553
 
554
+ return Gemma3MMCausalLMOutputWithPast(
555
  loss=loss,
556
  logits=logits,
557
  past_key_values=outputs.past_key_values,