Update modeling_dicow.py
Browse files- modeling_dicow.py +0 -3
modeling_dicow.py
CHANGED
|
@@ -332,9 +332,6 @@ class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalG
|
|
| 332 |
dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
|
| 333 |
dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1))
|
| 334 |
dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
|
| 335 |
-
if wandb.run is not None:
|
| 336 |
-
wandb.log({"dec_loss": dec_loss})
|
| 337 |
-
wandb.log({"ctc_loss": ctc_loss})
|
| 338 |
loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
|
| 339 |
|
| 340 |
|
|
|
|
| 332 |
dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
|
| 333 |
dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1))
|
| 334 |
dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean()
|
|
|
|
|
|
|
|
|
|
| 335 |
loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss
|
| 336 |
|
| 337 |
|