Make Phi4MMForCausalLM.forward's num_logits_to_keep actually optional
#20
by
phh
- opened
- modeling_phi4mm.py +4 -1
modeling_phi4mm.py
CHANGED
@@ -2134,7 +2134,10 @@ class Phi4MMForCausalLM(Phi4MMPreTrainedModel, GenerationMixin):
|
|
2134 |
|
2135 |
hidden_states = outputs[0]
|
2136 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
2137 |
-
|
|
|
|
|
|
|
2138 |
|
2139 |
loss = None
|
2140 |
if labels is not None:
|
|
|
2134 |
|
2135 |
hidden_states = outputs[0]
|
2136 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
2137 |
+
if num_logits_to_keep:
|
2138 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
2139 |
+
else:
|
2140 |
+
logits = self.lm_head(hidden_states)
|
2141 |
|
2142 |
loss = None
|
2143 |
if labels is not None:
|