|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import os |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict |
|
|
from pytorch_lightning import Trainer |
|
|
|
|
|
from nemo.collections.asr.data import audio_to_text_dataset |
|
|
from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset |
|
|
from nemo.collections.asr.losses.rnnt import RNNTLoss |
|
|
from nemo.collections.asr.metrics.rnnt_wer_bpe import RNNTBPEWER, RNNTBPEDecoding, RNNTBPEDecodingConfig |
|
|
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel |
|
|
from nemo.collections.asr.parts.mixins import ASRBPEMixin |
|
|
from nemo.core.classes.common import PretrainedModelInfo |
|
|
from nemo.utils import logging, model_utils |
|
|
|
|
|
|
|
|
class EncDecRNNTBPEModel(EncDecRNNTModel, ASRBPEMixin): |
|
|
"""Base class for encoder decoder RNNT-based models with subword tokenization.""" |
|
|
|
|
|
@classmethod |
|
|
def list_available_models(cls) -> List[PretrainedModelInfo]: |
|
|
""" |
|
|
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
|
|
|
|
|
Returns: |
|
|
List of available pre-trained models. |
|
|
""" |
|
|
results = [] |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_contextnet_256", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_256", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_256/versions/1.6.0/files/stt_en_contextnet_256.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_contextnet_512", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_512", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_512/versions/1.6.0/files/stt_en_contextnet_512.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_contextnet_1024", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_1024", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_1024/versions/1.9.0/files/stt_en_contextnet_1024.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_contextnet_256_mls", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_256_mls", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_256_mls/versions/1.0.0/files/stt_en_contextnet_256_mls.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_contextnet_512_mls", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_512_mls", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_512_mls/versions/1.0.0/files/stt_en_contextnet_512_mls.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_contextnet_1024_mls", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_contextnet_1024_mls", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_contextnet_1024_mls/versions/1.0.0/files/stt_en_contextnet_1024_mls.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_conformer_transducer_small", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_small", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_transducer_small/versions/1.6.0/files/stt_en_conformer_transducer_small.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_conformer_transducer_medium", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_medium", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_transducer_medium/versions/1.6.0/files/stt_en_conformer_transducer_medium.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_transducer_large/versions/1.10.0/files/stt_en_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_conformer_transducer_large_ls", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_large_ls", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_transducer_large_ls/versions/1.8.0/files/stt_en_conformer_transducer_large_ls.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_conformer_transducer_xlarge", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_xlarge", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_transducer_xlarge/versions/1.10.0/files/stt_en_conformer_transducer_xlarge.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_en_conformer_transducer_xxlarge", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_conformer_transducer_xxlarge", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_conformer_transducer_xxlarge/versions/1.8.0/files/stt_en_conformer_transducer_xxlarge.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_de_contextnet_1024", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_contextnet_1024", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_contextnet_1024/versions/1.4.0/files/stt_de_contextnet_1024.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_fr_contextnet_1024", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_contextnet_1024", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_contextnet_1024/versions/1.5/files/stt_fr_contextnet_1024.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_es_contextnet_1024", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_contextnet_1024", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_contextnet_1024/versions/1.8.0/files/stt_es_contextnet_1024.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_de_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_conformer_transducer_large/versions/1.5.0/files/stt_de_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_fr_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_fr_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_conformer_transducer_large/versions/1.5/files/stt_fr_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_es_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_conformer_transducer_large/versions/1.8.0/files/stt_es_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_enes_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_enes_conformer_transducer_large/versions/1.0.0/files/stt_enes_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_enes_contextnet_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_contextnet_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_enes_contextnet_large/versions/1.0.0/files/stt_enes_contextnet_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_ca_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ca_conformer_transducer_large/versions/1.11.0/files/stt_ca_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_rw_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_rw_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_rw_conformer_transducer_large/versions/1.11.0/files/stt_rw_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_enes_conformer_transducer_large_codesw", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_enes_conformer_transducer_large_codesw", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_enes_conformer_transducer_large_codesw/versions/1.0.0/files/stt_enes_conformer_transducer_large_codesw.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_kab_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_kab_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_kab_conformer_transducer_large/versions/1.12.0/files/stt_kab_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_be_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_be_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_be_conformer_transducer_large/versions/1.12.0/files/stt_be_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_hr_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hr_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_hr_conformer_transducer_large/versions/1.11.0/files/stt_hr_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_it_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_it_conformer_transducer_large/versions/1.13.0/files/stt_it_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_ru_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ru_conformer_transducer_large/versions/1.13.0/files/stt_ru_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
model = PretrainedModelInfo( |
|
|
pretrained_model_name="stt_eo_conformer_transducer_large", |
|
|
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_eo_conformer_transducer_large", |
|
|
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_eo_conformer_transducer_large/versions/1.14.0/files/stt_eo_conformer_transducer_large.nemo", |
|
|
) |
|
|
results.append(model) |
|
|
|
|
|
return results |
|
|
|
|
|
def __init__(self, cfg: DictConfig, trainer: Trainer = None): |
|
|
|
|
|
cfg = model_utils.convert_model_config_to_dict_config(cfg) |
|
|
cfg = model_utils.maybe_update_config_version(cfg) |
|
|
|
|
|
|
|
|
if 'tokenizer' not in cfg: |
|
|
raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") |
|
|
|
|
|
if not isinstance(cfg, DictConfig): |
|
|
cfg = OmegaConf.create(cfg) |
|
|
|
|
|
|
|
|
self._setup_tokenizer(cfg.tokenizer) |
|
|
|
|
|
|
|
|
vocabulary = self.tokenizer.tokenizer.get_vocab() |
|
|
|
|
|
|
|
|
with open_dict(cfg): |
|
|
cfg.labels = ListConfig(list(vocabulary)) |
|
|
|
|
|
with open_dict(cfg.decoder): |
|
|
cfg.decoder.vocab_size = len(vocabulary) |
|
|
|
|
|
with open_dict(cfg.joint): |
|
|
cfg.joint.num_classes = len(vocabulary) |
|
|
cfg.joint.vocabulary = ListConfig(list(vocabulary)) |
|
|
cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden |
|
|
cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden |
|
|
|
|
|
super().__init__(cfg=cfg, trainer=trainer) |
|
|
|
|
|
|
|
|
self.decoding = RNNTBPEDecoding( |
|
|
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
self.wer = RNNTBPEWER( |
|
|
decoding=self.decoding, |
|
|
batch_dim_index=0, |
|
|
use_cer=self._cfg.get('use_cer', False), |
|
|
log_prediction=self._cfg.get('log_prediction', True), |
|
|
dist_sync_on_step=True, |
|
|
) |
|
|
|
|
|
|
|
|
if self.joint.fuse_loss_wer: |
|
|
self.joint.set_loss(self.loss) |
|
|
self.joint.set_wer(self.wer) |
|
|
|
|
|
def change_vocabulary( |
|
|
self, |
|
|
new_tokenizer_dir: Union[str, DictConfig], |
|
|
new_tokenizer_type: str, |
|
|
decoding_cfg: Optional[DictConfig] = None, |
|
|
): |
|
|
""" |
|
|
Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. |
|
|
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would |
|
|
use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need |
|
|
model to learn capitalization, punctuation and/or special characters. |
|
|
|
|
|
Args: |
|
|
new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) |
|
|
new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. |
|
|
decoding_cfg: A config for the decoder, which is optional. If the decoding type |
|
|
needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. |
|
|
|
|
|
Returns: None |
|
|
|
|
|
""" |
|
|
if isinstance(new_tokenizer_dir, DictConfig): |
|
|
if new_tokenizer_type == 'agg': |
|
|
new_tokenizer_cfg = new_tokenizer_dir |
|
|
else: |
|
|
raise ValueError( |
|
|
f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' |
|
|
) |
|
|
else: |
|
|
new_tokenizer_cfg = None |
|
|
|
|
|
if new_tokenizer_cfg is not None: |
|
|
tokenizer_cfg = new_tokenizer_cfg |
|
|
else: |
|
|
if not os.path.isdir(new_tokenizer_dir): |
|
|
raise NotADirectoryError( |
|
|
f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' |
|
|
) |
|
|
|
|
|
if new_tokenizer_type.lower() not in ('bpe', 'wpe'): |
|
|
raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') |
|
|
|
|
|
tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) |
|
|
|
|
|
|
|
|
self._setup_tokenizer(tokenizer_cfg) |
|
|
|
|
|
|
|
|
vocabulary = self.tokenizer.tokenizer.get_vocab() |
|
|
|
|
|
joint_config = self.joint.to_config_dict() |
|
|
new_joint_config = copy.deepcopy(joint_config) |
|
|
if self.tokenizer_type == "agg": |
|
|
new_joint_config["vocabulary"] = ListConfig(vocabulary) |
|
|
else: |
|
|
new_joint_config["vocabulary"] = ListConfig(list(vocabulary.keys())) |
|
|
|
|
|
new_joint_config['num_classes'] = len(vocabulary) |
|
|
del self.joint |
|
|
self.joint = EncDecRNNTBPEModel.from_config_dict(new_joint_config) |
|
|
|
|
|
decoder_config = self.decoder.to_config_dict() |
|
|
new_decoder_config = copy.deepcopy(decoder_config) |
|
|
new_decoder_config.vocab_size = len(vocabulary) |
|
|
del self.decoder |
|
|
self.decoder = EncDecRNNTBPEModel.from_config_dict(new_decoder_config) |
|
|
|
|
|
del self.loss |
|
|
self.loss = RNNTLoss(num_classes=self.joint.num_classes_with_blank - 1) |
|
|
|
|
|
if decoding_cfg is None: |
|
|
|
|
|
decoding_cfg = self.cfg.decoding |
|
|
|
|
|
|
|
|
decoding_cls = OmegaConf.structured(RNNTBPEDecodingConfig) |
|
|
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) |
|
|
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) |
|
|
|
|
|
self.decoding = RNNTBPEDecoding( |
|
|
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
self.wer = RNNTBPEWER( |
|
|
decoding=self.decoding, |
|
|
batch_dim_index=self.wer.batch_dim_index, |
|
|
use_cer=self.wer.use_cer, |
|
|
log_prediction=self.wer.log_prediction, |
|
|
dist_sync_on_step=True, |
|
|
) |
|
|
|
|
|
|
|
|
if self.joint.fuse_loss_wer or ( |
|
|
self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0 |
|
|
): |
|
|
self.joint.set_loss(self.loss) |
|
|
self.joint.set_wer(self.wer) |
|
|
|
|
|
|
|
|
with open_dict(self.cfg.joint): |
|
|
self.cfg.joint = new_joint_config |
|
|
|
|
|
with open_dict(self.cfg.decoder): |
|
|
self.cfg.decoder = new_decoder_config |
|
|
|
|
|
with open_dict(self.cfg.decoding): |
|
|
self.cfg.decoding = decoding_cfg |
|
|
|
|
|
logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") |
|
|
|
|
|
def change_decoding_strategy(self, decoding_cfg: DictConfig): |
|
|
""" |
|
|
Changes decoding strategy used during RNNT decoding process. |
|
|
|
|
|
Args: |
|
|
decoding_cfg: A config for the decoder, which is optional. If the decoding type |
|
|
needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. |
|
|
""" |
|
|
if decoding_cfg is None: |
|
|
|
|
|
logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") |
|
|
decoding_cfg = self.cfg.decoding |
|
|
|
|
|
|
|
|
decoding_cls = OmegaConf.structured(RNNTBPEDecodingConfig) |
|
|
decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) |
|
|
decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) |
|
|
|
|
|
self.decoding = RNNTBPEDecoding( |
|
|
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
self.wer = RNNTBPEWER( |
|
|
decoding=self.decoding, |
|
|
batch_dim_index=self.wer.batch_dim_index, |
|
|
use_cer=self.wer.use_cer, |
|
|
log_prediction=self.wer.log_prediction, |
|
|
dist_sync_on_step=True, |
|
|
) |
|
|
|
|
|
|
|
|
if self.joint.fuse_loss_wer or ( |
|
|
self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0 |
|
|
): |
|
|
self.joint.set_loss(self.loss) |
|
|
self.joint.set_wer(self.wer) |
|
|
|
|
|
|
|
|
with open_dict(self.cfg.decoding): |
|
|
self.cfg.decoding = decoding_cfg |
|
|
|
|
|
logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") |
|
|
|
|
|
def _setup_dataloader_from_config(self, config: Optional[Dict]): |
|
|
dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( |
|
|
config=config, |
|
|
local_rank=self.local_rank, |
|
|
global_rank=self.global_rank, |
|
|
world_size=self.world_size, |
|
|
tokenizer=self.tokenizer, |
|
|
preprocessor_cfg=self.cfg.get("preprocessor", None), |
|
|
) |
|
|
|
|
|
if dataset is None: |
|
|
return None |
|
|
|
|
|
if isinstance(dataset, AudioToBPEDALIDataset): |
|
|
|
|
|
return dataset |
|
|
|
|
|
shuffle = config['shuffle'] |
|
|
if config.get('is_tarred', False): |
|
|
shuffle = False |
|
|
|
|
|
if hasattr(dataset, 'collate_fn'): |
|
|
collate_fn = dataset.collate_fn |
|
|
else: |
|
|
collate_fn = dataset.datasets[0].collate_fn |
|
|
|
|
|
return torch.utils.data.DataLoader( |
|
|
dataset=dataset, |
|
|
batch_size=config['batch_size'], |
|
|
collate_fn=collate_fn, |
|
|
drop_last=config.get('drop_last', False), |
|
|
shuffle=shuffle, |
|
|
num_workers=config.get('num_workers', 0), |
|
|
pin_memory=config.get('pin_memory', False), |
|
|
) |
|
|
|
|
|
def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': |
|
|
""" |
|
|
Setup function for a temporary data loader which wraps the provided audio file. |
|
|
|
|
|
Args: |
|
|
config: A python dictionary which contains the following keys: |
|
|
paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ |
|
|
Recommended length per file is between 5 and 25 seconds. |
|
|
batch_size: (int) batch size to use during inference. \ |
|
|
Bigger will result in better throughput performance but would use more memory. |
|
|
temp_dir: (str) A temporary directory where the audio manifest is temporarily |
|
|
stored. |
|
|
|
|
|
Returns: |
|
|
A pytorch DataLoader for the given audio file(s). |
|
|
""" |
|
|
if 'manifest_filepath' in config: |
|
|
manifest_filepath = config['manifest_filepath'] |
|
|
batch_size = config['batch_size'] |
|
|
else: |
|
|
manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') |
|
|
batch_size = min(config['batch_size'], len(config['paths2audio_files'])) |
|
|
|
|
|
dl_config = { |
|
|
'manifest_filepath': manifest_filepath, |
|
|
'sample_rate': self.preprocessor._sample_rate, |
|
|
'batch_size': batch_size, |
|
|
'shuffle': False, |
|
|
'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), |
|
|
'pin_memory': True, |
|
|
'channel_selector': config.get('channel_selector', None), |
|
|
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False), |
|
|
} |
|
|
|
|
|
if config.get("augmentor"): |
|
|
dl_config['augmentor'] = config.get("augmentor") |
|
|
|
|
|
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) |
|
|
return temporary_datalayer |
|
|
|