Spaces:
Build error
Build error
| # ---------------------------------------------------------------------------- | |
| # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
| # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
| # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
| # | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # ---------------------------------------------------------------------------- | |
| import logging | |
| import numpy as np | |
| import torch | |
| import os | |
| import itertools | |
| from fairseq.data import FairseqDataset, data_utils | |
| from fairseq.data import ( | |
| AppendTokenDataset, | |
| ConcatDataset, | |
| PrependTokenDataset, | |
| data_utils, | |
| indexed_dataset, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_langtriple_dataset( | |
| data_path, | |
| split, | |
| src, | |
| src_dict, | |
| ref, | |
| ref_dict, | |
| tgt, | |
| tgt_dict, | |
| combine, | |
| dataset_impl, | |
| upsample_primary, | |
| left_pad_source, | |
| left_pad_target, | |
| max_source_positions, | |
| max_target_positions, | |
| prepend_bos=False, | |
| load_alignments=False, | |
| truncate_source=False, | |
| append_source_id=False, | |
| num_buckets=0, | |
| shuffle=True, | |
| pad_to_multiple=1, | |
| prepend_bos_src=None, | |
| lang_format="[{}]", | |
| ): | |
| assert not truncate_source | |
| def split_exists(split, src, ref, tgt, lang, data_path): | |
| filename = os.path.join(data_path, "{}.{}-{}-{}.{}".format(split, src, ref, tgt, lang)) | |
| return indexed_dataset.dataset_exists(filename, impl=dataset_impl) | |
| src_datasets = [] | |
| ref_datasets = [] | |
| tgt_datasets = [] | |
| for k in itertools.count(): | |
| split_k = split + (str(k) if k > 0 else "") | |
| # infer langcode | |
| if split_exists(split_k, src, ref, tgt, src, data_path): | |
| prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, src, ref, tgt)) | |
| elif split_exists(split_k, tgt, ref, src, src, data_path): | |
| prefix = os.path.join(data_path, "{}.{}-{}-{}.".format(split_k, tgt, ref, src)) | |
| else: | |
| if k > 0: | |
| break | |
| else: | |
| raise FileNotFoundError( | |
| "Dataset not found: {} ({})".format(split, data_path) | |
| ) | |
| src_dataset = data_utils.load_indexed_dataset( | |
| prefix + src, src_dict, dataset_impl | |
| ) | |
| src_datasets.append(src_dataset) | |
| ref_dataset = data_utils.load_indexed_dataset( | |
| prefix + ref, ref_dict, dataset_impl | |
| ) | |
| ref_datasets.append(ref_dataset) | |
| tgt_dataset = data_utils.load_indexed_dataset( | |
| prefix + tgt, tgt_dict, dataset_impl | |
| ) | |
| if tgt_dataset is not None: | |
| tgt_datasets.append(tgt_dataset) | |
| logger.info( | |
| "{} {} {}-{}-{} {} examples".format( | |
| data_path, split_k, src, ref, tgt, len(src_datasets[-1]) | |
| ) | |
| ) | |
| if not combine: | |
| break | |
| assert len(src_datasets) == len(ref_datasets) | |
| assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 | |
| if len(src_datasets) == 1: | |
| src_dataset = src_datasets[0] | |
| ref_dataset = ref_datasets[0] | |
| tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None | |
| else: | |
| sample_ratios = [1] * len(src_datasets) | |
| sample_ratios[0] = upsample_primary | |
| src_dataset = ConcatDataset(src_datasets, sample_ratios) | |
| ref_dataset = ConcatDataset(ref_datasets, sample_ratios) | |
| if len(tgt_datasets) > 0: | |
| tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) | |
| else: | |
| tgt_dataset = None | |
| if prepend_bos: | |
| assert hasattr(src_dict, "bos_index") and hasattr(ref_dict, "bos_index") and hasattr(tgt_dict, "bos_index") | |
| src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) | |
| ref_dataset = PrependTokenDataset(ref_dataset, ref_dict.bos()) | |
| if tgt_dataset is not None: | |
| tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) | |
| elif prepend_bos_src is not None: | |
| logger.info(f"prepending src bos: {prepend_bos_src}") | |
| src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) | |
| ref_dataset = PrependTokenDataset(ref_dataset, prepend_bos_src) | |
| eos = None | |
| if append_source_id: | |
| src_dataset = AppendTokenDataset( | |
| src_dataset, src_dict.index(lang_format.format(src)) | |
| ) | |
| ref_dataset = AppendTokenDataset( | |
| ref_dataset, ref_dict.index(lang_format.format(ref)) | |
| ) | |
| if tgt_dataset is not None: | |
| tgt_dataset = AppendTokenDataset( | |
| tgt_dataset, tgt_dict.index(lang_format.format(tgt)) | |
| ) | |
| eos = tgt_dict.index(lang_format.format(tgt)) | |
| align_dataset = None | |
| if load_alignments: | |
| align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) | |
| if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): | |
| align_dataset = data_utils.load_indexed_dataset( | |
| align_path, None, dataset_impl | |
| ) | |
| tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None | |
| return LanguageTripleDataset( | |
| src_dataset, | |
| src_dataset.sizes, | |
| src_dict, | |
| ref_dataset, | |
| ref_dataset.sizes, | |
| ref_dict, | |
| tgt_dataset, | |
| tgt_dataset_sizes, | |
| tgt_dict, | |
| left_pad_source=left_pad_source, | |
| left_pad_target=left_pad_target, | |
| align_dataset=align_dataset, | |
| eos=eos, | |
| num_buckets=num_buckets, | |
| shuffle=shuffle, | |
| pad_to_multiple=pad_to_multiple, | |
| ) | |
| def collate( | |
| samples, | |
| pad_idx, | |
| eos_idx, | |
| left_pad_source=True, | |
| left_pad_target=False, | |
| input_feeding=True, | |
| pad_to_length=None, | |
| pad_to_multiple=1, | |
| ): | |
| if len(samples) == 0: | |
| return {} | |
| def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): | |
| return data_utils.collate_tokens( | |
| [s[key] for s in samples], | |
| pad_idx, | |
| None, | |
| left_pad, | |
| move_eos_to_beginning, | |
| pad_to_length=pad_to_length, | |
| pad_to_multiple=pad_to_multiple, | |
| ) | |
| def check_alignment(alignment, src_len, tgt_len): | |
| if alignment is None or len(alignment) == 0: | |
| return False | |
| if ( | |
| alignment[:, 0].max().item() >= src_len - 1 | |
| or alignment[:, 1].max().item() >= tgt_len - 1 | |
| ): | |
| logger.warning("alignment size mismatch found, skipping alignment!") | |
| return False | |
| return True | |
| def compute_alignment_weights(alignments): | |
| """ | |
| Given a tensor of shape [:, 2] containing the source-target indices | |
| corresponding to the alignments, a weight vector containing the | |
| inverse frequency of each target index is computed. | |
| For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then | |
| a tensor containing [1., 0.5, 0.5, 1] should be returned (since target | |
| index 3 is repeated twice) | |
| """ | |
| align_tgt = alignments[:, 1] | |
| _, align_tgt_i, align_tgt_c = torch.unique( | |
| align_tgt, return_inverse=True, return_counts=True | |
| ) | |
| align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] | |
| return 1.0 / align_weights.float() | |
| id = torch.LongTensor([s["id"] for s in samples]) | |
| src_tokens = merge( | |
| "source", | |
| left_pad=left_pad_source, | |
| pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
| ) | |
| ref_tokens = merge( | |
| "reference", | |
| left_pad=left_pad_source, | |
| pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
| ) | |
| # sort by descending source length | |
| src_lengths = torch.LongTensor( | |
| [s["source"].ne(pad_idx).long().sum() for s in samples] | |
| ) | |
| ref_lengths = torch.LongTensor( | |
| [s["reference"].ne(pad_idx).long().sum() for s in samples] | |
| ) | |
| src_lengths, sort_order = src_lengths.sort(descending=True) | |
| id = id.index_select(0, sort_order) | |
| src_tokens = src_tokens.index_select(0, sort_order) | |
| ref_lengths = ref_lengths.index_select(0, sort_order) | |
| ref_tokens = ref_tokens.index_select(0, sort_order) | |
| prev_output_tokens = None | |
| target = None | |
| if samples[0].get("target", None) is not None: | |
| target = merge( | |
| "target", | |
| left_pad=left_pad_target, | |
| pad_to_length=pad_to_length["target"] | |
| if pad_to_length is not None | |
| else None, | |
| ) | |
| target = target.index_select(0, sort_order) | |
| tgt_lengths = torch.LongTensor( | |
| [s["target"].ne(pad_idx).long().sum() for s in samples] | |
| ).index_select(0, sort_order) | |
| ntokens = tgt_lengths.sum().item() | |
| if samples[0].get("prev_output_tokens", None) is not None: | |
| prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) | |
| elif input_feeding: | |
| # we create a shifted version of targets for feeding the | |
| # previous output token(s) into the next decoder step | |
| prev_output_tokens = merge( | |
| "target", | |
| left_pad=left_pad_target, | |
| move_eos_to_beginning=True, | |
| pad_to_length=pad_to_length["target"] | |
| if pad_to_length is not None | |
| else None, | |
| ) | |
| else: | |
| ntokens = src_lengths.sum().item() | |
| batch = { | |
| "id": id, | |
| "nsentences": len(samples), | |
| "ntokens": ntokens, | |
| "net_input": { | |
| "src_tokens": src_tokens, | |
| "src_lengths": src_lengths, | |
| }, | |
| "target": target, | |
| "ref_tokens": ref_tokens, | |
| "ref_lengths": ref_lengths, | |
| } | |
| if prev_output_tokens is not None: | |
| batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( | |
| 0, sort_order | |
| ) | |
| if samples[0].get("alignment", None) is not None: | |
| bsz, tgt_sz = batch["target"].shape | |
| src_sz = batch["net_input"]["src_tokens"].shape[1] | |
| offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) | |
| offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz | |
| if left_pad_source: | |
| offsets[:, 0] += src_sz - src_lengths | |
| if left_pad_target: | |
| offsets[:, 1] += tgt_sz - tgt_lengths | |
| alignments = [ | |
| alignment + offset | |
| for align_idx, offset, src_len, tgt_len in zip( | |
| sort_order, offsets, src_lengths, tgt_lengths | |
| ) | |
| for alignment in [samples[align_idx]["alignment"].view(-1, 2)] | |
| if check_alignment(alignment, src_len, tgt_len) | |
| ] | |
| if len(alignments) > 0: | |
| alignments = torch.cat(alignments, dim=0) | |
| align_weights = compute_alignment_weights(alignments) | |
| batch["alignments"] = alignments | |
| batch["align_weights"] = align_weights | |
| if samples[0].get("constraints", None) is not None: | |
| # Collate the packed constraints across the samples, padding to | |
| # the length of the longest sample. | |
| lens = [sample.get("constraints").size(0) for sample in samples] | |
| max_len = max(lens) | |
| constraints = torch.zeros((len(samples), max(lens))).long() | |
| for i, sample in enumerate(samples): | |
| constraints[i, 0 : lens[i]] = samples[i].get("constraints") | |
| batch["constraints"] = constraints.index_select(0, sort_order) | |
| return batch | |
| class LanguageTripleDataset(FairseqDataset): | |
| """ | |
| A pair of torch.utils.data.Datasets. | |
| Args: | |
| src (torch.utils.data.Dataset): source dataset to wrap | |
| src_sizes (List[int]): source sentence lengths | |
| src_dict (~fairseq.data.Dictionary): source vocabulary | |
| tgt (torch.utils.data.Dataset, optional): target dataset to wrap | |
| tgt_sizes (List[int], optional): target sentence lengths | |
| tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary | |
| left_pad_source (bool, optional): pad source tensors on the left side | |
| (default: True). | |
| left_pad_target (bool, optional): pad target tensors on the left side | |
| (default: False). | |
| shuffle (bool, optional): shuffle dataset elements before batching | |
| (default: True). | |
| input_feeding (bool, optional): create a shifted version of the targets | |
| to be passed into the model for teacher forcing (default: True). | |
| remove_eos_from_source (bool, optional): if set, removes eos from end | |
| of source if it's present (default: False). | |
| append_eos_to_target (bool, optional): if set, appends eos to end of | |
| target if it's absent (default: False). | |
| align_dataset (torch.utils.data.Dataset, optional): dataset | |
| containing alignments. | |
| constraints (Tensor, optional): 2d tensor with a concatenated, zero- | |
| delimited list of constraints for each sentence. | |
| append_bos (bool, optional): if set, appends bos to the beginning of | |
| source/target sentence. | |
| num_buckets (int, optional): if set to a value greater than 0, then | |
| batches will be bucketed into the given number of batch shapes. | |
| src_lang_id (int, optional): source language ID, if set, the collated batch | |
| will contain a field 'src_lang_id' in 'net_input' which indicates the | |
| source language of the samples. | |
| tgt_lang_id (int, optional): target language ID, if set, the collated batch | |
| will contain a field 'tgt_lang_id' which indicates the target language | |
| of the samples. | |
| """ | |
| def __init__( | |
| self, | |
| src, | |
| src_sizes, | |
| src_dict, | |
| ref, | |
| ref_sizes, | |
| ref_dict, | |
| tgt=None, | |
| tgt_sizes=None, | |
| tgt_dict=None, | |
| left_pad_source=True, | |
| left_pad_target=False, | |
| shuffle=True, | |
| input_feeding=True, | |
| remove_eos_from_source=False, | |
| append_eos_to_target=False, | |
| align_dataset=None, | |
| constraints=None, | |
| append_bos=False, | |
| eos=None, | |
| num_buckets=0, | |
| src_lang_id=None, | |
| tgt_lang_id=None, | |
| pad_to_multiple=1, | |
| ): | |
| if tgt_dict is not None: | |
| assert src_dict.pad() == tgt_dict.pad() | |
| assert src_dict.eos() == tgt_dict.eos() | |
| assert src_dict.unk() == tgt_dict.unk() | |
| if tgt is not None: | |
| assert len(src) == len( | |
| tgt | |
| ), "Source and target must contain the same number of examples" | |
| assert len(src) == len( | |
| ref | |
| ), "Source and reference must contain the same number of examples" | |
| self.src = src | |
| self.ref = ref | |
| self.tgt = tgt | |
| self.src_sizes = np.array(src_sizes) | |
| self.ref_sizes = np.array(ref_sizes) | |
| self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None | |
| self.sizes = ( | |
| np.vstack((self.src_sizes, self.tgt_sizes)).T | |
| if self.tgt_sizes is not None | |
| else self.src_sizes | |
| ) | |
| self.src_dict = src_dict | |
| self.ref_dict = ref_dict | |
| self.tgt_dict = tgt_dict | |
| self.left_pad_source = left_pad_source | |
| self.left_pad_target = left_pad_target | |
| self.shuffle = shuffle | |
| self.input_feeding = input_feeding | |
| self.remove_eos_from_source = remove_eos_from_source | |
| self.append_eos_to_target = append_eos_to_target | |
| self.align_dataset = align_dataset | |
| if self.align_dataset is not None: | |
| assert ( | |
| self.tgt_sizes is not None | |
| ), "Both source and target needed when alignments are provided" | |
| self.constraints = constraints | |
| self.append_bos = append_bos | |
| self.eos = eos if eos is not None else src_dict.eos() | |
| self.src_lang_id = src_lang_id | |
| self.tgt_lang_id = tgt_lang_id | |
| if num_buckets > 0: | |
| from fairseq.data import BucketPadLengthDataset | |
| self.src = BucketPadLengthDataset( | |
| self.src, | |
| sizes=self.src_sizes, | |
| num_buckets=num_buckets, | |
| pad_idx=self.src_dict.pad(), | |
| left_pad=self.left_pad_source, | |
| ) | |
| self.src_sizes = self.src.sizes | |
| logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) | |
| self.ref = BucketPadLengthDataset( | |
| self.ref, | |
| sizes=self.ref_sizes, | |
| num_buckets=num_buckets, | |
| pad_idx=self.ref_dict.pad(), | |
| left_pad=self.left_pad_source, | |
| ) | |
| self.ref_sizes = self.ref.sizes | |
| logger.info("bucketing reference lengths: {}".format(list(self.src.buckets))) | |
| if self.tgt is not None: | |
| self.tgt = BucketPadLengthDataset( | |
| self.tgt, | |
| sizes=self.tgt_sizes, | |
| num_buckets=num_buckets, | |
| pad_idx=self.tgt_dict.pad(), | |
| left_pad=self.left_pad_target, | |
| ) | |
| self.tgt_sizes = self.tgt.sizes | |
| logger.info( | |
| "bucketing target lengths: {}".format(list(self.tgt.buckets)) | |
| ) | |
| # determine bucket sizes using self.num_tokens, which will return | |
| # the padded lengths (thanks to BucketPadLengthDataset) | |
| num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) | |
| self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) | |
| self.buckets = [ | |
| (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) | |
| ] | |
| else: | |
| self.buckets = None | |
| self.pad_to_multiple = pad_to_multiple | |
| def get_batch_shapes(self): | |
| return self.buckets | |
| def __getitem__(self, index): | |
| tgt_item = self.tgt[index] if self.tgt is not None else None | |
| src_item = self.src[index] | |
| ref_item = self.ref[index] | |
| # Append EOS to end of tgt sentence if it does not have an EOS and remove | |
| # EOS from end of src sentence if it exists. This is useful when we use | |
| # use existing datasets for opposite directions i.e., when we want to | |
| # use tgt_dataset as src_dataset and vice versa | |
| if self.append_eos_to_target: | |
| eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() | |
| if self.tgt and self.tgt[index][-1] != eos: | |
| tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) | |
| if self.append_bos: | |
| bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() | |
| if self.tgt and self.tgt[index][0] != bos: | |
| tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) | |
| bos = self.src_dict.bos() | |
| if self.src[index][0] != bos: | |
| src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) | |
| if self.ref[index][0] != bos: | |
| ref_item = torch.cat([torch.LongTensor([bos]), self.ref[index]]) | |
| if self.remove_eos_from_source: | |
| eos = self.src_dict.eos() | |
| if self.src[index][-1] == eos: | |
| src_item = self.src[index][:-1] | |
| if self.ref[index][-1] == eos: | |
| ref_item = self.ref[index][:-1] | |
| example = { | |
| "id": index, | |
| "source": src_item, | |
| "reference": ref_item, | |
| "target": tgt_item, | |
| } | |
| if self.align_dataset is not None: | |
| example["alignment"] = self.align_dataset[index] | |
| if self.constraints is not None: | |
| example["constraints"] = self.constraints[index] | |
| return example | |
| def __len__(self): | |
| return len(self.src) | |
| def collater(self, samples, pad_to_length=None): | |
| """Merge a list of samples to form a mini-batch. | |
| Args: | |
| samples (List[dict]): samples to collate | |
| pad_to_length (dict, optional): a dictionary of | |
| {'source': source_pad_to_length, 'target': target_pad_to_length} | |
| to indicate the max length to pad to in source and target respectively. | |
| Returns: | |
| dict: a mini-batch with the following keys: | |
| - `id` (LongTensor): example IDs in the original input order | |
| - `ntokens` (int): total number of tokens in the batch | |
| - `net_input` (dict): the input to the Model, containing keys: | |
| - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in | |
| the source sentence of shape `(bsz, src_len)`. Padding will | |
| appear on the left if *left_pad_source* is ``True``. | |
| - `src_lengths` (LongTensor): 1D Tensor of the unpadded | |
| lengths of each source sentence of shape `(bsz)` | |
| - `prev_output_tokens` (LongTensor): a padded 2D Tensor of | |
| tokens in the target sentence, shifted right by one | |
| position for teacher forcing, of shape `(bsz, tgt_len)`. | |
| This key will not be present if *input_feeding* is | |
| ``False``. Padding will appear on the left if | |
| *left_pad_target* is ``True``. | |
| - `src_lang_id` (LongTensor): a long Tensor which contains source | |
| language IDs of each sample in the batch | |
| - `target` (LongTensor): a padded 2D Tensor of tokens in the | |
| target sentence of shape `(bsz, tgt_len)`. Padding will appear | |
| on the left if *left_pad_target* is ``True``. | |
| - `tgt_lang_id` (LongTensor): a long Tensor which contains target language | |
| IDs of each sample in the batch | |
| """ | |
| res = collate( | |
| samples, | |
| pad_idx=self.src_dict.pad(), | |
| eos_idx=self.eos, | |
| left_pad_source=self.left_pad_source, | |
| left_pad_target=self.left_pad_target, | |
| input_feeding=self.input_feeding, | |
| pad_to_length=pad_to_length, | |
| pad_to_multiple=self.pad_to_multiple, | |
| ) | |
| if self.src_lang_id is not None or self.tgt_lang_id is not None: | |
| src_tokens = res["net_input"]["src_tokens"] | |
| bsz = src_tokens.size(0) | |
| if self.src_lang_id is not None: | |
| res["net_input"]["src_lang_id"] = ( | |
| torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) | |
| ) | |
| if self.tgt_lang_id is not None: | |
| res["tgt_lang_id"] = ( | |
| torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) | |
| ) | |
| return res | |
| def num_tokens(self, index): | |
| """Return the number of tokens in a sample. This value is used to | |
| enforce ``--max-tokens`` during batching.""" | |
| return max( | |
| self.src_sizes[index], | |
| self.tgt_sizes[index] if self.tgt_sizes is not None else 0, | |
| ) | |
| def num_tokens_vec(self, indices): | |
| """Return the number of tokens for a set of positions defined by indices. | |
| This value is used to enforce ``--max-tokens`` during batching.""" | |
| sizes = self.src_sizes[indices] | |
| if self.tgt_sizes is not None: | |
| sizes = np.maximum(sizes, self.tgt_sizes[indices]) | |
| return sizes | |
| def size(self, index): | |
| """Return an example's size as a float or tuple. This value is used when | |
| filtering a dataset with ``--max-positions``.""" | |
| return ( | |
| self.src_sizes[index], | |
| self.tgt_sizes[index] if self.tgt_sizes is not None else 0, | |
| ) | |
| def ordered_indices(self): | |
| """Return an ordered list of indices. Batches will be constructed based | |
| on this order.""" | |
| if self.shuffle: | |
| indices = np.random.permutation(len(self)).astype(np.int64) | |
| else: | |
| indices = np.arange(len(self), dtype=np.int64) | |
| if self.buckets is None: | |
| # sort by target length, then source length | |
| if self.tgt_sizes is not None: | |
| indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] | |
| return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] | |
| else: | |
| # sort by bucketed_num_tokens, which is: | |
| # max(padded_src_len, padded_tgt_len) | |
| return indices[ | |
| np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") | |
| ] | |
| def supports_prefetch(self): | |
| return getattr(self.src, "supports_prefetch", False) and ( | |
| getattr(self.tgt, "supports_prefetch", False) or self.tgt is None | |
| ) | |
| def prefetch(self, indices): | |
| self.src.prefetch(indices) | |
| if self.tgt is not None: | |
| self.tgt.prefetch(indices) | |
| if self.align_dataset is not None: | |
| self.align_dataset.prefetch(indices) | |
| def filter_indices_by_size(self, indices, max_sizes): | |
| """Filter a list of sample indices. Remove those that are longer | |
| than specified in max_sizes. | |
| Args: | |
| indices (np.array): original array of sample indices | |
| max_sizes (int or list[int] or tuple[int]): max sample size, | |
| can be defined separately for src and tgt (then list or tuple) | |
| Returns: | |
| np.array: filtered sample array | |
| list: list of removed indices | |
| """ | |
| return data_utils.filter_paired_dataset_indices_by_size( | |
| self.src_sizes, | |
| self.tgt_sizes, | |
| indices, | |
| max_sizes, | |
| ) | |