|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
|
|
|
from nemo.core.classes import ModelPT |
|
|
from nemo.core.classes.common import PretrainedModelInfo |
|
|
from nemo.core.classes.exportable import Exportable |
|
|
from nemo.core.classes.mixins import AccessMixin |
|
|
from nemo.core.utils.neural_type_utils import get_io_names |
|
|
from nemo.utils import logging, model_utils |
|
|
from nemo.utils.cast_utils import cast_all |
|
|
|
|
|
__all__ = ['ASRModel'] |
|
|
|
|
|
|
|
|
class ASRModel(ModelPT, ABC): |
|
|
@abstractmethod |
|
|
def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]: |
|
|
""" |
|
|
Takes paths to audio files and returns text transcription |
|
|
Args: |
|
|
paths2audio_files: paths to audio fragment to be transcribed |
|
|
|
|
|
Returns: |
|
|
transcription texts |
|
|
""" |
|
|
pass |
|
|
|
|
|
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): |
|
|
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() |
|
|
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum() |
|
|
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum() |
|
|
tensorboard_logs = {'val_loss': val_loss_mean, 'val_wer': wer_num / wer_denom} |
|
|
return {'val_loss': val_loss_mean, 'log': tensorboard_logs} |
|
|
|
|
|
def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): |
|
|
val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() |
|
|
wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum() |
|
|
wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum() |
|
|
tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom} |
|
|
return {'test_loss': val_loss_mean, 'log': tensorboard_logs} |
|
|
|
|
|
@classmethod |
|
|
def list_available_models(cls) -> 'List[PretrainedModelInfo]': |
|
|
""" |
|
|
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
|
|
Returns: |
|
|
List of available pre-trained models. |
|
|
""" |
|
|
|
|
|
list_of_models = model_utils.resolve_subclass_pretrained_model_info(cls) |
|
|
return list_of_models |
|
|
|
|
|
def add_auxiliary_losses(self, loss: torch.Tensor, reset_registry: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Utility method to enable calculation of auxiliary losses for ASR training. |
|
|
|
|
|
Args: |
|
|
loss: The output loss value prior to addition with auxiliary losses. |
|
|
reset_registry: Bool, whether to reset the AccessMixin registry after adding auxiliary losses. |
|
|
|
|
|
Returns: |
|
|
Loss tensor used for back propagation. |
|
|
""" |
|
|
|
|
|
if AccessMixin.is_access_enabled(): |
|
|
registry = AccessMixin.get_module_registry(self) |
|
|
log_dict = {} |
|
|
|
|
|
for loss_key, loss_registry in registry.items(): |
|
|
|
|
|
if 'adapter_loss' in loss_registry: |
|
|
loss_list = loss_registry['adapter_loss'] |
|
|
loss_value = sum(loss_list) |
|
|
loss += loss_value |
|
|
|
|
|
|
|
|
keys = loss_key.split(".") |
|
|
key = "/".join(keys) |
|
|
key = "adapter_loss/" + key |
|
|
log_dict[key] = loss_value.detach() |
|
|
|
|
|
if len(log_dict) > 0: |
|
|
self.log_dict(log_dict) |
|
|
|
|
|
if reset_registry: |
|
|
AccessMixin.reset_registry(self) |
|
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
def setup_optimization_flags(self): |
|
|
""" |
|
|
Utility method that must be explicitly called by the subclass in order to support optional optimization flags. |
|
|
This method is the only valid place to access self.cfg prior to DDP training occurs. |
|
|
|
|
|
The subclass may chose not to support this method, therefore all variables here must be checked via hasattr() |
|
|
""" |
|
|
|
|
|
self._skip_nan_grad = False |
|
|
if "skip_nan_grad" in self._cfg and self._cfg["skip_nan_grad"]: |
|
|
self._skip_nan_grad = self._cfg["skip_nan_grad"] |
|
|
|
|
|
def on_after_backward(self): |
|
|
""" |
|
|
zero-out the gradients which any of them is NAN or INF |
|
|
""" |
|
|
super().on_after_backward() |
|
|
|
|
|
if hasattr(self, '_skip_nan_grad') and self._skip_nan_grad: |
|
|
device = next(self.parameters()).device |
|
|
valid_gradients = torch.tensor([1], device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
for param_name, param in self.named_parameters(): |
|
|
if param.grad is not None: |
|
|
is_not_nan_or_inf = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()) |
|
|
if not is_not_nan_or_inf: |
|
|
valid_gradients = valid_gradients * 0 |
|
|
break |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
torch.distributed.all_reduce(valid_gradients, op=torch.distributed.ReduceOp.MIN) |
|
|
|
|
|
if valid_gradients < 1: |
|
|
logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.') |
|
|
self.zero_grad() |
|
|
|
|
|
|
|
|
class ExportableEncDecModel(Exportable): |
|
|
""" |
|
|
Simple utiliy mix-in to export models that consist of encoder/decoder pair |
|
|
plus pre/post processor, but have to be exported as encoder/decoder pair only |
|
|
(covers most ASR classes) |
|
|
""" |
|
|
|
|
|
@property |
|
|
def input_module(self): |
|
|
return self.encoder |
|
|
|
|
|
@property |
|
|
def output_module(self): |
|
|
return self.decoder |
|
|
|
|
|
@property |
|
|
def output_names(self): |
|
|
otypes = self.output_module.output_types |
|
|
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support: |
|
|
in_types = self.input_module.output_types |
|
|
otypes = {n: t for (n, t) in list(otypes.items())[:1]} |
|
|
for (n, t) in list(in_types.items())[1:]: |
|
|
otypes[n] = t |
|
|
return get_io_names(otypes, self.disabled_deployment_output_names) |
|
|
|
|
|
def forward_for_export( |
|
|
self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None |
|
|
): |
|
|
""" |
|
|
This forward is used when we need to export the model to ONNX format. |
|
|
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models. |
|
|
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case. |
|
|
Args: |
|
|
input: Tensor that represents a batch of raw audio signals, |
|
|
of shape [B, T]. T here represents timesteps. |
|
|
length: Vector of length B, that contains the individual lengths of the audio sequences. |
|
|
cache_last_channel: Tensor of shape [N, B, T, H] which contains the cache for last channel layers |
|
|
cache_last_time: Tensor of shape [N, B, H, T] which contains the cache for last time layers |
|
|
N is the number of such layers which need caching, B is batch size, H is the hidden size of activations, |
|
|
and T is the length of the cache |
|
|
|
|
|
Returns: |
|
|
the output of the model |
|
|
""" |
|
|
if hasattr(self.input_module, 'forward_for_export'): |
|
|
if cache_last_channel is None and cache_last_time is None: |
|
|
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length) |
|
|
else: |
|
|
encoder_output = self.input_module.forward_for_export( |
|
|
audio_signal=input, |
|
|
length=length, |
|
|
cache_last_channel=cache_last_channel, |
|
|
cache_last_time=cache_last_time, |
|
|
cache_last_channel_len=cache_last_channel_len, |
|
|
) |
|
|
else: |
|
|
if cache_last_channel is None and cache_last_time is None: |
|
|
encoder_output = self.input_module(audio_signal=input, length=length) |
|
|
else: |
|
|
encoder_output = self.input_module( |
|
|
audio_signal=input, |
|
|
length=length, |
|
|
cache_last_channel=cache_last_channel, |
|
|
cache_last_time=cache_last_time, |
|
|
cache_last_channel_len=cache_last_channel_len, |
|
|
) |
|
|
if isinstance(encoder_output, tuple): |
|
|
decoder_input = encoder_output[0] |
|
|
else: |
|
|
decoder_input = encoder_output |
|
|
if hasattr(self.output_module, 'forward_for_export'): |
|
|
if cache_last_channel is None and cache_last_time is None: |
|
|
ret = self.output_module.forward_for_export(encoder_output=decoder_input) |
|
|
else: |
|
|
ret = self.output_module.forward_for_export(encoder_output=decoder_input) |
|
|
else: |
|
|
if cache_last_channel is None and cache_last_time is None: |
|
|
ret = self.output_module(encoder_output=decoder_input) |
|
|
else: |
|
|
ret = self.output_module(encoder_output=decoder_input) |
|
|
if cache_last_channel is None and cache_last_time is None: |
|
|
pass |
|
|
else: |
|
|
if isinstance(ret, tuple): |
|
|
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4]) |
|
|
else: |
|
|
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4]) |
|
|
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32) |
|
|
|
|
|
@property |
|
|
def disabled_deployment_input_names(self): |
|
|
return self.encoder.disabled_deployment_input_names |
|
|
|
|
|
@property |
|
|
def disabled_deployment_output_names(self): |
|
|
return self.encoder.disabled_deployment_output_names |
|
|
|