duzx16
commited on
Commit
·
5c64357
1
Parent(s):
8127ab6
Set ignore_index for CrossEntropyLoss
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
|
@@ -1124,7 +1124,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1124 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 1125 |
shift_labels = labels[..., 1:].contiguous()
|
| 1126 |
# Flatten the tokens
|
| 1127 |
-
loss_fct = CrossEntropyLoss()
|
| 1128 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 1129 |
|
| 1130 |
lm_logits = lm_logits.to(hidden_states.dtype)
|
|
|
|
| 1124 |
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 1125 |
shift_labels = labels[..., 1:].contiguous()
|
| 1126 |
# Flatten the tokens
|
| 1127 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
| 1128 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 1129 |
|
| 1130 |
lm_logits = lm_logits.to(hidden_states.dtype)
|