# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from dataclasses import dataclass, is_dataclass from typing import Optional import torch from omegaconf import OmegaConf from utils.data_prep import ( get_batch_starts_ends, get_batch_tensors_and_boundary_info, get_manifest_lines_batch, is_entry_in_all_lines, is_entry_in_any_lines, ) from utils.make_output_files import make_ctm, make_new_manifest from utils.viterbi_decoding import viterbi_decoding from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.parts.utils.transcribe_utils import setup_model from nemo.core.config import hydra_runner from nemo.utils import logging """ Align the utterances in manifest_filepath. Results are saved in ctm files in output_dir. Arguments: pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded from NGC and used for generating the log-probs which we will use to do alignment. Note: NFA can only use CTC models (not Transducer models) at the moment. model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the log-probs which we will use to do alignment. Note: NFA can only use CTC models (not Transducer models) at the moment. Note: if a model_path is provided, it will override the pretrained_name. model_downsample_factor: an int indicating the downsample factor of the ASR model, ie the ratio of input timesteps to output timesteps. If the ASR model is a QuartzNet model, its downsample factor is 2. If the ASR model is a Conformer CTC model, its downsample factor is 4. If the ASR model is a Citirnet model, its downsample factor is 8. manifest_filepath: filepath to the manifest of the data you want to align, containing 'audio_filepath' and 'text' fields. output_dir: the folder where output CTM files and new JSON manifest will be saved. align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription as the 'ground truth' for the forced alignment. transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing"). The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available (otherwise will set it to 'cpu'). viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding. The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available (otherwise will set it to 'cpu'). batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding. additional_ctm_grouping_separator: the string used to separate CTM segments if you want to obtain CTM files at a level that is not the token level or the word level. NFA will always produce token-level and word-level CTM files in: `/tokens/.ctm` and `/words/.ctm`. If `additional_ctm_grouping_separator` is specified, an additional folder `/{tokens/words/additional_segments}/.ctm` will be created containing CTMs for `addtional_ctm_grouping_separator`-separated segments. remove_blank_tokens_from_ctm: a boolean denoting whether to remove tokens from token-level output CTMs. audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath we will use (starting from the final part of the audio_filepath) to determine the utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath will be replaced with dashes, so as not to change the number of space-separated elements in the CTM files. e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1" e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1" e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1" minimum_timestamp_duration: a float indicating a minimum duration (in seconds) for timestamps in the CTM. If any line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. Note that this may cause timestamps to overlap. """ @dataclass class AlignmentConfig: # Required configs pretrained_name: Optional[str] = None model_path: Optional[str] = None model_downsample_factor: Optional[int] = None manifest_filepath: Optional[str] = None output_dir: Optional[str] = None # General configs align_using_pred_text: bool = False transcribe_device: Optional[str] = None viterbi_device: Optional[str] = None batch_size: int = 1 additional_ctm_grouping_separator: Optional[str] = None remove_blank_tokens_from_ctm: bool = False minimum_timestamp_duration: float = 0 audio_filepath_parts_in_utt_id: int = 1 @hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) def main(cfg: AlignmentConfig): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) # Validate config if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None") if cfg.model_path is not None and cfg.pretrained_name is not None: raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None") if cfg.model_downsample_factor is None: raise ValueError("cfg.model_downsample_factor must be specified") if cfg.manifest_filepath is None: raise ValueError("cfg.manifest_filepath must be specified") if cfg.output_dir is None: raise ValueError("cfg.output_dir must be specified") if cfg.batch_size < 1: raise ValueError("cfg.batch_size cannot be zero or a negative number") if cfg.additional_ctm_grouping_separator == "" or cfg.additional_ctm_grouping_separator == " ": raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character") if cfg.minimum_timestamp_duration < 0: raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number") # Validate manifest contents if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"): raise RuntimeError( "At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. " "All lines must contain an 'audio_filepath' entry." ) if cfg.align_using_pred_text: if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"): raise RuntimeError( "Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath " "contains 'pred_text' entries. This is because the audio will be transcribed and may produce " "a different 'pred_text'. This may cause confusion." ) else: if not is_entry_in_all_lines(cfg.manifest_filepath, "text"): raise RuntimeError( "At least one line in cfg.manifest_filepath does not contain a 'text' entry. " "NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=True." ) # init devices if cfg.transcribe_device is None: transcribe_device = torch.device("cuda" if torch.cuda.is_available else "cpu") else: transcribe_device = torch.device(cfg.transcribe_device) logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}") if cfg.viterbi_device is None: viterbi_device = torch.device("cuda" if torch.cuda.is_available else "cpu") else: viterbi_device = torch.device(cfg.viterbi_device) logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}") if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda': logging.warning( 'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors ' 'it may help to change both devices to be the CPU.' ) # load model model, _ = setup_model(cfg, transcribe_device) model.eval() if not isinstance(model, EncDecCTCModel): raise NotImplementedError( f"Model {cfg.model_name} is not an instance of NeMo EncDecCTCModel." " Currently only instances of EncDecCTCModels are supported" ) if cfg.minimum_timestamp_duration > 0: logging.warning( f"cfg.minimum_timestamp_duration has been set to {cfg.minimum_timestamp_duration} seconds. " "This may cause the alignments for some tokens/words/additional segments to be overlapping." ) # get start and end line IDs of batches starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) if cfg.align_using_pred_text: # record pred_texts to save them in the new manifest at the end of this script pred_text_all_lines = [] else: pred_text_all_lines = None # get alignment and save in CTM batch-by-batch for start, end in zip(starts, ends): manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end) ( log_probs_batch, y_batch, T_batch, U_batch, token_info_batch, word_info_batch, segment_info_batch, pred_text_batch, ) = get_batch_tensors_and_boundary_info( manifest_lines_batch, model, cfg.additional_ctm_grouping_separator, cfg.align_using_pred_text, ) if cfg.align_using_pred_text: pred_text_all_lines.extend(pred_text_batch) alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) make_ctm( token_info_batch, alignments_batch, manifest_lines_batch, model, cfg.model_downsample_factor, os.path.join(cfg.output_dir, "tokens"), cfg.remove_blank_tokens_from_ctm, cfg.audio_filepath_parts_in_utt_id, cfg.minimum_timestamp_duration, ) make_ctm( word_info_batch, alignments_batch, manifest_lines_batch, model, cfg.model_downsample_factor, os.path.join(cfg.output_dir, "words"), False, # dont try to remove blank tokens because we dont expect them to be there anyway cfg.audio_filepath_parts_in_utt_id, cfg.minimum_timestamp_duration, ) if cfg.additional_ctm_grouping_separator: make_ctm( segment_info_batch, alignments_batch, manifest_lines_batch, model, cfg.model_downsample_factor, os.path.join(cfg.output_dir, "additional_segments"), False, # dont try to remove blank tokens because we dont expect them to be there anyway cfg.audio_filepath_parts_in_utt_id, cfg.minimum_timestamp_duration, ) make_new_manifest( cfg.output_dir, cfg.manifest_filepath, cfg.additional_ctm_grouping_separator, cfg.audio_filepath_parts_in_utt_id, pred_text_all_lines, ) return None if __name__ == "__main__": main()