|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from nemo.core.classes import Loss |
|
|
from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType |
|
|
|
|
|
|
|
|
class RNNTLossPytorch(Loss): |
|
|
@property |
|
|
def input_types(self): |
|
|
"""Input types definitions for CTCLoss. |
|
|
""" |
|
|
return { |
|
|
"acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), |
|
|
"labels": NeuralType(('B', 'T'), LabelsType()), |
|
|
"act_lens": NeuralType(tuple('B'), LengthsType()), |
|
|
"label_lens": NeuralType(tuple('B'), LengthsType()), |
|
|
} |
|
|
|
|
|
@property |
|
|
def output_types(self): |
|
|
"""Output types definitions for CTCLoss. |
|
|
loss: |
|
|
NeuralType(None) |
|
|
""" |
|
|
return {"loss": NeuralType(elements_type=LossType())} |
|
|
|
|
|
def __init__(self, blank, reduction): |
|
|
super().__init__() |
|
|
self.blank = blank |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, acts, labels, act_lens, label_lens): |
|
|
acts = torch.log_softmax(acts, -1) |
|
|
forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens) |
|
|
losses = -forward_logprob |
|
|
if self.reduction == 'mean_batch': |
|
|
losses = losses.mean() |
|
|
elif self.reduction == 'mean': |
|
|
losses = torch.div(losses, label_lens).mean() |
|
|
elif self.reduction == 'sum': |
|
|
losses = losses.sum() |
|
|
elif self.reduction == 'mean_volume': |
|
|
losses = losses.sum() / label_lens.sum() |
|
|
|
|
|
return losses |
|
|
|
|
|
def compute_forward_prob(self, acts, labels, act_lens, label_lens): |
|
|
B, T, U, _ = acts.shape |
|
|
|
|
|
log_alpha = torch.zeros(B, T, U) |
|
|
log_alpha = log_alpha.to(acts.device) |
|
|
|
|
|
for t in range(T): |
|
|
for u in range(U): |
|
|
if u == 0: |
|
|
if t == 0: |
|
|
|
|
|
log_alpha[:, t, u] = 0.0 |
|
|
else: |
|
|
|
|
|
|
|
|
log_alpha[:, t, u] = log_alpha[:, t - 1, u] + acts[:, t - 1, 0, self.blank] |
|
|
else: |
|
|
if t == 0: |
|
|
|
|
|
|
|
|
gathered = torch.gather( |
|
|
acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64) |
|
|
).reshape(-1) |
|
|
log_alpha[:, t, u] = log_alpha[:, t, u - 1] + gathered.to(log_alpha.device) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
log_alpha[:, t, u] = torch.logsumexp( |
|
|
torch.stack( |
|
|
[ |
|
|
log_alpha[:, t - 1, u] + acts[:, t - 1, u, self.blank], |
|
|
log_alpha[:, t, u - 1] |
|
|
+ torch.gather( |
|
|
acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64) |
|
|
).reshape(-1), |
|
|
] |
|
|
), |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
log_probs = [] |
|
|
for b in range(B): |
|
|
|
|
|
to_append = ( |
|
|
log_alpha[b, act_lens[b] - 1, label_lens[b]] + acts[b, act_lens[b] - 1, label_lens[b], self.blank] |
|
|
) |
|
|
log_probs.append(to_append) |
|
|
log_prob = torch.stack(log_probs) |
|
|
|
|
|
return log_prob |
|
|
|
|
|
|
|
|
class MultiblankRNNTLossPytorch(Loss): |
|
|
""" |
|
|
Pure Python implementation of multi-blank transducer loss (https://arxiv.org/pdf/2211.03541.pdf) |
|
|
""" |
|
|
|
|
|
@property |
|
|
def input_types(self): |
|
|
"""Input types definitions for CTCLoss. |
|
|
""" |
|
|
return { |
|
|
"acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), |
|
|
"labels": NeuralType(('B', 'T'), LabelsType()), |
|
|
"act_lens": NeuralType(tuple('B'), LengthsType()), |
|
|
"label_lens": NeuralType(tuple('B'), LengthsType()), |
|
|
} |
|
|
|
|
|
@property |
|
|
def output_types(self): |
|
|
"""Output types definitions for CTCLoss. |
|
|
loss: |
|
|
NeuralType(None) |
|
|
""" |
|
|
return {"loss": NeuralType(elements_type=LossType())} |
|
|
|
|
|
def __init__(self, blank, big_blank_durations, reduction, sigma): |
|
|
super().__init__() |
|
|
self.blank = blank |
|
|
self.big_blank_durations = big_blank_durations |
|
|
self.reduction = reduction |
|
|
self.sigma = sigma |
|
|
|
|
|
def forward(self, acts, labels, act_lens, label_lens): |
|
|
acts = torch.log_softmax(acts, -1) - self.sigma |
|
|
forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens) |
|
|
|
|
|
losses = -forward_logprob |
|
|
if self.reduction == 'mean_batch': |
|
|
losses = losses.mean() |
|
|
elif self.reduction == 'mean': |
|
|
losses = torch.div(losses, label_lens).mean() |
|
|
elif self.reduction == 'sum': |
|
|
losses = losses.sum() |
|
|
elif self.reduction == 'mean_volume': |
|
|
losses = losses.sum() / label_lens.sum() |
|
|
|
|
|
return losses |
|
|
|
|
|
def compute_forward_prob(self, acts, labels, act_lens, label_lens): |
|
|
B, T, U, _ = acts.shape |
|
|
|
|
|
log_alpha = torch.zeros(B, T, U, device=acts.device) |
|
|
for t in range(T): |
|
|
for u in range(U): |
|
|
if u == 0: |
|
|
if t == 0: |
|
|
|
|
|
log_alpha[:, t, u] = 0.0 |
|
|
else: |
|
|
|
|
|
|
|
|
log_alpha[:, t, u] = log_alpha[:, t - 1, u] + acts[:, t - 1, 0, self.blank] |
|
|
for i, d in enumerate(self.big_blank_durations): |
|
|
if t >= d: |
|
|
tt = log_alpha[:, t - d, u] + acts[:, t - d, 0, self.blank - 1 - i] |
|
|
log_alpha[:, t, u] = torch.logsumexp( |
|
|
torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0 |
|
|
) |
|
|
|
|
|
else: |
|
|
if t == 0: |
|
|
|
|
|
|
|
|
gathered = torch.gather( |
|
|
acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64) |
|
|
).reshape(-1) |
|
|
log_alpha[:, t, u] = log_alpha[:, t, u - 1] + gathered |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_alpha[:, t, u] = torch.logsumexp( |
|
|
torch.stack( |
|
|
[ |
|
|
log_alpha[:, t - 1, u] + acts[:, t - 1, u, self.blank], |
|
|
log_alpha[:, t, u - 1] |
|
|
+ torch.gather( |
|
|
acts[:, t, u - 1], dim=1, index=labels[:, u - 1].view(-1, 1).type(torch.int64) |
|
|
).reshape(-1), |
|
|
] |
|
|
), |
|
|
dim=0, |
|
|
) |
|
|
|
|
|
|
|
|
for i, d in enumerate(self.big_blank_durations): |
|
|
if t >= d: |
|
|
tt = log_alpha[:, t - d, u] + acts[:, t - d, u, self.blank - 1 - i] |
|
|
log_alpha[:, t, u] = torch.logsumexp( |
|
|
torch.stack([1.0 * log_alpha[:, t, u], tt]), dim=0 |
|
|
) |
|
|
|
|
|
log_probs = [] |
|
|
for b in range(B): |
|
|
|
|
|
|
|
|
to_append = ( |
|
|
log_alpha[b, act_lens[b] - 1, label_lens[b]] + acts[b, act_lens[b] - 1, label_lens[b], self.blank] |
|
|
) |
|
|
|
|
|
for i, d in enumerate(self.big_blank_durations): |
|
|
if act_lens[b] >= d: |
|
|
tt = ( |
|
|
log_alpha[b, act_lens[b] - d, label_lens[b]] |
|
|
+ acts[b, act_lens[b] - d, label_lens[b], self.blank - 1 - i] |
|
|
) |
|
|
to_append = torch.logsumexp(torch.stack([1.0 * to_append, tt]), dim=0) |
|
|
|
|
|
log_probs.append(to_append) |
|
|
log_prob = torch.stack(log_probs) |
|
|
|
|
|
return log_prob |
|
|
|