Make Phi4MMForCausalLM.forward's num_logits_to_keep actually optional

#20
Files changed (1) hide show
  1. modeling_phi4mm.py +4 -1
modeling_phi4mm.py CHANGED
@@ -2134,7 +2134,10 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
2134
 
2135
  hidden_states = outputs[0]
2136
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2137
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
 
 
 
2138
 
2139
  loss = None
2140
  if labels is not None:
 
2134
 
2135
  hidden_states = outputs[0]
2136
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2137
+ if num_logits_to_keep:
2138
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
2139
+ else:
2140
+ logits = self.lm_head(hidden_states)
2141
 
2142
  loss = None
2143
  if labels is not None: