codecho commited on
Commit
5aed2aa
·
verified ·
1 Parent(s): a574f67

Refactor Logits Naming

Browse files

This PR adjusts internal logits naming in MoonshotKimiaForCausalLM: lm_head output is now text_logits, and mimo_output output is audio_logits. This modification does not change the effective output (as variable names and their corresponding order in the returned tuple were adjusted synchronously), aiming to resolve naming confusion encountered during further development. The usage in the inference script kimia_infer/api/kimia.py at line 114 – audio_logits, text_logits, ... = self.alm.forward(...), which assigns the first returned value to audio_logits (for audio) and the second to text_logits (for text) – also validates the rationale for this change.

Files changed (1) hide show
  1. modeling_moonshot_kimia.py +4 -4
modeling_moonshot_kimia.py CHANGED
@@ -902,15 +902,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
902
  else:
903
  hidden_states, mimo_hidden_states = outputs[0], outputs[1]
904
 
905
- audio_logits = self.lm_head(hidden_states)
906
- text_logits = self.mimo_output(mimo_hidden_states)
907
 
908
  if not return_dict:
909
- output = (text_logits, audio_logits) + outputs[2:]
910
  return output
911
  return CausalLMOutputWithPast(
912
  loss=None,
913
- logits=(text_logits, audio_logits),
914
  past_key_values=outputs.past_key_values,
915
  hidden_states=outputs.hidden_states,
916
  attentions=outputs.attentions,
 
902
  else:
903
  hidden_states, mimo_hidden_states = outputs[0], outputs[1]
904
 
905
+ text_logits = self.lm_head(hidden_states)
906
+ audio_logits = self.mimo_output(mimo_hidden_states)
907
 
908
  if not return_dict:
909
+ output = (audio_logits, text_logits) + outputs[2:]
910
  return output
911
  return CausalLMOutputWithPast(
912
  loss=None,
913
+ logits=(audio_logits, text_logits),
914
  past_key_values=outputs.past_key_values,
915
  hidden_states=outputs.hidden_states,
916
  attentions=outputs.attentions,