Update modeling_phi4mm.py
Browse files- modeling_phi4mm.py +54 -1
modeling_phi4mm.py
CHANGED
@@ -1505,7 +1505,7 @@ PHI4MM_START_DOCSTRING = r"""
|
|
1505 |
"The bare Phi-4-MM model outputting raw hidden-states without any specific head on top.",
|
1506 |
PHI4MM_START_DOCSTRING,
|
1507 |
)
|
1508 |
-
class Phi4MMPreTrainedModel(PreTrainedModel):
|
1509 |
config_class = Phi4MMConfig
|
1510 |
base_model_prefix = "model"
|
1511 |
supports_gradient_checkpointing = True
|
@@ -1932,6 +1932,59 @@ class Phi4MMModel(Phi4MMPreTrainedModel):
|
|
1932 |
)
|
1933 |
return causal_mask
|
1934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1935 |
|
1936 |
class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
1937 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
1505 |
"The bare Phi-4-MM model outputting raw hidden-states without any specific head on top.",
|
1506 |
PHI4MM_START_DOCSTRING,
|
1507 |
)
|
1508 |
+
class Phi4MMPreTrainedModel(PreTrainedModel, GenerationMixin):
|
1509 |
config_class = Phi4MMConfig
|
1510 |
base_model_prefix = "model"
|
1511 |
supports_gradient_checkpointing = True
|
|
|
1932 |
)
|
1933 |
return causal_mask
|
1934 |
|
1935 |
+
def prepare_inputs_for_generation(
|
1936 |
+
self,
|
1937 |
+
input_ids,
|
1938 |
+
past_key_values=None,
|
1939 |
+
attention_mask=None,
|
1940 |
+
inputs_embeds=None,
|
1941 |
+
input_image_embeds=None,
|
1942 |
+
image_sizes=None,
|
1943 |
+
image_attention_mask=None,
|
1944 |
+
input_audio_embeds=None,
|
1945 |
+
audio_embed_sizes=None,
|
1946 |
+
audio_attention_mask=None,
|
1947 |
+
input_mode=None,
|
1948 |
+
cache_position=None,
|
1949 |
+
position_ids=None,
|
1950 |
+
use_cache=True,
|
1951 |
+
num_logits_to_keep=0,
|
1952 |
+
**kwargs
|
1953 |
+
):
|
1954 |
+
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
|
1955 |
+
# process
|
1956 |
+
|
1957 |
+
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
1958 |
+
# It will cause downside of slower at this single token position, however, better than current failure.
|
1959 |
+
if (
|
1960 |
+
past_key_values
|
1961 |
+
and self.config.rope_scaling
|
1962 |
+
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
|
1963 |
+
):
|
1964 |
+
past_length = cache_position[0]
|
1965 |
+
if past_length <= self.config.original_max_position_embeddings:
|
1966 |
+
past_key_values = None
|
1967 |
+
|
1968 |
+
model_inputs = super().prepare_inputs_for_generation(
|
1969 |
+
input_ids=input_ids,
|
1970 |
+
past_key_values=past_key_values,
|
1971 |
+
attention_mask=attention_mask,
|
1972 |
+
inputs_embeds=inputs_embeds,
|
1973 |
+
input_image_embeds=input_image_embeds,
|
1974 |
+
image_sizes=image_sizes,
|
1975 |
+
image_attention_mask=image_attention_mask,
|
1976 |
+
input_audio_embeds=input_audio_embeds,
|
1977 |
+
audio_embed_sizes=audio_embed_sizes,
|
1978 |
+
audio_attention_mask=audio_attention_mask,
|
1979 |
+
input_mode=input_mode,
|
1980 |
+
cache_position=cache_position,
|
1981 |
+
position_ids=position_ids,
|
1982 |
+
use_cache=use_cache,
|
1983 |
+
num_logits_to_keep=num_logits_to_keep or 0,
|
1984 |
+
**kwargs,
|
1985 |
+
)
|
1986 |
+
return model_inputs
|
1987 |
+
|
1988 |
|
1989 |
class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
1990 |
_tied_weights_keys = ["lm_head.weight"]
|