katuni4ka commited on
Commit
d21e013
·
verified ·
1 Parent(s): 9ffce0b

Update modeling_phi4mm.py

Browse files
Files changed (1) hide show
  1. modeling_phi4mm.py +54 -1
modeling_phi4mm.py CHANGED
@@ -1505,7 +1505,7 @@ PHI4MM_START_DOCSTRING = r"""
1505
  "The bare Phi-4-MM model outputting raw hidden-states without any specific head on top.",
1506
  PHI4MM_START_DOCSTRING,
1507
  )
1508
- class Phi4MMPreTrainedModel(PreTrainedModel):
1509
  config_class = Phi4MMConfig
1510
  base_model_prefix = "model"
1511
  supports_gradient_checkpointing = True
@@ -1932,6 +1932,59 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
1932
  )
1933
  return causal_mask
1934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1935
 
1936
  class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
1937
  _tied_weights_keys = ["lm_head.weight"]
 
1505
  "The bare Phi-4-MM model outputting raw hidden-states without any specific head on top.",
1506
  PHI4MM_START_DOCSTRING,
1507
  )
1508
+ class Phi4MMPreTrainedModel(PreTrainedModel, GenerationMixin):
1509
  config_class = Phi4MMConfig
1510
  base_model_prefix = "model"
1511
  supports_gradient_checkpointing = True
 
1932
  )
1933
  return causal_mask
1934
 
1935
+ def prepare_inputs_for_generation(
1936
+ self,
1937
+ input_ids,
1938
+ past_key_values=None,
1939
+ attention_mask=None,
1940
+ inputs_embeds=None,
1941
+ input_image_embeds=None,
1942
+ image_sizes=None,
1943
+ image_attention_mask=None,
1944
+ input_audio_embeds=None,
1945
+ audio_embed_sizes=None,
1946
+ audio_attention_mask=None,
1947
+ input_mode=None,
1948
+ cache_position=None,
1949
+ position_ids=None,
1950
+ use_cache=True,
1951
+ num_logits_to_keep=0,
1952
+ **kwargs
1953
+ ):
1954
+ # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
1955
+ # process
1956
+
1957
+ # When the first time input length reached long and short factor switching point, enforce re-compute cache
1958
+ # It will cause downside of slower at this single token position, however, better than current failure.
1959
+ if (
1960
+ past_key_values
1961
+ and self.config.rope_scaling
1962
+ and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
1963
+ ):
1964
+ past_length = cache_position[0]
1965
+ if past_length <= self.config.original_max_position_embeddings:
1966
+ past_key_values = None
1967
+
1968
+ model_inputs = super().prepare_inputs_for_generation(
1969
+ input_ids=input_ids,
1970
+ past_key_values=past_key_values,
1971
+ attention_mask=attention_mask,
1972
+ inputs_embeds=inputs_embeds,
1973
+ input_image_embeds=input_image_embeds,
1974
+ image_sizes=image_sizes,
1975
+ image_attention_mask=image_attention_mask,
1976
+ input_audio_embeds=input_audio_embeds,
1977
+ audio_embed_sizes=audio_embed_sizes,
1978
+ audio_attention_mask=audio_attention_mask,
1979
+ input_mode=input_mode,
1980
+ cache_position=cache_position,
1981
+ position_ids=position_ids,
1982
+ use_cache=use_cache,
1983
+ num_logits_to_keep=num_logits_to_keep or 0,
1984
+ **kwargs,
1985
+ )
1986
+ return model_inputs
1987
+
1988
 
1989
  class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
1990
  _tied_weights_keys = ["lm_head.weight"]