# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time from copy import deepcopy from typing import Dict import torch from omegaconf import DictConfig from nemo.collections.asr.models import ClusteringDiarizer from nemo.collections.asr.parts.utils.offline_clustering import get_scale_interpolated_embs, split_input_data from nemo.collections.asr.parts.utils.online_clustering import OnlineSpeakerClustering from nemo.collections.asr.parts.utils.speaker_utils import ( OnlineSegmentor, audio_rttm_map, generate_cluster_labels, get_embs_and_timestamps, ) from nemo.utils import logging, model_utils __all__ = ['OnlineClusteringDiarizer'] def timeit(method): """ Monitor elapsed time of the corresponding function displaying the method name. Args: method: function that is being measured Return: `timed` function for measuring the elapsed time """ def timed(*args, **kwargs): ts = time.time() result = method(*args, **kwargs) te = time.time() if 'log_time' in kwargs: name = kwargs.get('log_name', method.__name__.upper()) kwargs['log_time'][name] = int((te - ts) * 1000) else: logging.info('%2.2fms %r' % ((te - ts) * 1000, method.__name__)) return result return timed class OnlineClusteringDiarizer(ClusteringDiarizer): """ A class that enables online (streaming) clustering based diarization. - The instance created from `OnlineClusteringDiarizer` sets aside a certain amount of memory to provide the upcoming inference with history information - There are two major modules involved: `OnlineSegmentor` and `OnlineSpeakerClustering`. OnlineSegmentor: Take the VAD-timestamps and generate segments for each scale OnlineSpeakerClustering: Update the entire speaker labels of the given online session while updating the speaker labels of the streaming inputs. - The overall diarization process is done by calling `diarize_step` function. `diarize_step` function goes through the following steps: (1) Segmentation (`OnlineSegmentor` class) (2) Embedding extraction (`_extract_online_embeddings` function call) (3) Online speaker counting and speaker clustering (`OnlineClusteringDiarizer` class) (4) Label generation (`generate_cluster_labels` function call) """ def __init__(self, cfg: DictConfig): super().__init__(cfg) self.cfg = model_utils.convert_model_config_to_dict_config(cfg) self._cfg_diarizer = self.cfg.diarizer self.base_scale_index = max(self.multiscale_args_dict['scale_dict'].keys()) self.uniq_id = self._cfg_diarizer.get('uniq_id', None) self.decimals = self._cfg_diarizer.get('decimals', 2) self.AUDIO_RTTM_MAP = audio_rttm_map(self.cfg.diarizer.manifest_filepath) self.sample_rate = self.cfg.sample_rate torch.manual_seed(0) self._out_dir = self._cfg_diarizer.out_dir if not os.path.exists(self._out_dir): os.mkdir(self._out_dir) if torch.cuda.is_available(): self.cuda = True self.device = torch.device("cuda") else: self.cuda = False self.device = torch.device("cpu") self.reset() # Set speaker embedding model in eval mode self._speaker_model.eval() def _init_online_clustering_module(self, clustering_params): """ Initialize online speaker clustering module Attributes: online_clus (OnlineSpeakerClustering): Online clustering diarizer class instance history_n (int): History buffer size for saving history of speaker label inference Total number of embedding vectors saved in the buffer that is kept till the end of the session current_n (int): Current buffer (FIFO queue) size for calculating the speaker label inference Total number of embedding vectors saved in the FIFO queue for clustering inference """ self.online_clus = OnlineSpeakerClustering( max_num_speakers=clustering_params.max_num_speakers, max_rp_threshold=clustering_params.max_rp_threshold, sparse_search_volume=clustering_params.sparse_search_volume, history_buffer_size=clustering_params.history_buffer_size, current_buffer_size=clustering_params.current_buffer_size, ) self.history_n = clustering_params.history_buffer_size self.current_n = clustering_params.current_buffer_size self.max_num_speakers = self.online_clus.max_num_speakers def _init_online_segmentor_module(self, sample_rate): """ Initialize an online segmentor module Attributes: online_segmentor (OnlineSegmentor): online segmentation module that generates short speech segments from the VAD input """ self.online_segmentor = OnlineSegmentor(sample_rate) def _init_memory_buffer(self): """ Variables are kept in memory for future updates Attributes: memory_margin (int): The number of embeddings saved in the memory buffer. This memory margin is dependent on the base scale length: margin = (buffer_length)/(base scale shift) memory margin is automatically calculated to have minimal memory usage memory_segment_ranges (dict): The segment range information kept in the memory buffer memory_segment_indexes (dict): The segment indexes kept in the memory buffer memory_cluster_labels (Tensor): The cluster labels inferred in the previous diarization steps """ self.memory_margin = 0 self.memory_segment_ranges = {key: [] for key in self.multiscale_args_dict['scale_dict'].keys()} self.memory_segment_indexes = {key: [] for key in self.multiscale_args_dict['scale_dict'].keys()} self.memory_cluster_labels = torch.tensor([]) def _init_temporal_major_voting_module(self, clustering_params): """ Variables needed for taking majority votes for speaker labels Attributes: use_temporal_label_major_vote (bool): Boolean for whether to use temporal majority voting temporal_label_major_vote_buffer_size (int): buffer size for majority voting base_scale_label_dict (dict): Dictionary containing multiple speaker labels for major voting Speaker labels from multiple steps are saved for each segment index. """ self.use_temporal_label_major_vote = clustering_params.get('use_temporal_label_major_vote', False) self.temporal_label_major_vote_buffer_size = clustering_params.get('temporal_label_major_vote_buffer_size', 1) self.base_scale_label_dict = {} def _init_segment_variables(self): """ Initialize segment variables for each scale. Note that we have `uniq_id` variable in case where multiple sessions are handled. """ self.emb_vectors = {} self.time_stamps = {} self.segment_range_ts = {} self.segment_raw_audio = {} self.segment_indexes = {} for scale_idx in self.multiscale_args_dict['scale_dict'].keys(): self.multiscale_embeddings_and_timestamps[scale_idx] = [None, None] self.emb_vectors[scale_idx] = torch.tensor([]) self.time_stamps[scale_idx] = [] self.segment_range_ts[scale_idx] = [] self.segment_raw_audio[scale_idx] = [] self.segment_indexes[scale_idx] = [] def _init_buffer_frame_timestamps(self): """ Timing variables transferred from OnlineDiarWithASR class. Buffer is window region where input signal is kept for ASR. Frame is window region where the actual inference ASR decoded results are updated Example: buffer_len = 5.0 frame_len = 1.0 |___Buffer___[___________]____________| |____________[ Frame ]____________| | <- buffer_start |____________| <- frame_start |_____________________________________| <- buffer_end buffer_start = 12.0 buffer_end = 17.0 frame_start = 14.0 These timestamps and index variables are updated by OnlineDiarWithASR. Attributes: frame_index (int): Integer index of frame window frame_start (float): The start of the frame window buffer_start (float): The start of the buffer window buffer_end (float): The end of the buffer """ self.frame_index = 0 self.frame_start = 0.0 self.buffer_start = 0.0 self.buffer_end = 0.0 def _transfer_timestamps_to_segmentor(self): """ Pass the timing information from streaming ASR buffers. """ self.online_segmentor.frame_start = self.frame_start self.online_segmentor.buffer_start = self.buffer_start self.online_segmentor.buffer_end = self.buffer_end def reset(self): """ Reset all the necessary variables and initialize classes. Attributes: n_embed_seg_len (int): Number of segments needed for 1 second of input time-series signal """ self.n_embed_seg_len = int( self.sample_rate * self.multiscale_args_dict['scale_dict'][self.base_scale_index][0] ) self._init_segment_variables() self._init_online_clustering_module(self._cfg_diarizer.clustering.parameters) self._init_online_segmentor_module(self.cfg.sample_rate) self._init_memory_buffer() self._init_temporal_major_voting_module(self._cfg_diarizer.clustering.parameters) self._init_buffer_frame_timestamps() def _clear_memory(self, scale_idx: int): """ Calculate how many segments should be removed from memory (`memory_margin`) and save the necessary information. `keep_range` determines how many segments and their corresponding embedding, raw audio, timestamps in the memory of the online diarizer instance. Args: scale_idx (int): Scale index in integer type """ base_scale_shift = self.multiscale_args_dict['scale_dict'][self.base_scale_index][1] self.memory_margin = int((self.buffer_end - self.buffer_start) / base_scale_shift) scale_buffer_size = int( len(set(self.scale_mapping_dict[scale_idx].tolist())) / len(set(self.scale_mapping_dict[self.base_scale_index].tolist())) * (self.history_n + self.current_n) ) keep_range = scale_buffer_size + self.memory_margin self.emb_vectors[scale_idx] = self.emb_vectors[scale_idx][-keep_range:] self.segment_raw_audio[scale_idx] = self.segment_raw_audio[scale_idx][-keep_range:] self.segment_range_ts[scale_idx] = self.segment_range_ts[scale_idx][-keep_range:] self.segment_indexes[scale_idx] = self.segment_indexes[scale_idx][-keep_range:] @timeit def _temporal_label_major_vote(self) -> torch.Tensor: """ Take a majority voting for every segment on temporal steps. This feature significantly reduces the error coming from unstable speaker counting in the beginning of sessions. Returns: maj_vote_labels (list): List containing the major-voted speaker labels on temporal domain """ maj_vote_labels = [] for seg_idx in self.memory_segment_indexes[self.base_scale_index]: if seg_idx not in self.base_scale_label_dict: self.base_scale_label_dict[seg_idx] = [self.memory_cluster_labels[seg_idx]] else: while len(self.base_scale_label_dict[seg_idx]) > self.temporal_label_major_vote_buffer_size: self.base_scale_label_dict[seg_idx].pop(0) self.base_scale_label_dict[seg_idx].append(self.memory_cluster_labels[seg_idx]) maj_vote_labels.append(torch.mode(torch.tensor(self.base_scale_label_dict[seg_idx]))[0].item()) return maj_vote_labels def save_history_data(self, scale_idx: int, total_cluster_labels: torch.Tensor, is_online: bool) -> torch.Tensor: """ Save the temporary input to the class memory buffer. - Clustering is done for (hist_N + curr_N) number of embeddings. - Thus, we need to remove the clustering results on the embedding memory. - If self.diar.history_buffer_seg_end is not None, that indicates streaming diarization system is starting to save embeddings to its memory. Thus, the new incoming clustering label should be separated. - If `is_online = True`, old embeddings outside the window are removed to save GPU memory. Args: scale_idx (int): Scale index in integer total_cluster_labels (Tensor): The speaker labels from the beginning of the session to the current position is_online (bool) Boolean variable that indicates whether the system is currently in online mode or not Returns: cluster_label_hyp (Tensor): Majority voted speaker labels over multiple inferences """ total_cluster_labels = total_cluster_labels.tolist() if not is_online: self.memory_segment_ranges[scale_idx] = deepcopy(self.segment_range_ts[scale_idx]) self.memory_segment_indexes[scale_idx] = deepcopy(self.segment_indexes[scale_idx]) if scale_idx == self.base_scale_index: self.memory_cluster_labels = deepcopy(total_cluster_labels) # Only if there are newly obtained embeddings, update ranges and embeddings. elif self.segment_indexes[scale_idx][-1] > self.memory_segment_indexes[scale_idx][-1]: # Get the global index of the first segment we want to keep in the buffer global_stt_idx = max(max(self.memory_segment_indexes[scale_idx]) - self.memory_margin, 0) # Convert global index global_stt_idx to buffer index buffer_stt_idx segment_indexes_mat = torch.tensor(self.segment_indexes[scale_idx]) buffer_stt_idx = torch.where(segment_indexes_mat == global_stt_idx)[0][0] self.memory_segment_ranges[scale_idx][global_stt_idx:] = deepcopy( self.segment_range_ts[scale_idx][buffer_stt_idx:] ) self.memory_segment_indexes[scale_idx][global_stt_idx:] = deepcopy( self.segment_indexes[scale_idx][buffer_stt_idx:] ) if scale_idx == self.base_scale_index: self.memory_cluster_labels[global_stt_idx:] = deepcopy(total_cluster_labels[global_stt_idx:]) if len(self.memory_cluster_labels) != len(self.memory_segment_ranges[scale_idx]): raise ValueError( "self.memory_cluster_labels and self.memory_segment_ranges should always have the same length, " f"but they have {len(self.memory_cluster_labels)} and {len(self.memory_segment_ranges[scale_idx])}." ) # Remove unnecessary old values self._clear_memory(scale_idx) if not ( len(self.emb_vectors[scale_idx]) == len(self.segment_raw_audio[scale_idx]) == len(self.segment_indexes[scale_idx]) == len(self.segment_range_ts[scale_idx]) ): raise ValueError( "self.emb_vectors, self.segment_raw_audio, self.segment_indexes, and self.segment_range_ts " "should always have the same length, " f"but they have {len(self.emb_vectors[scale_idx])}, {len(self.segment_raw_audio[scale_idx])}, " f"{len(self.segment_indexes[scale_idx])}, and {len(self.segment_range_ts[scale_idx])}, respectively." ) if self.use_temporal_label_major_vote: cluster_label_hyp = self._temporal_label_major_vote() else: cluster_label_hyp = self.memory_cluster_labels return cluster_label_hyp @timeit @torch.no_grad() def _run_embedding_extractor(self, audio_signal: torch.Tensor) -> torch.Tensor: """ Call `forward` function of the speaker embedding model. Args: audio_signal (Tensor): Torch tensor containing time-series signal Returns: Speaker embedding vectors for the given time-series input `audio_signal`. """ audio_signal = torch.stack(audio_signal).float().to(self.device) audio_signal_lens = torch.tensor([self.n_embed_seg_len for k in range(audio_signal.shape[0])]).to(self.device) _, torch_embs = self._speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_lens) return torch_embs @timeit def _extract_online_embeddings( self, audio_signal: torch.Tensor, segment_ranges: torch.Tensor, embeddings ) -> torch.Tensor: """ Incrementally extract speaker embeddings based on `audio_signal` and `segment_ranges` variables. Unlike offline speaker diarization, speaker embedding and subsegment ranges are not saved to disk. Measures the mismatch between `segment_ranges` and `embeddings` then extract the necessary amount of speaker embeddings. Args: audio_signal (Tensor): Torch tensor containing time-series audio signal embeddings (Tensor): Previously existing Torch tensor containing speaker embedding vector segment_ranges(Tensor): Torch tensor containing the start and end of each segment Returns: embeddings (Tensor): Concatenated speaker embedding vectors that match segment range information in `segment_ranges`. """ stt_idx = 0 if embeddings is None else embeddings.shape[0] end_idx = len(segment_ranges) if end_idx > stt_idx: torch_embs = self._run_embedding_extractor(audio_signal[stt_idx:end_idx]) if embeddings is None: embeddings = torch_embs else: embeddings = torch.vstack((embeddings[:stt_idx, :], torch_embs)) elif end_idx < stt_idx: embeddings = embeddings[: len(segment_ranges)] if len(segment_ranges) != embeddings.shape[0]: raise ValueError("Segment ranges and embeddings shapes do not match.") return embeddings @timeit def _perform_online_clustering( self, uniq_embs_and_timestamps: Dict[str, torch.Tensor], cuda=False, ) -> torch.Tensor: """ Launch online clustering for `uniq_embs_and_timestamps` input variable. Args: uniq_embs_and_timestamps (dict): Dictionary containing embeddings, timestamps and multiscale weights. If uniq_embs_and_timestamps contains only one scale, single scale diarization is performed. cuda (bool): Boolean indicator for cuda usages """ device = torch.device("cuda") if cuda else torch.device("cpu") # Get base-scale (the highest index) information from uniq_embs_and_timestamps. embeddings_in_scales, timestamps_in_scales = split_input_data( embeddings_in_scales=uniq_embs_and_timestamps['embeddings'], timestamps_in_scales=uniq_embs_and_timestamps['timestamps'], multiscale_segment_counts=uniq_embs_and_timestamps['multiscale_segment_counts'], ) curr_emb, self.scale_mapping_dict = get_scale_interpolated_embs( multiscale_weights=uniq_embs_and_timestamps['multiscale_weights'], embeddings_in_scales=embeddings_in_scales, timestamps_in_scales=timestamps_in_scales, device=device, ) concat_emb, add_new = self.online_clus.get_reduced_mat( emb_in=curr_emb, base_segment_indexes=self.segment_indexes[self.base_scale_index] ) # Perform online version of clustering with history-concatenated embedding vectors Y_concat = self.online_clus.forward_infer(emb=concat_emb, frame_index=self.frame_index, cuda=cuda,) # Match the permutation of the newly obtained speaker labels and the previous labels merged_clus_labels = self.online_clus.match_labels(Y_concat, add_new) # Update history data for scale_idx, (window, shift) in self.multiscale_args_dict['scale_dict'].items(): cluster_label_hyp = self.save_history_data(scale_idx, merged_clus_labels, self.online_clus.is_online) return cluster_label_hyp def _get_interim_output(self) -> torch.Tensor: """ In case buffer is not filled or there is no speech activity in the input, generate temporary output. Returns: diar_hyp (Tensor): Speaker labels based on the previously saved segments and speaker labels """ if len(self.memory_cluster_labels) == 0 or self.buffer_start < 0: diar_hyp, _ = generate_cluster_labels([[0.0, self.total_buffer_in_secs]], [0]) else: diar_hyp, _ = generate_cluster_labels( self.memory_segment_ranges[self.base_scale_index], self.memory_cluster_labels ) return diar_hyp @timeit def diarize_step(self, audio_buffer: torch.Tensor, vad_timestamps: torch.Tensor) -> torch.Tensor: """ A function for a unit diarization step. Each diarization step goes through the following steps: 1. Segmentation: Using `OnlineSegmentor` class, call `run_online_segmentation` method to get the segments. 2. Embedding Extraction: Extract multiscale embeddings from the extracted speech segments. 3. Online Clustering & Counting Perform online speaker clustering by using `OnlineSpeakerClustering` class. 4. Generate speaker labels: Generate start and end timestamps of speaker labels based on the diarization results. c.f.) Also see method `diarize` in `ClusteringDiarizer` class. Args: audio_buffer (Tensor): Tensor variable containing the time series signal at the current frame Dimensions: (Number of audio time-series samples) x 1 vad_timestamps (Tensor): List containing VAD timestamps. Dimensions: (Number of segments) x 2 Example: >>> vad_timestamps = torch.Tensor([[0.05, 2.52], [3.12, 6.85]]) Returns: diar_hyp (Tensor): Speaker label hypothesis from the start of the session to the current position """ self._transfer_timestamps_to_segmentor() # In case buffer is not filled or there is no speech activity in the input if self.buffer_start < 0 or len(vad_timestamps) == 0: return self._get_interim_output() # Segmentation: (c.f. see `diarize` function in ClusteringDiarizer class) for scale_idx, (window, shift) in self.multiscale_args_dict['scale_dict'].items(): # Step 1: Get subsegments for embedding extraction. audio_sigs, segment_ranges, range_inds = self.online_segmentor.run_online_segmentation( audio_buffer=audio_buffer, vad_timestamps=vad_timestamps, segment_raw_audio=self.segment_raw_audio[scale_idx], segment_range_ts=self.segment_range_ts[scale_idx], segment_indexes=self.segment_indexes[scale_idx], window=window, shift=shift, ) self.segment_raw_audio[scale_idx] = audio_sigs self.segment_range_ts[scale_idx] = segment_ranges self.segment_indexes[scale_idx] = range_inds # Step 2-1: Extract speaker embeddings from the extracted subsegment timestamps. embeddings = self._extract_online_embeddings( audio_signal=self.segment_raw_audio[scale_idx], segment_ranges=self.segment_range_ts[scale_idx], embeddings=self.emb_vectors[scale_idx], ) # Step 2-2:Save the embeddings and segmentation timestamps in memory self.emb_vectors[scale_idx] = embeddings self.multiscale_embeddings_and_timestamps[scale_idx] = [ {self.uniq_id: embeddings}, {self.uniq_id: segment_ranges}, ] embs_and_timestamps = get_embs_and_timestamps( self.multiscale_embeddings_and_timestamps, self.multiscale_args_dict ) # Step 3 - Clustering: Perform an online version of clustering algorithm cluster_label_hyp = self._perform_online_clustering(embs_and_timestamps[self.uniq_id], cuda=self.cuda,) # Step 4: Generate RTTM style diarization labels from segment ranges and cluster labels diar_hyp, _ = generate_cluster_labels(self.memory_segment_ranges[self.base_scale_index], cluster_label_hyp) return diar_hyp