|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from omegaconf import DictConfig, OmegaConf, open_dict |
|
|
|
|
|
from nemo.collections.asr.data.audio_to_ctm_dataset import FrameCtmUnit |
|
|
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs |
|
|
from nemo.collections.asr.models.asr_model import ASRModel |
|
|
from nemo.utils import logging |
|
|
|
|
|
|
|
|
class AlignerWrapperModel(ASRModel): |
|
|
"""ASR model wrapper to perform alignment building. |
|
|
Functionality is limited to the components needed to build an alignment.""" |
|
|
|
|
|
def __init__(self, model: ASRModel, cfg: DictConfig): |
|
|
model_cfg = model.cfg |
|
|
for ds in ("train_ds", "validation_ds", "test_ds"): |
|
|
if ds in model_cfg: |
|
|
model_cfg[ds] = None |
|
|
super().__init__(cfg=model_cfg, trainer=model.trainer) |
|
|
self._model = model |
|
|
self.alignment_type = cfg.get("alignment_type", "forced") |
|
|
self.word_output = cfg.get("word_output", True) |
|
|
self.cpu_decoding = cfg.get("cpu_decoding", False) |
|
|
self.decode_batch_size = cfg.get("decode_batch_size", 0) |
|
|
|
|
|
|
|
|
if self.alignment_type == "forced": |
|
|
pass |
|
|
elif self.alignment_type == "argmax": |
|
|
pass |
|
|
elif self.alignment_type == "loose": |
|
|
raise NotImplementedError(f"alignment_type=`{self.alignment_type}` is not supported at the moment.") |
|
|
elif self.alignment_type == "rnnt_decoding_aux": |
|
|
raise NotImplementedError(f"alignment_type=`{self.alignment_type}` is not supported at the moment.") |
|
|
else: |
|
|
raise RuntimeError(f"Unsupported alignment type: {self.alignment_type}") |
|
|
|
|
|
self._init_model_specific(cfg) |
|
|
|
|
|
def _init_ctc_alignment_specific(self, cfg: DictConfig): |
|
|
"""Part of __init__ intended to initialize attributes specific to the alignment type for CTC models. |
|
|
|
|
|
This method is not supposed to be called outside of __init__. |
|
|
""" |
|
|
|
|
|
if self.alignment_type == "argmax" and not hasattr(self._model, "use_graph_lm"): |
|
|
return |
|
|
|
|
|
from nemo.collections.asr.modules.graph_decoder import ViterbiDecoderWithGraph |
|
|
|
|
|
if self.alignment_type == "forced": |
|
|
if hasattr(self._model, "use_graph_lm"): |
|
|
if self._model.use_graph_lm: |
|
|
self.graph_decoder = self._model.transcribe_decoder |
|
|
self._model.use_graph_lm = False |
|
|
else: |
|
|
self.graph_decoder = ViterbiDecoderWithGraph( |
|
|
num_classes=self.blank_id, backend="k2", dec_type="topo", return_type="1best" |
|
|
) |
|
|
|
|
|
self.graph_decoder.split_batch_size = self.decode_batch_size |
|
|
else: |
|
|
self.graph_decoder = ViterbiDecoderWithGraph( |
|
|
num_classes=self.blank_id, split_batch_size=self.decode_batch_size, |
|
|
) |
|
|
|
|
|
decoder_module_cfg = cfg.get("decoder_module_cfg", None) |
|
|
if decoder_module_cfg is not None: |
|
|
self.graph_decoder._decoder.intersect_pruned = decoder_module_cfg.get("intersect_pruned") |
|
|
self.graph_decoder._decoder.intersect_conf = decoder_module_cfg.get("intersect_conf") |
|
|
return |
|
|
|
|
|
if self.alignment_type == "argmax": |
|
|
|
|
|
if not self._model.use_graph_lm: |
|
|
self._model.transcribe_decoder = ViterbiDecoderWithGraph( |
|
|
num_classes=self.blank_id, backend="k2", dec_type="topo", return_type="1best" |
|
|
) |
|
|
|
|
|
self._model.transcribe_decoder.return_ilabels = False |
|
|
self._model.transcribe_decoder.output_aligned = True |
|
|
self._model.transcribe_decoder.split_batch_size = self.decode_batch_size |
|
|
self._model.use_graph_lm = False |
|
|
return |
|
|
|
|
|
def _init_rnnt_alignment_specific(self, cfg: DictConfig): |
|
|
"""Part of __init__ intended to initialize attributes specific to the alignment type for RNNT models. |
|
|
|
|
|
This method is not supposed to be called outside of __init__. |
|
|
""" |
|
|
if self.alignment_type == "argmax": |
|
|
return |
|
|
|
|
|
from nemo.collections.asr.modules.graph_decoder import ViterbiDecoderWithGraph |
|
|
|
|
|
if self.alignment_type == "forced": |
|
|
self.predictor_window_size = cfg.rnnt_cfg.get("predictor_window_size", 0) |
|
|
self.predictor_step_size = cfg.rnnt_cfg.get("predictor_step_size", 0) |
|
|
|
|
|
from nemo.collections.asr.parts.k2.utils import apply_rnnt_prune_ranges, get_uniform_rnnt_prune_ranges |
|
|
|
|
|
self.prepare_pruned_outputs = lambda encoder_outputs, encoded_len, decoder_outputs, transcript_len: apply_rnnt_prune_ranges( |
|
|
encoder_outputs, |
|
|
decoder_outputs, |
|
|
get_uniform_rnnt_prune_ranges( |
|
|
encoded_len, |
|
|
transcript_len, |
|
|
self.predictor_window_size + 1, |
|
|
self.predictor_step_size, |
|
|
encoder_outputs.size(1), |
|
|
).to(device=encoder_outputs.device), |
|
|
) |
|
|
|
|
|
from nemo.collections.asr.parts.k2.classes import GraphModuleConfig |
|
|
|
|
|
self.graph_decoder = ViterbiDecoderWithGraph( |
|
|
num_classes=self.blank_id, |
|
|
backend="k2", |
|
|
dec_type="topo_rnnt_ali", |
|
|
split_batch_size=self.decode_batch_size, |
|
|
graph_module_cfg=OmegaConf.structured( |
|
|
GraphModuleConfig( |
|
|
topo_type="minimal", |
|
|
predictor_window_size=self.predictor_window_size, |
|
|
predictor_step_size=self.predictor_step_size, |
|
|
) |
|
|
), |
|
|
) |
|
|
|
|
|
decoder_module_cfg = cfg.get("decoder_module_cfg", None) |
|
|
if decoder_module_cfg is not None: |
|
|
self.graph_decoder._decoder.intersect_pruned = decoder_module_cfg.get("intersect_pruned") |
|
|
self.graph_decoder._decoder.intersect_conf = decoder_module_cfg.get("intersect_conf") |
|
|
return |
|
|
|
|
|
def _init_model_specific(self, cfg: DictConfig): |
|
|
"""Part of __init__ intended to initialize attributes specific to the model type. |
|
|
|
|
|
This method is not supposed to be called outside of __init__. |
|
|
""" |
|
|
from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
|
|
|
|
|
if isinstance(self._model, EncDecCTCModel): |
|
|
self.model_type = "ctc" |
|
|
self.blank_id = self._model.decoder.num_classes_with_blank - 1 |
|
|
self._predict_impl = self._predict_impl_ctc |
|
|
|
|
|
prob_suppress_index = cfg.ctc_cfg.get("prob_suppress_index", -1) |
|
|
prob_suppress_value = cfg.ctc_cfg.get("prob_suppress_value", 1.0) |
|
|
if prob_suppress_value > 1 or prob_suppress_value <= 0: |
|
|
raise ValueError(f"Suppression value has to be in (0,1]: {prob_suppress_value}") |
|
|
if prob_suppress_index < -(self.blank_id + 1) or prob_suppress_index > self.blank_id: |
|
|
raise ValueError( |
|
|
f"Suppression index for the provided model has to be in [{-self.blank_id+1},{self.blank_id}]: {prob_suppress_index}" |
|
|
) |
|
|
self.prob_suppress_index = ( |
|
|
self._model.decoder.num_classes_with_blank + prob_suppress_index |
|
|
if prob_suppress_index < 0 |
|
|
else prob_suppress_index |
|
|
) |
|
|
self.prob_suppress_value = prob_suppress_value |
|
|
|
|
|
self._init_ctc_alignment_specific(cfg) |
|
|
return |
|
|
|
|
|
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel |
|
|
|
|
|
if isinstance(self._model, EncDecRNNTModel): |
|
|
self.model_type = "rnnt" |
|
|
self.blank_id = self._model.joint.num_classes_with_blank - 1 |
|
|
self.log_softmax = None if self._model.joint.log_softmax is None else not self._model.joint.log_softmax |
|
|
self._predict_impl = self._predict_impl_rnnt |
|
|
|
|
|
decoding_config = copy.deepcopy(self._model.cfg.decoding) |
|
|
decoding_config.strategy = "greedy_batch" |
|
|
with open_dict(decoding_config): |
|
|
decoding_config.preserve_alignments = True |
|
|
decoding_config.fused_batch_size = -1 |
|
|
self._model.change_decoding_strategy(decoding_config) |
|
|
self._init_rnnt_alignment_specific(cfg) |
|
|
return |
|
|
|
|
|
raise RuntimeError(f"Unsupported model type: {type(self._model)}") |
|
|
|
|
|
def _rnnt_joint_pruned( |
|
|
self, |
|
|
encoder_outputs: torch.Tensor, |
|
|
encoded_len: torch.Tensor, |
|
|
decoder_outputs: torch.Tensor, |
|
|
transcript_len: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""A variant of the RNNT Joiner tensor calculation with pruned Encoder and Predictor sum. |
|
|
Only the uniform pruning is supported at the moment. |
|
|
""" |
|
|
encoder_outputs = self._model.joint.enc(encoder_outputs.transpose(1, 2)) |
|
|
decoder_outputs = self._model.joint.pred(decoder_outputs.transpose(1, 2)) |
|
|
|
|
|
encoder_outputs_pruned, decoder_outputs_pruned = self.prepare_pruned_outputs( |
|
|
encoder_outputs, encoded_len, decoder_outputs, transcript_len |
|
|
) |
|
|
res = self._model.joint.joint_net(encoder_outputs_pruned + decoder_outputs_pruned) |
|
|
|
|
|
if self._model.joint.log_softmax is None: |
|
|
if not res.is_cuda: |
|
|
res = res.log_softmax(dim=-1) |
|
|
else: |
|
|
if self._model.joint.log_softmax: |
|
|
res = res.log_softmax(dim=-1) |
|
|
return res |
|
|
|
|
|
def _apply_prob_suppress(self, log_probs: torch.Tensor) -> torch.Tensor: |
|
|
"""Multiplies probability of an element with index self.prob_suppress_index by self.prob_suppress_value times |
|
|
with stochasticity preservation of the log_probs tensor. |
|
|
|
|
|
Often used to suppress <blank> probability of the output of a CTC model. |
|
|
|
|
|
Example: |
|
|
For |
|
|
- log_probs = torch.log(torch.tensor([0.015, 0.085, 0.9])) |
|
|
- self.prob_suppress_index = -1 |
|
|
- self.prob_suppress_value = 0.5 |
|
|
the result of _apply_prob_suppress(log_probs) is |
|
|
- torch.log(torch.tensor([0.0825, 0.4675, 0.45])) |
|
|
""" |
|
|
exp_probs = (log_probs).exp() |
|
|
x = exp_probs[:, :, self.prob_suppress_index] |
|
|
|
|
|
y = torch.cat( |
|
|
[exp_probs[:, :, : self.prob_suppress_index], exp_probs[:, :, self.prob_suppress_index + 1 :]], 2 |
|
|
).sum(-1) |
|
|
b1 = torch.full((exp_probs.shape[0], exp_probs.shape[1], 1), self.prob_suppress_value, device=log_probs.device) |
|
|
b2 = ((1 - self.prob_suppress_value * x) / y).unsqueeze(2).repeat(1, 1, exp_probs.shape[-1] - 1) |
|
|
return ( |
|
|
exp_probs * torch.cat([b2[:, :, : self.prob_suppress_index], b1, b2[:, :, self.prob_suppress_index :]], 2) |
|
|
).log() |
|
|
|
|
|
def _prepare_ctc_argmax_predictions( |
|
|
self, log_probs: torch.Tensor, encoded_len: torch.Tensor |
|
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
|
"""Obtains argmax predictions with corresponding probabilities. |
|
|
Replaces consecutive repeated indices in the argmax predictions with the <blank> index. |
|
|
""" |
|
|
if hasattr(self._model, "transcribe_decoder"): |
|
|
predictions, _, probs = self.transcribe_decoder.forward(log_probs=log_probs, log_probs_length=encoded_len) |
|
|
else: |
|
|
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) |
|
|
probs_tensor, _ = log_probs.exp().max(dim=-1, keepdim=False) |
|
|
predictions, probs = [], [] |
|
|
for i in range(log_probs.shape[0]): |
|
|
utt_len = encoded_len[i] |
|
|
probs.append(probs_tensor[i, :utt_len]) |
|
|
pred_candidate = greedy_predictions[i, :utt_len].cpu() |
|
|
|
|
|
previous = self.blank_id |
|
|
for j in range(utt_len): |
|
|
p = pred_candidate[j] |
|
|
if p == previous and previous != self.blank_id: |
|
|
pred_candidate[j] = self.blank_id |
|
|
previous = p |
|
|
predictions.append(pred_candidate.to(device=greedy_predictions.device)) |
|
|
return predictions, probs |
|
|
|
|
|
def _predict_impl_rnnt_argmax( |
|
|
self, |
|
|
encoded: torch.Tensor, |
|
|
encoded_len: torch.Tensor, |
|
|
transcript: torch.Tensor, |
|
|
transcript_len: torch.Tensor, |
|
|
sample_id: torch.Tensor, |
|
|
) -> List[Tuple[int, 'FrameCtmUnit']]: |
|
|
"""Builds time alignment of an encoded sequence. |
|
|
This method assumes that the RNNT model is used and the alignment type is `argmax`. |
|
|
|
|
|
It produces a list of sample ids and fours: (label, start_frame, length, probability), called FrameCtmUnit. |
|
|
""" |
|
|
hypotheses = self._model.decoding.rnnt_decoder_predictions_tensor( |
|
|
encoded, encoded_len, return_hypotheses=True |
|
|
)[0] |
|
|
results = [] |
|
|
for s_id, hypothesis in zip(sample_id, hypotheses): |
|
|
pred_ids = hypothesis.y_sequence.tolist() |
|
|
tokens = self._model.decoding.decode_ids_to_tokens(pred_ids) |
|
|
token_begin = hypothesis.timestep |
|
|
token_len = [j - i for i, j in zip(token_begin, token_begin[1:] + [len(hypothesis.alignments)])] |
|
|
|
|
|
token_prob = [1.0] * len(tokens) |
|
|
if self.word_output: |
|
|
words = [w for w in self._model.decoding.decode_tokens_to_str(pred_ids).split(" ") if w != ""] |
|
|
words, word_begin, word_len, word_prob = ( |
|
|
self._process_tokens_to_words(tokens, token_begin, token_len, token_prob, words) |
|
|
if hasattr(self._model, "tokenizer") |
|
|
else self._process_char_with_space_to_words(tokens, token_begin, token_len, token_prob, words) |
|
|
) |
|
|
results.append( |
|
|
(s_id, [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(words, word_begin, word_len, word_prob)]) |
|
|
) |
|
|
else: |
|
|
results.append( |
|
|
( |
|
|
s_id, |
|
|
[FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(tokens, token_begin, token_len, token_prob)], |
|
|
) |
|
|
) |
|
|
return results |
|
|
|
|
|
def _process_tokens_to_words( |
|
|
self, |
|
|
tokens: List[str], |
|
|
token_begin: List[int], |
|
|
token_len: List[int], |
|
|
token_prob: List[float], |
|
|
words: List[str], |
|
|
) -> Tuple[List[str], List[int], List[int], List[float]]: |
|
|
"""Transforms alignment information from token level to word level. |
|
|
|
|
|
Used when self._model.tokenizer is present. |
|
|
""" |
|
|
|
|
|
assert len(self._model.tokenizer.text_to_tokens(words[0])) == len( |
|
|
self._model.tokenizer.text_to_tokens(words[0] + " ") |
|
|
) |
|
|
word_begin, word_len, word_prob = [], [], [] |
|
|
token_len_nonzero = [(t_l if t_l > 0 else 1) for t_l in token_len] |
|
|
i = 0 |
|
|
for word in words: |
|
|
loc_tokens = self._model.tokenizer.text_to_tokens(word) |
|
|
step = len(loc_tokens) |
|
|
|
|
|
|
|
|
if step == 0: |
|
|
token_begin[i + 1] = token_begin[i] |
|
|
token_len[i + 1] += token_len[i] |
|
|
token_len_nonzero[i + 1] += token_len_nonzero[i] |
|
|
del tokens[i], token_begin[i], token_len[i], token_len_nonzero[i], token_prob[i] |
|
|
continue |
|
|
|
|
|
if step == 2 and loc_tokens[-1] == "??": |
|
|
step -= 1 |
|
|
j = i + step |
|
|
word_begin.append(token_begin[i]) |
|
|
word_len.append(sum(token_len[i:j])) |
|
|
denominator = sum(token_len_nonzero[i:j]) |
|
|
word_prob.append(sum(token_prob[k] * token_len_nonzero[k] for k in range(i, j)) / denominator) |
|
|
i = j |
|
|
return words, word_begin, word_len, word_prob |
|
|
|
|
|
def _process_char_with_space_to_words( |
|
|
self, |
|
|
tokens: List[str], |
|
|
token_begin: List[int], |
|
|
token_len: List[int], |
|
|
token_prob: List[float], |
|
|
words: List[str], |
|
|
) -> Tuple[List[str], List[int], List[int], List[float]]: |
|
|
"""Transforms alignment information from character level to word level. |
|
|
This method includes separator (typically the space) information in the results. |
|
|
|
|
|
Used with character-based models (no self._model.tokenizer). |
|
|
""" |
|
|
|
|
|
space_idx = (np.array(tokens) == " ").nonzero()[0].tolist() |
|
|
assert len(words) == len(space_idx) + 1 |
|
|
token_len_nonzero = [(t_l if t_l > 0 else 1) for t_l in token_len] |
|
|
if len(space_idx) == 0: |
|
|
word_begin = [token_begin[0]] |
|
|
word_len = [sum(token_len)] |
|
|
denominator = sum(token_len_nonzero) |
|
|
word_prob = [sum(t_p * t_l for t_p, t_l in zip(token_prob, token_len_nonzero)) / denominator] |
|
|
else: |
|
|
space_word = "[SEP]" |
|
|
word_begin = [token_begin[0]] |
|
|
word_len = [sum(token_len[: space_idx[0]])] |
|
|
denominator = sum(token_len_nonzero[: space_idx[0]]) |
|
|
word_prob = [sum(token_prob[k] * token_len_nonzero[k] for k in range(space_idx[0])) / denominator] |
|
|
words_with_space = [words[0]] |
|
|
for word, i, j in zip(words[1:], space_idx, space_idx[1:] + [len(tokens)]): |
|
|
|
|
|
word_begin.append(token_begin[i]) |
|
|
word_len.append(token_len[i]) |
|
|
word_prob.append(token_prob[i]) |
|
|
words_with_space.append(space_word) |
|
|
|
|
|
word_begin.append(token_begin[i + 1]) |
|
|
word_len.append(sum(token_len[i + 1 : j])) |
|
|
denominator = sum(token_len_nonzero[i + 1 : j]) |
|
|
word_prob.append(sum(token_prob[k] * token_len_nonzero[k] for k in range(i + 1, j)) / denominator) |
|
|
words_with_space.append(word) |
|
|
words = words_with_space |
|
|
return words, word_begin, word_len, word_prob |
|
|
|
|
|
def _results_to_ctmUnits( |
|
|
self, s_id: int, pred: torch.Tensor, prob: torch.Tensor |
|
|
) -> Tuple[int, List['FrameCtmUnit']]: |
|
|
"""Transforms predictions with probabilities to a list of FrameCtmUnit objects, |
|
|
containing frame-level alignment information (label, start, duration, probability), for a given sample id. |
|
|
|
|
|
Alignment information can be either token-based (char, wordpiece, ...) or word-based. |
|
|
""" |
|
|
if len(pred) == 0: |
|
|
return (s_id, []) |
|
|
|
|
|
non_blank_idx = (pred != self.blank_id).nonzero(as_tuple=True)[0].cpu() |
|
|
pred_ids = pred[non_blank_idx].tolist() |
|
|
prob_list = prob.tolist() |
|
|
if self.model_type == "rnnt": |
|
|
wer_module = self._model.decoding |
|
|
|
|
|
|
|
|
token_begin = non_blank_idx - torch.arange(len(non_blank_idx)) |
|
|
token_end = torch.cat((token_begin[1:], torch.tensor([len(pred) - len(non_blank_idx)]))) |
|
|
else: |
|
|
wer_module = self._model._wer |
|
|
token_begin = non_blank_idx |
|
|
token_end = torch.cat((token_begin[1:], torch.tensor([len(pred)]))) |
|
|
tokens = wer_module.decode_ids_to_tokens(pred_ids) |
|
|
token_len = (token_end - token_begin).tolist() |
|
|
token_begin = token_begin.tolist() |
|
|
token_prob = [ |
|
|
sum(prob_list[i:j]) / (j - i) |
|
|
for i, j in zip(non_blank_idx.tolist(), non_blank_idx[1:].tolist() + [len(pred)]) |
|
|
] |
|
|
if self.word_output: |
|
|
words = wer_module.decode_tokens_to_str(pred_ids).split(" ") |
|
|
words, word_begin, word_len, word_prob = ( |
|
|
self._process_tokens_to_words(tokens, token_begin, token_len, token_prob, words) |
|
|
if hasattr(self._model, "tokenizer") |
|
|
else self._process_char_with_space_to_words(tokens, token_begin, token_len, token_prob, words) |
|
|
) |
|
|
return s_id, [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(words, word_begin, word_len, word_prob)] |
|
|
return s_id, [FrameCtmUnit(t, b, l, p) for t, b, l, p in zip(tokens, token_begin, token_len, token_prob)] |
|
|
|
|
|
def _predict_impl_ctc( |
|
|
self, |
|
|
encoded: torch.Tensor, |
|
|
encoded_len: torch.Tensor, |
|
|
transcript: torch.Tensor, |
|
|
transcript_len: torch.Tensor, |
|
|
sample_id: torch.Tensor, |
|
|
) -> List[Tuple[int, 'FrameCtmUnit']]: |
|
|
"""Builds time alignment of an encoded sequence. |
|
|
This method assumes that the CTC model is used. |
|
|
|
|
|
It produces a list of sample ids and fours: (label, start_frame, length, probability), called FrameCtmUnit. |
|
|
""" |
|
|
log_probs = encoded |
|
|
|
|
|
if self.prob_suppress_value != 1.0: |
|
|
log_probs = self._apply_prob_suppress(log_probs) |
|
|
|
|
|
if self.alignment_type == "argmax": |
|
|
predictions, probs = self._prepare_ctc_argmax_predictions(log_probs, encoded_len) |
|
|
elif self.alignment_type == "forced": |
|
|
if self.cpu_decoding: |
|
|
log_probs, encoded_len, transcript, transcript_len = ( |
|
|
log_probs.cpu(), |
|
|
encoded_len.cpu(), |
|
|
transcript.cpu(), |
|
|
transcript_len.cpu(), |
|
|
) |
|
|
predictions, probs = self.graph_decoder.align(log_probs, encoded_len, transcript, transcript_len) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
return [ |
|
|
self._results_to_ctmUnits(s_id, pred, prob) |
|
|
for s_id, pred, prob in zip(sample_id.tolist(), predictions, probs) |
|
|
] |
|
|
|
|
|
def _predict_impl_rnnt( |
|
|
self, |
|
|
encoded: torch.Tensor, |
|
|
encoded_len: torch.Tensor, |
|
|
transcript: torch.Tensor, |
|
|
transcript_len: torch.Tensor, |
|
|
sample_id: torch.Tensor, |
|
|
) -> List[Tuple[int, 'FrameCtmUnit']]: |
|
|
"""Builds time alignment of an encoded sequence. |
|
|
This method assumes that the RNNT model is used. |
|
|
|
|
|
It produces a list of sample ids and fours: (label, start_frame, length, probability), called FrameCtmUnit. |
|
|
""" |
|
|
if self.alignment_type == "argmax": |
|
|
return self._predict_impl_rnnt_argmax(encoded, encoded_len, transcript, transcript_len, sample_id) |
|
|
elif self.alignment_type == "forced": |
|
|
decoded = self._model.decoder(targets=transcript, target_length=transcript_len)[0] |
|
|
log_probs = ( |
|
|
self._rnnt_joint_pruned(encoded, encoded_len, decoded, transcript_len) |
|
|
if self.predictor_window_size > 0 and self.predictor_window_size < transcript_len.max() |
|
|
else self._model.joint(encoder_outputs=encoded, decoder_outputs=decoded) |
|
|
) |
|
|
apply_log_softmax = True if self.log_softmax is None and encoded.is_cuda else self.log_softmax |
|
|
if apply_log_softmax: |
|
|
log_probs = log_probs.log_softmax(dim=-1) |
|
|
if self.cpu_decoding: |
|
|
log_probs, encoded_len, transcript, transcript_len = ( |
|
|
log_probs.cpu(), |
|
|
encoded_len.cpu(), |
|
|
transcript.cpu(), |
|
|
transcript_len.cpu(), |
|
|
) |
|
|
predictions, probs = self.graph_decoder.align(log_probs, encoded_len, transcript, transcript_len) |
|
|
return [ |
|
|
self._results_to_ctmUnits(s_id, pred, prob) |
|
|
for s_id, pred, prob in zip(sample_id.tolist(), predictions, probs) |
|
|
] |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_step(self, batch, batch_idx, dataloader_idx=0) -> List[Tuple[int, 'FrameCtmUnit']]: |
|
|
signal, signal_len, transcript, transcript_len, sample_id = batch |
|
|
|
|
|
if isinstance(batch, DALIOutputs) and batch.has_processed_signal: |
|
|
encoded, encoded_len = self._model.forward(processed_signal=signal, processed_signal_length=signal_len)[:2] |
|
|
else: |
|
|
encoded, encoded_len = self._model.forward(input_signal=signal, input_signal_length=signal_len)[:2] |
|
|
|
|
|
return self._predict_impl(encoded, encoded_len, transcript, transcript_len, sample_id) |
|
|
|
|
|
@torch.no_grad() |
|
|
def transcribe(self, manifest: List[str], batch_size: int = 4, num_workers: int = None,) -> List['FrameCtmUnit']: |
|
|
""" |
|
|
Does alignment. Use this method for debugging and prototyping. |
|
|
|
|
|
Args: |
|
|
|
|
|
manifest: path to dataset JSON manifest file (in NeMo format). \ |
|
|
Recommended length per audio file is between 5 and 25 seconds. |
|
|
batch_size: (int) batch size to use during inference. \ |
|
|
Bigger will result in better throughput performance but would use more memory. |
|
|
num_workers: (int) number of workers for DataLoader |
|
|
|
|
|
Returns: |
|
|
A list of four: (label, start_frame, length, probability), called FrameCtmUnit, \ |
|
|
in the same order as in the manifest. |
|
|
""" |
|
|
hypotheses = [] |
|
|
|
|
|
mode = self._model.training |
|
|
device = next(self._model.parameters()).device |
|
|
dither_value = self._model.preprocessor.featurizer.dither |
|
|
pad_to_value = self._model.preprocessor.featurizer.pad_to |
|
|
|
|
|
if num_workers is None: |
|
|
num_workers = min(batch_size, os.cpu_count() - 1) |
|
|
|
|
|
try: |
|
|
self._model.preprocessor.featurizer.dither = 0.0 |
|
|
self._model.preprocessor.featurizer.pad_to = 0 |
|
|
|
|
|
|
|
|
self._model.eval() |
|
|
|
|
|
self._model.encoder.freeze() |
|
|
self._model.decoder.freeze() |
|
|
if hasattr(self._model, "joint"): |
|
|
self._model.joint.freeze() |
|
|
logging_level = logging.get_verbosity() |
|
|
logging.set_verbosity(logging.WARNING) |
|
|
|
|
|
config = { |
|
|
'manifest_filepath': manifest, |
|
|
'batch_size': batch_size, |
|
|
'num_workers': num_workers, |
|
|
} |
|
|
temporary_datalayer = self._model._setup_transcribe_dataloader(config) |
|
|
for test_batch in tqdm(temporary_datalayer, desc="Aligning"): |
|
|
test_batch[0] = test_batch[0].to(device) |
|
|
test_batch[1] = test_batch[1].to(device) |
|
|
hypotheses += [unit for i, unit in self.predict_step(test_batch, 0)] |
|
|
del test_batch |
|
|
finally: |
|
|
|
|
|
self._model.train(mode=mode) |
|
|
self._model.preprocessor.featurizer.dither = dither_value |
|
|
self._model.preprocessor.featurizer.pad_to = pad_to_value |
|
|
|
|
|
logging.set_verbosity(logging_level) |
|
|
if mode is True: |
|
|
self._model.encoder.unfreeze() |
|
|
self._model.decoder.unfreeze() |
|
|
if hasattr(self._model, "joint"): |
|
|
self._model.joint.unfreeze() |
|
|
return hypotheses |
|
|
|
|
|
def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): |
|
|
raise RuntimeError("This module cannot be used in training.") |
|
|
|
|
|
def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): |
|
|
raise RuntimeError("This module cannot be used in validation.") |
|
|
|
|
|
def setup_test_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): |
|
|
raise RuntimeError("This module cannot be used in testing.") |
|
|
|