camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from typing import List
import torch
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.core.utils.neural_type_utils import get_io_names
from nemo.utils import logging, model_utils
from nemo.utils.cast_utils import cast_all
__all__ = ['ASRModel']
class ASRModel(ModelPT, ABC):
@abstractmethod
def transcribe(self, paths2audio_files: List[str], batch_size: int = 4) -> List[str]:
"""
Takes paths to audio files and returns text transcription
Args:
paths2audio_files: paths to audio fragment to be transcribed
Returns:
transcription texts
"""
pass
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
wer_num = torch.stack([x['val_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['val_wer_denom'] for x in outputs]).sum()
tensorboard_logs = {'val_loss': val_loss_mean, 'val_wer': wer_num / wer_denom}
return {'val_loss': val_loss_mean, 'log': tensorboard_logs}
def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
val_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
wer_num = torch.stack([x['test_wer_num'] for x in outputs]).sum()
wer_denom = torch.stack([x['test_wer_denom'] for x in outputs]).sum()
tensorboard_logs = {'test_loss': val_loss_mean, 'test_wer': wer_num / wer_denom}
return {'test_loss': val_loss_mean, 'log': tensorboard_logs}
@classmethod
def list_available_models(cls) -> 'List[PretrainedModelInfo]':
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
# recursively walk the subclasses to generate pretrained model info
list_of_models = model_utils.resolve_subclass_pretrained_model_info(cls)
return list_of_models
def add_auxiliary_losses(self, loss: torch.Tensor, reset_registry: bool = False) -> torch.Tensor:
"""
Utility method to enable calculation of auxiliary losses for ASR training.
Args:
loss: The output loss value prior to addition with auxiliary losses.
reset_registry: Bool, whether to reset the AccessMixin registry after adding auxiliary losses.
Returns:
Loss tensor used for back propagation.
"""
# Add adapter auxiliary losses, if registered
if AccessMixin.is_access_enabled():
registry = AccessMixin.get_module_registry(self)
log_dict = {}
for loss_key, loss_registry in registry.items():
# Add auxiliary loss to total loss
if 'adapter_loss' in loss_registry:
loss_list = loss_registry['adapter_loss']
loss_value = sum(loss_list)
loss += loss_value
# Log current loss name and value
keys = loss_key.split(".")
key = "/".join(keys)
key = "adapter_loss/" + key
log_dict[key] = loss_value.detach()
if len(log_dict) > 0:
self.log_dict(log_dict)
if reset_registry:
AccessMixin.reset_registry(self)
# return total loss
return loss
def setup_optimization_flags(self):
"""
Utility method that must be explicitly called by the subclass in order to support optional optimization flags.
This method is the only valid place to access self.cfg prior to DDP training occurs.
The subclass may chose not to support this method, therefore all variables here must be checked via hasattr()
"""
# Skip update if nan/inf grads appear on any rank.
self._skip_nan_grad = False
if "skip_nan_grad" in self._cfg and self._cfg["skip_nan_grad"]:
self._skip_nan_grad = self._cfg["skip_nan_grad"]
def on_after_backward(self):
"""
zero-out the gradients which any of them is NAN or INF
"""
super().on_after_backward()
if hasattr(self, '_skip_nan_grad') and self._skip_nan_grad:
device = next(self.parameters()).device
valid_gradients = torch.tensor([1], device=device, dtype=torch.float32)
# valid_gradients = True
for param_name, param in self.named_parameters():
if param.grad is not None:
is_not_nan_or_inf = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
if not is_not_nan_or_inf:
valid_gradients = valid_gradients * 0
break
if torch.distributed.is_initialized():
torch.distributed.all_reduce(valid_gradients, op=torch.distributed.ReduceOp.MIN)
if valid_gradients < 1:
logging.warning(f'detected inf or nan values in gradients! Setting gradients to zero.')
self.zero_grad()
class ExportableEncDecModel(Exportable):
"""
Simple utiliy mix-in to export models that consist of encoder/decoder pair
plus pre/post processor, but have to be exported as encoder/decoder pair only
(covers most ASR classes)
"""
@property
def input_module(self):
return self.encoder
@property
def output_module(self):
return self.decoder
@property
def output_names(self):
otypes = self.output_module.output_types
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
otypes[n] = t
return get_io_names(otypes, self.disabled_deployment_output_names)
def forward_for_export(
self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
Args:
input: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps.
length: Vector of length B, that contains the individual lengths of the audio sequences.
cache_last_channel: Tensor of shape [N, B, T, H] which contains the cache for last channel layers
cache_last_time: Tensor of shape [N, B, H, T] which contains the cache for last time layers
N is the number of such layers which need caching, B is batch size, H is the hidden size of activations,
and T is the length of the cache
Returns:
the output of the model
"""
if hasattr(self.input_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
else:
encoder_output = self.input_module.forward_for_export(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(audio_signal=input, length=length)
else:
encoder_output = self.input_module(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
decoder_input = encoder_output
if hasattr(self.output_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(encoder_output=decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
else:
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)
@property
def disabled_deployment_input_names(self):
return self.encoder.disabled_deployment_input_names
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names