Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import torch | |
| ''' | |
| def subsequent_mask( | |
| size: int, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> torch.Tensor: | |
| """Create mask for subsequent steps (size, size). | |
| This mask is used only in decoder which works in an auto-regressive mode. | |
| This means the current step could only do attention with its left steps. | |
| In encoder, fully attention is used when streaming is not necessary and | |
| the sequence is not long. In this case, no attention mask is needed. | |
| When streaming is need, chunk-based attention is used in encoder. See | |
| subsequent_chunk_mask for the chunk-based attention mask. | |
| Args: | |
| size (int): size of mask | |
| str device (str): "cpu" or "cuda" or torch.Tensor.device | |
| dtype (torch.device): result dtype | |
| Returns: | |
| torch.Tensor: mask | |
| Examples: | |
| >>> subsequent_mask(3) | |
| [[1, 0, 0], | |
| [1, 1, 0], | |
| [1, 1, 1]] | |
| """ | |
| ret = torch.ones(size, size, device=device, dtype=torch.bool) | |
| return torch.tril(ret) | |
| ''' | |
| def subsequent_mask( | |
| size: int, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> torch.Tensor: | |
| """Create mask for subsequent steps (size, size). | |
| This mask is used only in decoder which works in an auto-regressive mode. | |
| This means the current step could only do attention with its left steps. | |
| In encoder, fully attention is used when streaming is not necessary and | |
| the sequence is not long. In this case, no attention mask is needed. | |
| When streaming is need, chunk-based attention is used in encoder. See | |
| subsequent_chunk_mask for the chunk-based attention mask. | |
| Args: | |
| size (int): size of mask | |
| str device (str): "cpu" or "cuda" or torch.Tensor.device | |
| dtype (torch.device): result dtype | |
| Returns: | |
| torch.Tensor: mask | |
| Examples: | |
| >>> subsequent_mask(3) | |
| [[1, 0, 0], | |
| [1, 1, 0], | |
| [1, 1, 1]] | |
| """ | |
| arange = torch.arange(size, device=device) | |
| mask = arange.expand(size, size) | |
| arange = arange.unsqueeze(-1) | |
| mask = mask <= arange | |
| return mask | |
| def subsequent_chunk_mask( | |
| size: int, | |
| chunk_size: int, | |
| num_left_chunks: int = -1, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> torch.Tensor: | |
| """Create mask for subsequent steps (size, size) with chunk size, | |
| this is for streaming encoder | |
| Args: | |
| size (int): size of mask | |
| chunk_size (int): size of chunk | |
| num_left_chunks (int): number of left chunks | |
| <0: use full chunk | |
| >=0: use num_left_chunks | |
| device (torch.device): "cpu" or "cuda" or torch.Tensor.device | |
| Returns: | |
| torch.Tensor: mask | |
| Examples: | |
| >>> subsequent_chunk_mask(4, 2) | |
| [[1, 1, 0, 0], | |
| [1, 1, 0, 0], | |
| [1, 1, 1, 1], | |
| [1, 1, 1, 1]] | |
| """ | |
| ret = torch.zeros(size, size, device=device, dtype=torch.bool) | |
| for i in range(size): | |
| if num_left_chunks < 0: | |
| start = 0 | |
| else: | |
| start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) | |
| ending = min((i // chunk_size + 1) * chunk_size, size) | |
| ret[i, start:ending] = True | |
| return ret | |
| def add_optional_chunk_mask( | |
| xs: torch.Tensor, | |
| masks: torch.Tensor, | |
| use_dynamic_chunk: bool, | |
| use_dynamic_left_chunk: bool, | |
| decoding_chunk_size: int, | |
| static_chunk_size: int, | |
| num_decoding_left_chunks: int, | |
| ): | |
| """Apply optional mask for encoder. | |
| Args: | |
| xs (torch.Tensor): padded input, (B, L, D), L for max length | |
| mask (torch.Tensor): mask for xs, (B, 1, L) | |
| use_dynamic_chunk (bool): whether to use dynamic chunk or not | |
| use_dynamic_left_chunk (bool): whether to use dynamic left chunk for | |
| training. | |
| decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's | |
| 0: default for training, use random dynamic chunk. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| static_chunk_size (int): chunk size for static chunk training/decoding | |
| if it's greater than 0, if use_dynamic_chunk is true, | |
| this parameter will be ignored | |
| num_decoding_left_chunks: number of left chunks, this is for decoding, | |
| the chunk size is decoding_chunk_size. | |
| >=0: use num_decoding_left_chunks | |
| <0: use all left chunks | |
| Returns: | |
| torch.Tensor: chunk mask of the input xs. | |
| """ | |
| # Whether to use chunk mask or not | |
| if use_dynamic_chunk: | |
| max_len = xs.size(1) | |
| if decoding_chunk_size < 0: | |
| chunk_size = max_len | |
| num_left_chunks = -1 | |
| elif decoding_chunk_size > 0: | |
| chunk_size = decoding_chunk_size | |
| num_left_chunks = num_decoding_left_chunks | |
| else: | |
| # chunk size is either [1, 25] or full context(max_len). | |
| # Since we use 4 times subsampling and allow up to 1s(100 frames) | |
| # delay, the maximum frame is 100 / 4 = 25. | |
| chunk_size = torch.randint(1, max_len, (1,)).item() | |
| num_left_chunks = -1 | |
| if chunk_size > max_len // 2: | |
| chunk_size = max_len | |
| else: | |
| chunk_size = chunk_size % 25 + 1 | |
| if use_dynamic_left_chunk: | |
| max_left_chunks = (max_len - 1) // chunk_size | |
| num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item() | |
| chunk_masks = subsequent_chunk_mask( | |
| xs.size(1), chunk_size, num_left_chunks, xs.device | |
| ) # (L, L) | |
| chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) | |
| chunk_masks = masks & chunk_masks # (B, L, L) | |
| elif static_chunk_size > 0: | |
| num_left_chunks = num_decoding_left_chunks | |
| chunk_masks = subsequent_chunk_mask( | |
| xs.size(1), static_chunk_size, num_left_chunks, xs.device | |
| ) # (L, L) | |
| chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) | |
| chunk_masks = masks & chunk_masks # (B, L, L) | |
| else: | |
| chunk_masks = masks | |
| return chunk_masks | |
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: | |
| """Make mask tensor containing indices of padded part. | |
| See description of make_non_pad_mask. | |
| Args: | |
| lengths (torch.Tensor): Batch of lengths (B,). | |
| Returns: | |
| torch.Tensor: Mask tensor containing indices of padded part. | |
| Examples: | |
| >>> lengths = [5, 3, 2] | |
| >>> make_pad_mask(lengths) | |
| masks = [[0, 0, 0, 0 ,0], | |
| [0, 0, 0, 1, 1], | |
| [0, 0, 1, 1, 1]] | |
| """ | |
| batch_size = lengths.size(0) | |
| max_len = max_len if max_len > 0 else lengths.max().item() | |
| seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) | |
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
| seq_length_expand = lengths.unsqueeze(-1) | |
| mask = seq_range_expand >= seq_length_expand | |
| return mask | |
| def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: | |
| """Make mask tensor containing indices of non-padded part. | |
| The sequences in a batch may have different lengths. To enable | |
| batch computing, padding is need to make all sequence in same | |
| size. To avoid the padding part pass value to context dependent | |
| block such as attention or convolution , this padding part is | |
| masked. | |
| This pad_mask is used in both encoder and decoder. | |
| 1 for non-padded part and 0 for padded part. | |
| Args: | |
| lengths (torch.Tensor): Batch of lengths (B,). | |
| Returns: | |
| torch.Tensor: mask tensor containing indices of padded part. | |
| Examples: | |
| >>> lengths = [5, 3, 2] | |
| >>> make_non_pad_mask(lengths) | |
| masks = [[1, 1, 1, 1 ,1], | |
| [1, 1, 1, 0, 0], | |
| [1, 1, 0, 0, 0]] | |
| """ | |
| return ~make_pad_mask(lengths) | |
| def mask_finished_scores(score: torch.Tensor, flag: torch.Tensor) -> torch.Tensor: | |
| """ | |
| If a sequence is finished, we only allow one alive branch. This function | |
| aims to give one branch a zero score and the rest -inf score. | |
| Args: | |
| score (torch.Tensor): A real value array with shape | |
| (batch_size * beam_size, beam_size). | |
| flag (torch.Tensor): A bool array with shape | |
| (batch_size * beam_size, 1). | |
| Returns: | |
| torch.Tensor: (batch_size * beam_size, beam_size). | |
| """ | |
| beam_size = score.size(-1) | |
| zero_mask = torch.zeros_like(flag, dtype=torch.bool) | |
| if beam_size > 1: | |
| unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), dim=1) | |
| finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), dim=1) | |
| else: | |
| unfinished = zero_mask | |
| finished = flag | |
| score.masked_fill_(unfinished, -float("inf")) | |
| score.masked_fill_(finished, 0) | |
| return score | |
| def mask_finished_preds( | |
| pred: torch.Tensor, flag: torch.Tensor, eos: int | |
| ) -> torch.Tensor: | |
| """ | |
| If a sequence is finished, all of its branch should be <eos> | |
| Args: | |
| pred (torch.Tensor): A int array with shape | |
| (batch_size * beam_size, beam_size). | |
| flag (torch.Tensor): A bool array with shape | |
| (batch_size * beam_size, 1). | |
| Returns: | |
| torch.Tensor: (batch_size * beam_size). | |
| """ | |
| beam_size = pred.size(-1) | |
| finished = flag.repeat([1, beam_size]) | |
| return pred.masked_fill_(finished, eos) | |
 
			
