# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import numpy as np import pytest import torch from numpy.testing import assert_array_equal from omegaconf import OmegaConf from scripts.nlp_language_modeling.build_knn_map_index import build_map, dedup from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import ( KNNIndex, MMapRetrievalIndexedDataset, MMapRetrievalIndexedDatasetBuilder, merge_knn_files, ) from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset import RETRODataset try: from apex.transformer import parallel_state HAVE_APEX = True except (ImportError, ModuleNotFoundError): HAVE_APEX = False @pytest.mark.run_only_on('GPU') @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") class TestRetrievalIndexFiles: @classmethod def setup_class(cls): init_method = 'tcp://' master_ip = 'localhost' master_port = '6000' init_method += master_ip + ':' + master_port torch.distributed.init_process_group(backend='gloo', world_size=1, rank=0, init_method=init_method) parallel_state.initialize_model_parallel(1, 1) @pytest.mark.unit def test_index(self): chunk_size = 64 stride = 32 sizes = np.array([128, 256], dtype=np.int32) dtype = np.int64 itemsize = dtype().itemsize index_file = '/tmp/test.idx' try: with MMapRetrievalIndexedDataset.Index.writer(index_file, dtype, False) as index: index.write(sizes, chunk_size, stride=stride) index_load = MMapRetrievalIndexedDataset.Index(index_file) assert index_load.chunk_size == chunk_size assert not index_load.retrieval_db assert np.array_equal(index_load.sizes, sizes) assert np.array_equal( index_load._chunk_id_start, np.array([0, len(range(0, sizes[0] - chunk_size + 1, stride))], dtype=np.int64), ) add1 = [i * itemsize for i in list(range(0, sizes[0] - chunk_size + 1, stride))] start = max(add1) + chunk_size * itemsize add2 = [i * itemsize + start for i in list(range(0, sizes[1] - chunk_size + 1, stride))] addr = add1 + add2 assert np.array_equal(index_load._chunk_address, np.array(addr, dtype=np.int64)) assert np.array_equal(index_load._pointers, np.array([0, sizes[0] * itemsize], dtype=np.int64)) assert len(index_load._chunk_address) == index_load.num_chunks finally: os.remove(index_file) chunk_size = 64 stride = 64 sizes = np.array([128, 256], dtype=np.int32) dtype = np.int64 itemsize = dtype().itemsize index_file = '/tmp/test.idx' try: with MMapRetrievalIndexedDataset.Index.writer(index_file, dtype, False) as index: index.write(sizes, chunk_size, stride=stride) index_load = MMapRetrievalIndexedDataset.Index(index_file) assert index_load.chunk_size == chunk_size assert not index_load.retrieval_db assert np.array_equal(index_load.sizes, sizes) assert np.array_equal( index_load._chunk_id_start, np.array([0, len(range(0, sizes[0] - chunk_size + 1, stride))], dtype=np.int64), ) add1 = [i * itemsize for i in list(range(0, sizes[0] - chunk_size + 1, stride))] start = max(add1) + chunk_size * itemsize add2 = [i * itemsize + start for i in list(range(0, sizes[1] - chunk_size + 1, stride))] addr = add1 + add2 assert np.array_equal(index_load._chunk_address, np.array(addr, dtype=np.int64)) assert np.array_equal(index_load._pointers, np.array([0, sizes[0] * itemsize], dtype=np.int64)) assert len(index_load._chunk_address) == index_load.num_chunks finally: os.remove(index_file) @pytest.mark.unit def test_create_data_index_stride32(self): chunk_size = 64 pad_id = 0 stride = 32 sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence1) % chunk_size) gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence2) % chunk_size) gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) data_file = '/tmp/test' index_file = data_file + '.idx' bin_file = data_file + '.bin' try: builder = MMapRetrievalIndexedDatasetBuilder( bin_file, chunk_size, pad_id, False, dtype=np.int64, stride=stride ) builder.add_item(sentence1) builder.add_item(sentence2) builder.finalize(index_file) # load the data ds = MMapRetrievalIndexedDataset(data_file) assert np.array_equal(ds.get(0), gt1) assert np.array_equal(ds.get(1), gt2) fetch1, fetch2 = ds[0:2] assert np.array_equal(fetch1, gt1) assert np.array_equal(fetch2, gt2) chunk_id = ds.get_chunk_id(0, 64) assert chunk_id == 2 assert ds.from_chunk_id_to_doc_id(0) == 0 assert ds.from_chunk_id_to_doc_id(1) == 0 assert ds.from_chunk_id_to_doc_id(2) == 0 with pytest.raises(ValueError): ds.get_chunk_id(0, 128) assert np.array_equal(ds.get_chunk(chunk_id), gt1[64 : 64 + 64]) chunk_id = ds.get_chunk_id(1, 0) assert chunk_id == 3 assert ds.from_chunk_id_to_doc_id(3) == 1 assert ds.from_chunk_id_to_doc_id(4) == 1 assert ds.from_chunk_id_to_doc_id(5) == 1 assert ds.from_chunk_id_to_doc_id(6) == 1 assert ds.from_chunk_id_to_doc_id(7) == 1 assert ds.from_chunk_id_to_doc_id(8) == 1 assert ds.from_chunk_id_to_doc_id(9) == 1 with pytest.raises(ValueError): ds.from_chunk_id_to_doc_id(10) assert np.array_equal(ds.get_chunk(chunk_id), gt2[0:64]) assert np.array_equal(ds.get_chunk(chunk_id + 1), gt2[stride : stride + chunk_size]) assert np.array_equal(ds.get_chunk(chunk_id + 2), gt2[stride * 2 : stride * 2 + chunk_size]) assert np.array_equal(ds.get_chunk(chunk_id + 3), gt2[stride * 3 : stride * 3 + chunk_size]) assert ds.get_chunk_id(1, 64) == 5 assert ds.get_chunk_id(1, 128) == 7 assert ds.get_chunk_id(1, 192) == 9 with pytest.raises(ValueError): ds.get_chunk_id(0, 256) finally: os.remove(index_file) os.remove(bin_file) @pytest.mark.unit def test_create_data_index(self): chunk_size = 64 pad_id = 0 sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence1) % chunk_size) gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence2) % chunk_size) gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) data_file = '/tmp/test' index_file = data_file + '.idx' bin_file = data_file + '.bin' try: builder = MMapRetrievalIndexedDatasetBuilder(bin_file, chunk_size, pad_id, False) builder.add_item(sentence1) builder.add_item(sentence2) builder.finalize(index_file) # load the data ds = MMapRetrievalIndexedDataset(data_file) assert np.array_equal(ds.get(0), gt1) assert np.array_equal(ds.get(1), gt2) fetch1, fetch2 = ds[0:2] assert np.array_equal(fetch1, gt1) assert np.array_equal(fetch2, gt2) chunk_id = ds.get_chunk_id(0, 64) assert chunk_id == 1 assert ds.from_chunk_id_to_doc_id(0) == 0 assert ds.from_chunk_id_to_doc_id(1) == 0 with pytest.raises(ValueError): ds.get_chunk_id(0, 128) assert np.array_equal(ds.get_chunk(chunk_id), gt1[64 : 64 + 64]) chunk_id = ds.get_chunk_id(1, 0) assert chunk_id == 2 assert ds.from_chunk_id_to_doc_id(2) == 1 assert ds.from_chunk_id_to_doc_id(3) == 1 assert ds.from_chunk_id_to_doc_id(4) == 1 assert ds.from_chunk_id_to_doc_id(5) == 1 with pytest.raises(ValueError): ds.from_chunk_id_to_doc_id(6) assert np.array_equal(ds.get_chunk(chunk_id), gt2[0:64]) assert np.array_equal(ds.get_chunk(chunk_id + 1), gt2[64:128]) assert np.array_equal(ds.get_chunk(chunk_id + 2), gt2[128:192]) assert np.array_equal(ds.get_chunk(chunk_id + 3), gt2[192:256]) assert ds.get_chunk_id(1, 64) == 3 assert ds.get_chunk_id(1, 128) == 4 assert ds.get_chunk_id(1, 192) == 5 with pytest.raises(ValueError): ds.get_chunk_id(0, 256) finally: os.remove(index_file) os.remove(bin_file) @pytest.mark.unit def test_create_retrieval_data_index_stride32(self): stride = 32 chunk_size = 64 pad_id = 0 sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence1) % chunk_size) gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) padded_gt1 = np.pad(sentence1, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence2) % chunk_size) gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) padded_gt2 = np.pad(sentence2, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) data_file = '/tmp/test' index_file = data_file + '.idx' bin_file = data_file + '.bin' try: builder = MMapRetrievalIndexedDatasetBuilder(bin_file, chunk_size, pad_id, True, stride=stride) builder.add_item(sentence1) builder.add_item(sentence2) builder.finalize(index_file) # load the data ds = MMapRetrievalIndexedDataset(data_file) assert np.array_equal(ds.get(0), gt1) assert np.array_equal(ds.get(1), gt2) fetch1, fetch2 = ds[0:2] assert np.array_equal(fetch1, gt1) assert np.array_equal(fetch2, gt2) chunk_id = ds.get_chunk_id(0, 64) assert chunk_id == 2 assert ds.from_chunk_id_to_doc_id(0) == 0 assert ds.from_chunk_id_to_doc_id(1) == 0 assert ds.from_chunk_id_to_doc_id(2) == 0 with pytest.raises(ValueError): ds.get_chunk_id(0, 128) assert np.array_equal(ds.get_chunk(chunk_id), padded_gt1[64 : 64 + 64 * 2]) chunk_id = ds.get_chunk_id(1, 0) assert chunk_id == 3 assert ds.from_chunk_id_to_doc_id(3) == 1 assert ds.from_chunk_id_to_doc_id(4) == 1 assert ds.from_chunk_id_to_doc_id(5) == 1 assert ds.from_chunk_id_to_doc_id(6) == 1 assert ds.from_chunk_id_to_doc_id(7) == 1 assert ds.from_chunk_id_to_doc_id(8) == 1 assert ds.from_chunk_id_to_doc_id(9) == 1 with pytest.raises(ValueError): ds.from_chunk_id_to_doc_id(10) assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[0 : chunk_size * 2]) assert np.array_equal(ds.get_chunk(chunk_id + 1), gt2[stride : stride + chunk_size * 2]) assert np.array_equal(ds.get_chunk(chunk_id + 2), gt2[stride * 2 : stride * 2 + chunk_size * 2]) assert np.array_equal(ds.get_chunk(chunk_id + 3), gt2[stride * 3 : stride * 3 + chunk_size * 2]) assert ds.get_chunk_id(1, 64) == 5 assert ds.get_chunk_id(1, 128) == 7 assert ds.get_chunk_id(1, 192) == 9 with pytest.raises(ValueError): ds.get_chunk_id(0, 256) chunk_id = ds.get_chunk_id(1, 64) assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[64:192]) multi_chunks = ds.get_chunk(slice(0, ds.chunks)) assert np.array_equal(multi_chunks[0], padded_gt1[0 : chunk_size * 2]) assert np.array_equal(multi_chunks[1], padded_gt1[stride : stride + chunk_size * 2]) assert np.array_equal(multi_chunks[2], padded_gt1[stride * 2 : stride * 2 + chunk_size * 2]) assert np.array_equal(multi_chunks[3], padded_gt2[0 : chunk_size * 2]) assert np.array_equal(multi_chunks[4], padded_gt2[stride : stride + chunk_size * 2]) assert np.array_equal(multi_chunks[5], padded_gt2[stride * 2 : stride * 2 + chunk_size * 2]) assert np.array_equal(multi_chunks[6], padded_gt2[stride * 3 : stride * 3 + chunk_size * 2]) assert np.array_equal(multi_chunks[7], padded_gt2[stride * 4 : stride * 4 + chunk_size * 2]) assert np.array_equal(multi_chunks[8], padded_gt2[stride * 5 : stride * 5 + chunk_size * 2]) assert np.array_equal(multi_chunks[9], padded_gt2[stride * 6 : stride * 6 + chunk_size * 2]) finally: os.remove(index_file) os.remove(bin_file) @pytest.mark.unit def test_create_retrieval_data_index(self): chunk_size = 64 pad_id = 0 sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence1) % chunk_size) gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) padded_gt1 = np.pad(sentence1, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) padded_size = chunk_size - (len(sentence2) % chunk_size) gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) padded_gt2 = np.pad(sentence2, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) data_file = '/tmp/test' index_file = data_file + '.idx' bin_file = data_file + '.bin' try: builder = MMapRetrievalIndexedDatasetBuilder(bin_file, chunk_size, pad_id, True) builder.add_item(sentence1) builder.add_item(sentence2) builder.finalize(index_file) # load the data ds = MMapRetrievalIndexedDataset(data_file) assert np.array_equal(ds.get(0), gt1) assert np.array_equal(ds.get(1), gt2) fetch1, fetch2 = ds[0:2] assert np.array_equal(fetch1, gt1) assert np.array_equal(fetch2, gt2) chunk_id = ds.get_chunk_id(0, 64) assert chunk_id == 1 assert ds.from_chunk_id_to_doc_id(0) == 0 assert ds.from_chunk_id_to_doc_id(1) == 0 with pytest.raises(ValueError): ds.get_chunk_id(0, 128) assert np.array_equal(ds.get_chunk(chunk_id), padded_gt1[64 : 64 + 64 * 2]) chunk_id = ds.get_chunk_id(1, 0) assert chunk_id == 2 assert ds.from_chunk_id_to_doc_id(2) == 1 assert ds.from_chunk_id_to_doc_id(3) == 1 assert ds.from_chunk_id_to_doc_id(4) == 1 assert ds.from_chunk_id_to_doc_id(5) == 1 with pytest.raises(ValueError): ds.from_chunk_id_to_doc_id(6) assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[0:128]) assert np.array_equal(ds.get_chunk(chunk_id + 1), padded_gt2[64:192]) assert np.array_equal(ds.get_chunk(chunk_id + 2), padded_gt2[128:256]) assert np.array_equal(ds.get_chunk(chunk_id + 3), padded_gt2[192:320]) assert ds.get_chunk_id(1, 64) == 3 assert ds.get_chunk_id(1, 128) == 4 assert ds.get_chunk_id(1, 192) == 5 with pytest.raises(ValueError): ds.get_chunk_id(0, 256) chunk_id = ds.get_chunk_id(1, 64) assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[64:192]) multi_chunks = ds.get_chunk(slice(0, ds.chunks)) assert np.array_equal(multi_chunks[0], padded_gt1[0:128]) assert np.array_equal(multi_chunks[1], padded_gt1[64 : 64 + 128]) assert np.array_equal(multi_chunks[2], padded_gt2[0:128]) assert np.array_equal(multi_chunks[3], padded_gt2[64 : 64 + 128]) assert np.array_equal(multi_chunks[4], padded_gt2[128 : 128 + 128]) assert np.array_equal(multi_chunks[5], padded_gt2[192 : 192 + 128]) finally: os.remove(index_file) os.remove(bin_file) @pytest.mark.unit def test_knn_index(self): data_file = '/tmp/test' index_file = data_file + '.idx' K = 8 index_files = [f'{data_file}_{i}.idx' for i in range(3)] merged_file = '/tmp/merged.idx' try: with KNNIndex.writer(index_file, K) as w: map_np0 = np.random.randint(0, 100, (50, K)) w.write(map_np0) map_np1 = np.random.randint(0, 100, (50, K)) w.write(map_np1) map_np2 = np.random.randint(0, 100, (50, K)) w.write(map_np2) f = KNNIndex(index_file) assert f.K == K assert f.len == map_np0.shape[0] + map_np1.shape[0] + map_np2.shape[0] assert np.array_equal(map_np0, f.knn_map[:50]) assert np.array_equal(map_np1, f.knn_map[50:100]) assert np.array_equal(map_np2, f.knn_map[100:]) assert np.array_equal(f.get_KNN_chunk_ids(5), map_np0[5]) assert f.chunk_start_id == 0 assert f.chunk_end_id == f.len with KNNIndex.writer(index_file, K, 100) as w: map_np0 = np.random.randint(0, 100, (50, K)) w.write(map_np0) map_np1 = np.random.randint(0, 100, (50, K)) w.write(map_np1) map_np2 = np.random.randint(0, 100, (50, K)) w.write(map_np2) f = KNNIndex(index_file) assert f.K == K assert f.len == map_np0.shape[0] + map_np1.shape[0] + map_np2.shape[0] assert np.array_equal(map_np0, f.knn_map[:50]) assert np.array_equal(map_np1, f.knn_map[50:100]) assert np.array_equal(map_np2, f.knn_map[100:]) assert np.array_equal(f.get_KNN_chunk_ids(5 + 100), map_np0[5]) assert f.chunk_start_id == 100 assert f.chunk_end_id == f.len + 100 # test multiple sharding indices inputs = [] start = 0 for i in range(3): with KNNIndex.writer(index_files[i], K, offset=start) as w: map_np0 = np.random.randint(0, 100, (50, K)) inputs.append(map_np0) w.write(map_np0) map_np1 = np.random.randint(0, 100, (50, K)) inputs.append(map_np1) w.write(map_np1) f = KNNIndex(index_files[i]) start += f.len merge_knn_files(index_files, merged_file) f = KNNIndex(merged_file) input_array = np.vstack(inputs) assert f.len == 100 * 3 for i in range(300): assert np.array_equal(f.get_KNN_chunk_ids(i), input_array[i]) assert f.chunk_start_id == 0 assert f.chunk_end_id == f.len assert f.K == K finally: os.remove(index_file) for i in range(3): os.remove(index_files[i]) os.remove(merged_file) @pytest.mark.unit @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") def test_retro_dataset(self): chunk_size = 64 pad_id = 0 sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) sentence3 = torch.arange(0, 300, 2, dtype=torch.int64) sentence4 = torch.arange(1, 400, 2, dtype=torch.int64) # test the case that # training data and retrieval data are different data_file = '/tmp/test_data' data_index_file = data_file + '.idx' data_bin_file = data_file + '.bin' db_file = '/tmp/test_db_data' db_index_file = db_file + '.idx' db_bin_file = db_file + '.bin' K = 8 map_index_file = '/tmp/test_map.idx' index_path = '/tmp' cfg = OmegaConf.create({'data': {"index_mapping_dir": index_path}}) # dummy tokenizer class Tokenizer: eos_id = 1 pad_id = 0 tokenizer = Tokenizer() num_samples = 100 seq_len = 192 name = 'test' data_prefix = 'pref' seed = 1 _filename = index_path + '/' + data_prefix _filename += '_{}_indexmap'.format(name) _filename += '_{}ns'.format(num_samples) _filename += '_{}sl'.format(seq_len) _filename += '_{}s'.format(seed) doc_idx_filename = _filename + '_doc_idx.npy' sample_idx_filename = _filename + '_sample_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy' try: builder = MMapRetrievalIndexedDatasetBuilder(data_bin_file, chunk_size, pad_id, False) builder.add_item(sentence1) builder.add_item(sentence2) builder.finalize(data_index_file) builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True) builder.add_item(sentence3) builder.add_item(sentence4) builder.finalize(db_index_file) # load the data data_index = MMapRetrievalIndexedDataset(data_file) db_index = MMapRetrievalIndexedDataset(db_file) with KNNIndex.writer(map_index_file, K) as w: map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) w.write(map_np) map_index = KNNIndex(map_index_file) documents = np.arange(0, data_index.sizes.shape[0]) d = RETRODataset( cfg, None, tokenizer, name, data_prefix, documents, data_index, num_samples, seq_len, seed, map_index, db_index, ) for i in range(len(d)): record = d[i] assert record['tokens'].shape[0] == seq_len assert record['labels'].shape[0] == seq_len assert record['retrieved_ids'].shape[0] == seq_len // chunk_size assert record['retrieved_ids'].shape[1] == K assert record['retrieved_ids'].shape[2] == chunk_size * 2 assert record['tokens_mask'].shape[0] == seq_len finally: os.remove(data_bin_file) os.remove(data_index_file) os.remove(db_bin_file) os.remove(db_index_file) os.remove(map_index_file) os.remove(doc_idx_filename) os.remove(sample_idx_filename) os.remove(shuffle_idx_filename) # test the case that # training data and retrieval data are the same try: builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True) builder.add_item(sentence1) builder.add_item(sentence2) builder.add_item(sentence3) builder.add_item(sentence4) builder.finalize(db_index_file) # load the data data_index = MMapRetrievalIndexedDataset(db_file) db_index = MMapRetrievalIndexedDataset(db_file) with KNNIndex.writer(map_index_file, K) as w: map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) w.write(map_np) map_index = KNNIndex(map_index_file) documents = np.arange(0, data_index.sizes.shape[0]) d = RETRODataset( cfg, None, tokenizer, name, data_prefix, documents, data_index, num_samples, seq_len, seed, map_index, db_index, ) for i in range(len(d)): record = d[i] assert record['tokens'].shape[0] == seq_len assert record['labels'].shape[0] == seq_len assert record['retrieved_ids'].shape[0] == seq_len // chunk_size assert record['retrieved_ids'].shape[1] == K assert record['retrieved_ids'].shape[2] == chunk_size * 2 assert record['tokens_mask'].shape[0] == seq_len finally: os.remove(db_bin_file) os.remove(db_index_file) os.remove(map_index_file) os.remove(doc_idx_filename) os.remove(sample_idx_filename) os.remove(shuffle_idx_filename) @pytest.mark.unit @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") def test_retro_dataset_stride32(self): chunk_size = 64 pad_id = 0 sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) sentence3 = torch.arange(0, 300, 2, dtype=torch.int64) sentence4 = torch.arange(1, 400, 2, dtype=torch.int64) # test the case that # training data and retrieval data are different data_file = '/tmp/test_data' data_index_file = data_file + '.idx' data_bin_file = data_file + '.bin' db_file = '/tmp/test_db_data' db_index_file = db_file + '.idx' db_bin_file = db_file + '.bin' K = 8 map_index_file = '/tmp/test_map.idx' index_path = '/tmp' cfg = OmegaConf.create({'data': {"index_mapping_dir": index_path}}) # dummy tokenizer class Tokenizer: eos_id = 1 pad_id = 0 tokenizer = Tokenizer() num_samples = 100 stride = 32 seq_len = 192 name = 'test' data_prefix = 'pref' seed = 1 _filename = index_path + '/' + data_prefix _filename += '_{}_indexmap'.format(name) _filename += '_{}ns'.format(num_samples) _filename += '_{}sl'.format(seq_len) _filename += '_{}s'.format(seed) doc_idx_filename = _filename + '_doc_idx.npy' sample_idx_filename = _filename + '_sample_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy' try: builder = MMapRetrievalIndexedDatasetBuilder(data_bin_file, chunk_size, pad_id, False, stride=32) builder.add_item(sentence1) builder.add_item(sentence2) builder.finalize(data_index_file) builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True, stride=32) builder.add_item(sentence3) builder.add_item(sentence4) builder.finalize(db_index_file) # load the data data_index = MMapRetrievalIndexedDataset(data_file) db_index = MMapRetrievalIndexedDataset(db_file) with KNNIndex.writer(map_index_file, K) as w: map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) w.write(map_np) map_index = KNNIndex(map_index_file) documents = np.arange(0, data_index.sizes.shape[0]) d = RETRODataset( cfg, None, tokenizer, name, data_prefix, documents, data_index, num_samples, seq_len, seed, map_index, db_index, ) for i in range(len(d)): record = d[i] assert record['tokens'].shape[0] == seq_len assert record['labels'].shape[0] == seq_len assert record['retrieved_ids'].shape[0] == seq_len // chunk_size assert record['retrieved_ids'].shape[1] == K assert record['retrieved_ids'].shape[2] == chunk_size * 2 assert record['tokens_mask'].shape[0] == seq_len finally: os.remove(data_bin_file) os.remove(data_index_file) os.remove(db_bin_file) os.remove(db_index_file) os.remove(map_index_file) os.remove(doc_idx_filename) os.remove(sample_idx_filename) os.remove(shuffle_idx_filename) # test the case that # training data and retrieval data are the same try: builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True, stride=32) builder.add_item(sentence1) builder.add_item(sentence2) builder.add_item(sentence3) builder.add_item(sentence4) builder.finalize(db_index_file) # load the data data_index = MMapRetrievalIndexedDataset(db_file) db_index = MMapRetrievalIndexedDataset(db_file) with KNNIndex.writer(map_index_file, K) as w: map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) w.write(map_np) map_index = KNNIndex(map_index_file) documents = np.arange(0, data_index.sizes.shape[0]) d = RETRODataset( cfg, None, tokenizer, name, data_prefix, documents, data_index, num_samples, seq_len, seed, map_index, db_index, ) for i in range(len(d)): record = d[i] assert record['tokens'].shape[0] == seq_len assert record['labels'].shape[0] == seq_len assert record['retrieved_ids'].shape[0] == seq_len // chunk_size assert record['retrieved_ids'].shape[1] == K assert record['retrieved_ids'].shape[2] == chunk_size * 2 assert record['tokens_mask'].shape[0] == seq_len finally: os.remove(db_bin_file) os.remove(db_index_file) os.remove(map_index_file) os.remove(doc_idx_filename) os.remove(sample_idx_filename) os.remove(shuffle_idx_filename) @pytest.mark.unit @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") def test_dedup(self): total = 1000 id_start = np.array([0, 100, 200, 300, 500, 900]) beg = 30 end = 210 chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) for i in range(30, 100): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[0:2]) for i in range(100, 200): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[1:3]) for i in range(200, 210): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[2:4]) beg = 5 end = 100 chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) for i in range(beg, end): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[0:2]) beg = 100 end = 200 chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) for i in range(beg, end): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[1:3]) beg = 900 end = 1000 chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) for i in range(beg, end): assert_array_equal(chunk_id_to_doc_id_map[i - beg], np.array([900, 1000])) beg = 150 end = 250 chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) for i in range(beg, 200): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[1:3]) for i in range(200, end): assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[2:4]) I = np.arange(1000)[None, :] tmp_neighbors = np.ones_like(I) * -1 with pytest.raises(ValueError): dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, 0, beg) I = np.arange(1000)[None, :] tmp_neighbors = np.ones_like(I) * -1 with pytest.raises(ValueError): dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, 250, beg) for i in range(beg, 200): I = np.arange(1000)[None, :] tmp_neighbors = np.ones_like(I) * -1 dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, i, beg) gt = np.array(list(range(100)) + list(range(200, 1000)) + ([-1] * 100)) assert_array_equal(tmp_neighbors[0], gt) for i in range(200, 250): I = np.arange(1000)[None, :] tmp_neighbors = np.ones_like(I) * -1 dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, i, beg) gt = np.array(list(range(200)) + list(range(300, 1000)) + ([-1] * 100)) assert_array_equal(tmp_neighbors[0], gt) I = np.arange(1000)[None, :] I = np.repeat(I, 70, axis=0) tmp_neighbors = np.ones_like(I) * -1 dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, 180, beg) gt0 = np.array(list(range(100)) + list(range(200, 1000)) + ([-1] * 100)) gt1 = np.array(list(range(200)) + list(range(300, 1000)) + ([-1] * 100)) for i in range(20): assert_array_equal(tmp_neighbors[i], gt0) for i in range(20, 70): assert_array_equal(tmp_neighbors[i], gt1)