| """Search algorithms for transducer models.""" | |
| from typing import List | |
| from typing import Union | |
| import numpy as np | |
| import torch | |
| from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state | |
| from espnet.nets.pytorch_backend.transducer.utils import init_lm_state | |
| from espnet.nets.pytorch_backend.transducer.utils import is_prefix | |
| from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps | |
| from espnet.nets.pytorch_backend.transducer.utils import select_lm_state | |
| from espnet.nets.pytorch_backend.transducer.utils import substract | |
| from espnet.nets.transducer_decoder_interface import Hypothesis | |
| from espnet.nets.transducer_decoder_interface import NSCHypothesis | |
| from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface | |
| class BeamSearchTransducer: | |
| """Beam search implementation for transducer.""" | |
| def __init__( | |
| self, | |
| decoder: Union[TransducerDecoderInterface, torch.nn.Module], | |
| joint_network: torch.nn.Module, | |
| beam_size: int, | |
| lm: torch.nn.Module = None, | |
| lm_weight: float = 0.1, | |
| search_type: str = "default", | |
| max_sym_exp: int = 2, | |
| u_max: int = 50, | |
| nstep: int = 1, | |
| prefix_alpha: int = 1, | |
| score_norm: bool = True, | |
| nbest: int = 1, | |
| ): | |
| """Initialize transducer beam search. | |
| Args: | |
| decoder: Decoder class to use | |
| joint_network: Joint Network class | |
| beam_size: Number of hypotheses kept during search | |
| lm: LM class to use | |
| lm_weight: lm weight for soft fusion | |
| search_type: type of algorithm to use for search | |
| max_sym_exp: number of maximum symbol expansions at each time step ("tsd") | |
| u_max: maximum output sequence length ("alsd") | |
| nstep: number of maximum expansion steps at each time step ("nsc") | |
| prefix_alpha: maximum prefix length in prefix search ("nsc") | |
| score_norm: normalize final scores by length ("default") | |
| nbest: number of returned final hypothesis | |
| """ | |
| self.decoder = decoder | |
| self.joint_network = joint_network | |
| self.beam_size = beam_size | |
| self.hidden_size = decoder.dunits | |
| self.vocab_size = decoder.odim | |
| self.blank = decoder.blank | |
| if self.beam_size <= 1: | |
| self.search_algorithm = self.greedy_search | |
| elif search_type == "default": | |
| self.search_algorithm = self.default_beam_search | |
| elif search_type == "tsd": | |
| self.search_algorithm = self.time_sync_decoding | |
| elif search_type == "alsd": | |
| self.search_algorithm = self.align_length_sync_decoding | |
| elif search_type == "nsc": | |
| self.search_algorithm = self.nsc_beam_search | |
| else: | |
| raise NotImplementedError | |
| self.lm = lm | |
| self.lm_weight = lm_weight | |
| if lm is not None: | |
| self.use_lm = True | |
| self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False | |
| self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor | |
| self.lm_layers = len(self.lm_predictor.rnn) | |
| else: | |
| self.use_lm = False | |
| self.max_sym_exp = max_sym_exp | |
| self.u_max = u_max | |
| self.nstep = nstep | |
| self.prefix_alpha = prefix_alpha | |
| self.score_norm = score_norm | |
| self.nbest = nbest | |
| def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]: | |
| """Perform beam search. | |
| Args: | |
| h: Encoded speech features (T_max, D_enc) | |
| Returns: | |
| nbest_hyps: N-best decoding results | |
| """ | |
| self.decoder.set_device(h.device) | |
| if not hasattr(self.decoder, "decoders"): | |
| self.decoder.set_data_type(h.dtype) | |
| nbest_hyps = self.search_algorithm(h) | |
| return nbest_hyps | |
| def sort_nbest( | |
| self, hyps: Union[List[Hypothesis], List[NSCHypothesis]] | |
| ) -> Union[List[Hypothesis], List[NSCHypothesis]]: | |
| """Sort hypotheses by score or score given sequence length. | |
| Args: | |
| hyps: list of hypotheses | |
| Return: | |
| hyps: sorted list of hypotheses | |
| """ | |
| if self.score_norm: | |
| hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) | |
| else: | |
| hyps.sort(key=lambda x: x.score, reverse=True) | |
| return hyps[: self.nbest] | |
| def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]: | |
| """Greedy search implementation for transformer-transducer. | |
| Args: | |
| h: Encoded speech features (T_max, D_enc) | |
| Returns: | |
| hyp: 1-best decoding results | |
| """ | |
| dec_state = self.decoder.init_state(1) | |
| hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state) | |
| cache = {} | |
| y, state, _ = self.decoder.score(hyp, cache) | |
| for i, hi in enumerate(h): | |
| ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) | |
| logp, pred = torch.max(ytu, dim=-1) | |
| if pred != self.blank: | |
| hyp.yseq.append(int(pred)) | |
| hyp.score += float(logp) | |
| hyp.dec_state = state | |
| y, state, _ = self.decoder.score(hyp, cache) | |
| return [hyp] | |
| def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]: | |
| """Beam search implementation. | |
| Args: | |
| x: Encoded speech features (T_max, D_enc) | |
| Returns: | |
| nbest_hyps: N-best decoding results | |
| """ | |
| beam = min(self.beam_size, self.vocab_size) | |
| beam_k = min(beam, (self.vocab_size - 1)) | |
| dec_state = self.decoder.init_state(1) | |
| kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)] | |
| cache = {} | |
| for hi in h: | |
| hyps = kept_hyps | |
| kept_hyps = [] | |
| while True: | |
| max_hyp = max(hyps, key=lambda x: x.score) | |
| hyps.remove(max_hyp) | |
| y, state, lm_tokens = self.decoder.score(max_hyp, cache) | |
| ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) | |
| top_k = ytu[1:].topk(beam_k, dim=-1) | |
| kept_hyps.append( | |
| Hypothesis( | |
| score=(max_hyp.score + float(ytu[0:1])), | |
| yseq=max_hyp.yseq[:], | |
| dec_state=max_hyp.dec_state, | |
| lm_state=max_hyp.lm_state, | |
| ) | |
| ) | |
| if self.use_lm: | |
| lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens) | |
| else: | |
| lm_state = max_hyp.lm_state | |
| for logp, k in zip(*top_k): | |
| score = max_hyp.score + float(logp) | |
| if self.use_lm: | |
| score += self.lm_weight * lm_scores[0][k + 1] | |
| hyps.append( | |
| Hypothesis( | |
| score=score, | |
| yseq=max_hyp.yseq[:] + [int(k + 1)], | |
| dec_state=state, | |
| lm_state=lm_state, | |
| ) | |
| ) | |
| hyps_max = float(max(hyps, key=lambda x: x.score).score) | |
| kept_most_prob = sorted( | |
| [hyp for hyp in kept_hyps if hyp.score > hyps_max], | |
| key=lambda x: x.score, | |
| ) | |
| if len(kept_most_prob) >= beam: | |
| kept_hyps = kept_most_prob | |
| break | |
| return self.sort_nbest(kept_hyps) | |
| def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: | |
| """Time synchronous beam search implementation. | |
| Based on https://ieeexplore.ieee.org/document/9053040 | |
| Args: | |
| h: Encoded speech features (T_max, D_enc) | |
| Returns: | |
| nbest_hyps: N-best decoding results | |
| """ | |
| beam = min(self.beam_size, self.vocab_size) | |
| beam_state = self.decoder.init_state(beam) | |
| B = [ | |
| Hypothesis( | |
| yseq=[self.blank], | |
| score=0.0, | |
| dec_state=self.decoder.select_state(beam_state, 0), | |
| ) | |
| ] | |
| cache = {} | |
| if self.use_lm and not self.is_wordlm: | |
| B[0].lm_state = init_lm_state(self.lm_predictor) | |
| for hi in h: | |
| A = [] | |
| C = B | |
| h_enc = hi.unsqueeze(0) | |
| for v in range(self.max_sym_exp): | |
| D = [] | |
| beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( | |
| C, | |
| beam_state, | |
| cache, | |
| self.use_lm, | |
| ) | |
| beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) | |
| beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) | |
| seq_A = [h.yseq for h in A] | |
| for i, hyp in enumerate(C): | |
| if hyp.yseq not in seq_A: | |
| A.append( | |
| Hypothesis( | |
| score=(hyp.score + float(beam_logp[i, 0])), | |
| yseq=hyp.yseq[:], | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| ) | |
| ) | |
| else: | |
| dict_pos = seq_A.index(hyp.yseq) | |
| A[dict_pos].score = np.logaddexp( | |
| A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) | |
| ) | |
| if v < (self.max_sym_exp - 1): | |
| if self.use_lm: | |
| beam_lm_states = create_lm_batch_state( | |
| [c.lm_state for c in C], self.lm_layers, self.is_wordlm | |
| ) | |
| beam_lm_states, beam_lm_scores = self.lm.buff_predict( | |
| beam_lm_states, beam_lm_tokens, len(C) | |
| ) | |
| for i, hyp in enumerate(C): | |
| for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
| new_hyp = Hypothesis( | |
| score=(hyp.score + float(logp)), | |
| yseq=(hyp.yseq + [int(k)]), | |
| dec_state=self.decoder.select_state(beam_state, i), | |
| lm_state=hyp.lm_state, | |
| ) | |
| if self.use_lm: | |
| new_hyp.score += self.lm_weight * beam_lm_scores[i, k] | |
| new_hyp.lm_state = select_lm_state( | |
| beam_lm_states, i, self.lm_layers, self.is_wordlm | |
| ) | |
| D.append(new_hyp) | |
| C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] | |
| B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] | |
| return self.sort_nbest(B) | |
| def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: | |
| """Alignment-length synchronous beam search implementation. | |
| Based on https://ieeexplore.ieee.org/document/9053040 | |
| Args: | |
| h: Encoded speech features (T_max, D_enc) | |
| Returns: | |
| nbest_hyps: N-best decoding results | |
| """ | |
| beam = min(self.beam_size, self.vocab_size) | |
| h_length = int(h.size(0)) | |
| u_max = min(self.u_max, (h_length - 1)) | |
| beam_state = self.decoder.init_state(beam) | |
| B = [ | |
| Hypothesis( | |
| yseq=[self.blank], | |
| score=0.0, | |
| dec_state=self.decoder.select_state(beam_state, 0), | |
| ) | |
| ] | |
| final = [] | |
| cache = {} | |
| if self.use_lm and not self.is_wordlm: | |
| B[0].lm_state = init_lm_state(self.lm_predictor) | |
| for i in range(h_length + u_max): | |
| A = [] | |
| B_ = [] | |
| h_states = [] | |
| for hyp in B: | |
| u = len(hyp.yseq) - 1 | |
| t = i - u + 1 | |
| if t > (h_length - 1): | |
| continue | |
| B_.append(hyp) | |
| h_states.append((t, h[t])) | |
| if B_: | |
| beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( | |
| B_, | |
| beam_state, | |
| cache, | |
| self.use_lm, | |
| ) | |
| h_enc = torch.stack([h[1] for h in h_states]) | |
| beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) | |
| beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) | |
| if self.use_lm: | |
| beam_lm_states = create_lm_batch_state( | |
| [b.lm_state for b in B_], self.lm_layers, self.is_wordlm | |
| ) | |
| beam_lm_states, beam_lm_scores = self.lm.buff_predict( | |
| beam_lm_states, beam_lm_tokens, len(B_) | |
| ) | |
| for i, hyp in enumerate(B_): | |
| new_hyp = Hypothesis( | |
| score=(hyp.score + float(beam_logp[i, 0])), | |
| yseq=hyp.yseq[:], | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| ) | |
| A.append(new_hyp) | |
| if h_states[i][0] == (h_length - 1): | |
| final.append(new_hyp) | |
| for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
| new_hyp = Hypothesis( | |
| score=(hyp.score + float(logp)), | |
| yseq=(hyp.yseq[:] + [int(k)]), | |
| dec_state=self.decoder.select_state(beam_state, i), | |
| lm_state=hyp.lm_state, | |
| ) | |
| if self.use_lm: | |
| new_hyp.score += self.lm_weight * beam_lm_scores[i, k] | |
| new_hyp.lm_state = select_lm_state( | |
| beam_lm_states, i, self.lm_layers, self.is_wordlm | |
| ) | |
| A.append(new_hyp) | |
| B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] | |
| B = recombine_hyps(B) | |
| if final: | |
| return self.sort_nbest(final) | |
| else: | |
| return B | |
| def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]: | |
| """N-step constrained beam search implementation. | |
| Based and modified from https://arxiv.org/pdf/2002.03577.pdf. | |
| Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet | |
| until further modifications. | |
| Note: the algorithm is not in his "complete" form but works almost as | |
| intended. | |
| Args: | |
| h: Encoded speech features (T_max, D_enc) | |
| Returns: | |
| nbest_hyps: N-best decoding results | |
| """ | |
| beam = min(self.beam_size, self.vocab_size) | |
| beam_k = min(beam, (self.vocab_size - 1)) | |
| beam_state = self.decoder.init_state(beam) | |
| init_tokens = [ | |
| NSCHypothesis( | |
| yseq=[self.blank], | |
| score=0.0, | |
| dec_state=self.decoder.select_state(beam_state, 0), | |
| ) | |
| ] | |
| cache = {} | |
| beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( | |
| init_tokens, | |
| beam_state, | |
| cache, | |
| self.use_lm, | |
| ) | |
| state = self.decoder.select_state(beam_state, 0) | |
| if self.use_lm: | |
| beam_lm_states, beam_lm_scores = self.lm.buff_predict( | |
| None, beam_lm_tokens, 1 | |
| ) | |
| lm_state = select_lm_state( | |
| beam_lm_states, 0, self.lm_layers, self.is_wordlm | |
| ) | |
| lm_scores = beam_lm_scores[0] | |
| else: | |
| lm_state = None | |
| lm_scores = None | |
| kept_hyps = [ | |
| NSCHypothesis( | |
| yseq=[self.blank], | |
| score=0.0, | |
| dec_state=state, | |
| y=[beam_y[0]], | |
| lm_state=lm_state, | |
| lm_scores=lm_scores, | |
| ) | |
| ] | |
| for hi in h: | |
| hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True) | |
| kept_hyps = [] | |
| h_enc = hi.unsqueeze(0) | |
| for j, hyp_j in enumerate(hyps[:-1]): | |
| for hyp_i in hyps[(j + 1) :]: | |
| curr_id = len(hyp_j.yseq) | |
| next_id = len(hyp_i.yseq) | |
| if ( | |
| is_prefix(hyp_j.yseq, hyp_i.yseq) | |
| and (curr_id - next_id) <= self.prefix_alpha | |
| ): | |
| ytu = torch.log_softmax( | |
| self.joint_network(hi, hyp_i.y[-1]), dim=-1 | |
| ) | |
| curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]]) | |
| for k in range(next_id, (curr_id - 1)): | |
| ytu = torch.log_softmax( | |
| self.joint_network(hi, hyp_j.y[k]), dim=-1 | |
| ) | |
| curr_score += float(ytu[hyp_j.yseq[k + 1]]) | |
| hyp_j.score = np.logaddexp(hyp_j.score, curr_score) | |
| S = [] | |
| V = [] | |
| for n in range(self.nstep): | |
| beam_y = torch.stack([hyp.y[-1] for hyp in hyps]) | |
| beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) | |
| beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) | |
| for i, hyp in enumerate(hyps): | |
| S.append( | |
| NSCHypothesis( | |
| yseq=hyp.yseq[:], | |
| score=hyp.score + float(beam_logp[i, 0:1]), | |
| y=hyp.y[:], | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| lm_scores=hyp.lm_scores, | |
| ) | |
| ) | |
| for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): | |
| score = hyp.score + float(logp) | |
| if self.use_lm: | |
| score += self.lm_weight * float(hyp.lm_scores[k]) | |
| V.append( | |
| NSCHypothesis( | |
| yseq=hyp.yseq[:] + [int(k)], | |
| score=score, | |
| y=hyp.y[:], | |
| dec_state=hyp.dec_state, | |
| lm_state=hyp.lm_state, | |
| lm_scores=hyp.lm_scores, | |
| ) | |
| ) | |
| V.sort(key=lambda x: x.score, reverse=True) | |
| V = substract(V, hyps)[:beam] | |
| beam_state = self.decoder.create_batch_states( | |
| beam_state, | |
| [v.dec_state for v in V], | |
| [v.yseq for v in V], | |
| ) | |
| beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( | |
| V, | |
| beam_state, | |
| cache, | |
| self.use_lm, | |
| ) | |
| if self.use_lm: | |
| beam_lm_states = create_lm_batch_state( | |
| [v.lm_state for v in V], self.lm_layers, self.is_wordlm | |
| ) | |
| beam_lm_states, beam_lm_scores = self.lm.buff_predict( | |
| beam_lm_states, beam_lm_tokens, len(V) | |
| ) | |
| if n < (self.nstep - 1): | |
| for i, v in enumerate(V): | |
| v.y.append(beam_y[i]) | |
| v.dec_state = self.decoder.select_state(beam_state, i) | |
| if self.use_lm: | |
| v.lm_state = select_lm_state( | |
| beam_lm_states, i, self.lm_layers, self.is_wordlm | |
| ) | |
| v.lm_scores = beam_lm_scores[i] | |
| hyps = V[:] | |
| else: | |
| beam_logp = torch.log_softmax( | |
| self.joint_network(h_enc, beam_y), dim=-1 | |
| ) | |
| for i, v in enumerate(V): | |
| if self.nstep != 1: | |
| v.score += float(beam_logp[i, 0]) | |
| v.y.append(beam_y[i]) | |
| v.dec_state = self.decoder.select_state(beam_state, i) | |
| if self.use_lm: | |
| v.lm_state = select_lm_state( | |
| beam_lm_states, i, self.lm_layers, self.is_wordlm | |
| ) | |
| v.lm_scores = beam_lm_scores[i] | |
| kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] | |
| return self.sort_nbest(kept_hyps) | |