Refactor Logits Naming
Browse filesThis PR adjusts internal logits naming in MoonshotKimiaForCausalLM: lm_head output is now text_logits, and mimo_output output is audio_logits. This modification does not change the effective output (as variable names and their corresponding order in the returned tuple were adjusted synchronously), aiming to resolve naming confusion encountered during further development. The usage in the inference script kimia_infer/api/kimia.py at line 114 – audio_logits, text_logits, ... = self.alm.forward(...), which assigns the first returned value to audio_logits (for audio) and the second to text_logits (for text) – also validates the rationale for this change.
modeling_moonshot_kimia.py
CHANGED
@@ -902,15 +902,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
|
|
902 |
else:
|
903 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
904 |
|
905 |
-
|
906 |
-
|
907 |
|
908 |
if not return_dict:
|
909 |
-
output = (
|
910 |
return output
|
911 |
return CausalLMOutputWithPast(
|
912 |
loss=None,
|
913 |
-
logits=(
|
914 |
past_key_values=outputs.past_key_values,
|
915 |
hidden_states=outputs.hidden_states,
|
916 |
attentions=outputs.attentions,
|
|
|
902 |
else:
|
903 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
904 |
|
905 |
+
text_logits = self.lm_head(hidden_states)
|
906 |
+
audio_logits = self.mimo_output(mimo_hidden_states)
|
907 |
|
908 |
if not return_dict:
|
909 |
+
output = (audio_logits, text_logits) + outputs[2:]
|
910 |
return output
|
911 |
return CausalLMOutputWithPast(
|
912 |
loss=None,
|
913 |
+
logits=(audio_logits, text_logits),
|
914 |
past_key_values=outputs.past_key_values,
|
915 |
hidden_states=outputs.hidden_states,
|
916 |
attentions=outputs.attentions,
|