|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
from numpy.testing import assert_array_equal |
|
from omegaconf import OmegaConf |
|
from scripts.nlp_language_modeling.build_knn_map_index import build_map, dedup |
|
|
|
from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import ( |
|
KNNIndex, |
|
MMapRetrievalIndexedDataset, |
|
MMapRetrievalIndexedDatasetBuilder, |
|
merge_knn_files, |
|
) |
|
from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset import RETRODataset |
|
|
|
try: |
|
from apex.transformer import parallel_state |
|
|
|
HAVE_APEX = True |
|
except (ImportError, ModuleNotFoundError): |
|
HAVE_APEX = False |
|
|
|
|
|
@pytest.mark.run_only_on('GPU') |
|
@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") |
|
class TestRetrievalIndexFiles: |
|
@classmethod |
|
def setup_class(cls): |
|
init_method = 'tcp://' |
|
master_ip = 'localhost' |
|
master_port = '6000' |
|
init_method += master_ip + ':' + master_port |
|
torch.distributed.init_process_group(backend='gloo', world_size=1, rank=0, init_method=init_method) |
|
parallel_state.initialize_model_parallel(1, 1) |
|
|
|
@pytest.mark.unit |
|
def test_index(self): |
|
chunk_size = 64 |
|
stride = 32 |
|
sizes = np.array([128, 256], dtype=np.int32) |
|
dtype = np.int64 |
|
itemsize = dtype().itemsize |
|
index_file = '/tmp/test.idx' |
|
try: |
|
with MMapRetrievalIndexedDataset.Index.writer(index_file, dtype, False) as index: |
|
index.write(sizes, chunk_size, stride=stride) |
|
|
|
index_load = MMapRetrievalIndexedDataset.Index(index_file) |
|
assert index_load.chunk_size == chunk_size |
|
assert not index_load.retrieval_db |
|
assert np.array_equal(index_load.sizes, sizes) |
|
assert np.array_equal( |
|
index_load._chunk_id_start, |
|
np.array([0, len(range(0, sizes[0] - chunk_size + 1, stride))], dtype=np.int64), |
|
) |
|
add1 = [i * itemsize for i in list(range(0, sizes[0] - chunk_size + 1, stride))] |
|
start = max(add1) + chunk_size * itemsize |
|
add2 = [i * itemsize + start for i in list(range(0, sizes[1] - chunk_size + 1, stride))] |
|
addr = add1 + add2 |
|
assert np.array_equal(index_load._chunk_address, np.array(addr, dtype=np.int64)) |
|
assert np.array_equal(index_load._pointers, np.array([0, sizes[0] * itemsize], dtype=np.int64)) |
|
assert len(index_load._chunk_address) == index_load.num_chunks |
|
finally: |
|
os.remove(index_file) |
|
|
|
chunk_size = 64 |
|
stride = 64 |
|
sizes = np.array([128, 256], dtype=np.int32) |
|
dtype = np.int64 |
|
itemsize = dtype().itemsize |
|
index_file = '/tmp/test.idx' |
|
try: |
|
with MMapRetrievalIndexedDataset.Index.writer(index_file, dtype, False) as index: |
|
index.write(sizes, chunk_size, stride=stride) |
|
|
|
index_load = MMapRetrievalIndexedDataset.Index(index_file) |
|
assert index_load.chunk_size == chunk_size |
|
assert not index_load.retrieval_db |
|
assert np.array_equal(index_load.sizes, sizes) |
|
assert np.array_equal( |
|
index_load._chunk_id_start, |
|
np.array([0, len(range(0, sizes[0] - chunk_size + 1, stride))], dtype=np.int64), |
|
) |
|
add1 = [i * itemsize for i in list(range(0, sizes[0] - chunk_size + 1, stride))] |
|
start = max(add1) + chunk_size * itemsize |
|
add2 = [i * itemsize + start for i in list(range(0, sizes[1] - chunk_size + 1, stride))] |
|
addr = add1 + add2 |
|
assert np.array_equal(index_load._chunk_address, np.array(addr, dtype=np.int64)) |
|
assert np.array_equal(index_load._pointers, np.array([0, sizes[0] * itemsize], dtype=np.int64)) |
|
assert len(index_load._chunk_address) == index_load.num_chunks |
|
finally: |
|
os.remove(index_file) |
|
|
|
@pytest.mark.unit |
|
def test_create_data_index_stride32(self): |
|
chunk_size = 64 |
|
pad_id = 0 |
|
stride = 32 |
|
sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence1) % chunk_size) |
|
gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) |
|
|
|
sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence2) % chunk_size) |
|
gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) |
|
|
|
data_file = '/tmp/test' |
|
index_file = data_file + '.idx' |
|
bin_file = data_file + '.bin' |
|
try: |
|
builder = MMapRetrievalIndexedDatasetBuilder( |
|
bin_file, chunk_size, pad_id, False, dtype=np.int64, stride=stride |
|
) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.finalize(index_file) |
|
|
|
ds = MMapRetrievalIndexedDataset(data_file) |
|
assert np.array_equal(ds.get(0), gt1) |
|
assert np.array_equal(ds.get(1), gt2) |
|
fetch1, fetch2 = ds[0:2] |
|
assert np.array_equal(fetch1, gt1) |
|
assert np.array_equal(fetch2, gt2) |
|
chunk_id = ds.get_chunk_id(0, 64) |
|
assert chunk_id == 2 |
|
assert ds.from_chunk_id_to_doc_id(0) == 0 |
|
assert ds.from_chunk_id_to_doc_id(1) == 0 |
|
assert ds.from_chunk_id_to_doc_id(2) == 0 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 128) |
|
assert np.array_equal(ds.get_chunk(chunk_id), gt1[64 : 64 + 64]) |
|
chunk_id = ds.get_chunk_id(1, 0) |
|
assert chunk_id == 3 |
|
assert ds.from_chunk_id_to_doc_id(3) == 1 |
|
assert ds.from_chunk_id_to_doc_id(4) == 1 |
|
assert ds.from_chunk_id_to_doc_id(5) == 1 |
|
assert ds.from_chunk_id_to_doc_id(6) == 1 |
|
assert ds.from_chunk_id_to_doc_id(7) == 1 |
|
assert ds.from_chunk_id_to_doc_id(8) == 1 |
|
assert ds.from_chunk_id_to_doc_id(9) == 1 |
|
with pytest.raises(ValueError): |
|
ds.from_chunk_id_to_doc_id(10) |
|
assert np.array_equal(ds.get_chunk(chunk_id), gt2[0:64]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 1), gt2[stride : stride + chunk_size]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 2), gt2[stride * 2 : stride * 2 + chunk_size]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 3), gt2[stride * 3 : stride * 3 + chunk_size]) |
|
assert ds.get_chunk_id(1, 64) == 5 |
|
assert ds.get_chunk_id(1, 128) == 7 |
|
assert ds.get_chunk_id(1, 192) == 9 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 256) |
|
finally: |
|
os.remove(index_file) |
|
os.remove(bin_file) |
|
|
|
@pytest.mark.unit |
|
def test_create_data_index(self): |
|
chunk_size = 64 |
|
pad_id = 0 |
|
sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence1) % chunk_size) |
|
gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) |
|
|
|
sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence2) % chunk_size) |
|
gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) |
|
|
|
data_file = '/tmp/test' |
|
index_file = data_file + '.idx' |
|
bin_file = data_file + '.bin' |
|
try: |
|
builder = MMapRetrievalIndexedDatasetBuilder(bin_file, chunk_size, pad_id, False) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.finalize(index_file) |
|
|
|
ds = MMapRetrievalIndexedDataset(data_file) |
|
assert np.array_equal(ds.get(0), gt1) |
|
assert np.array_equal(ds.get(1), gt2) |
|
fetch1, fetch2 = ds[0:2] |
|
assert np.array_equal(fetch1, gt1) |
|
assert np.array_equal(fetch2, gt2) |
|
chunk_id = ds.get_chunk_id(0, 64) |
|
assert chunk_id == 1 |
|
assert ds.from_chunk_id_to_doc_id(0) == 0 |
|
assert ds.from_chunk_id_to_doc_id(1) == 0 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 128) |
|
assert np.array_equal(ds.get_chunk(chunk_id), gt1[64 : 64 + 64]) |
|
chunk_id = ds.get_chunk_id(1, 0) |
|
assert chunk_id == 2 |
|
assert ds.from_chunk_id_to_doc_id(2) == 1 |
|
assert ds.from_chunk_id_to_doc_id(3) == 1 |
|
assert ds.from_chunk_id_to_doc_id(4) == 1 |
|
assert ds.from_chunk_id_to_doc_id(5) == 1 |
|
with pytest.raises(ValueError): |
|
ds.from_chunk_id_to_doc_id(6) |
|
assert np.array_equal(ds.get_chunk(chunk_id), gt2[0:64]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 1), gt2[64:128]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 2), gt2[128:192]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 3), gt2[192:256]) |
|
assert ds.get_chunk_id(1, 64) == 3 |
|
assert ds.get_chunk_id(1, 128) == 4 |
|
assert ds.get_chunk_id(1, 192) == 5 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 256) |
|
finally: |
|
os.remove(index_file) |
|
os.remove(bin_file) |
|
|
|
@pytest.mark.unit |
|
def test_create_retrieval_data_index_stride32(self): |
|
stride = 32 |
|
chunk_size = 64 |
|
pad_id = 0 |
|
sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence1) % chunk_size) |
|
gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) |
|
padded_gt1 = np.pad(sentence1, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) |
|
|
|
sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence2) % chunk_size) |
|
gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) |
|
padded_gt2 = np.pad(sentence2, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) |
|
|
|
data_file = '/tmp/test' |
|
index_file = data_file + '.idx' |
|
bin_file = data_file + '.bin' |
|
try: |
|
builder = MMapRetrievalIndexedDatasetBuilder(bin_file, chunk_size, pad_id, True, stride=stride) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.finalize(index_file) |
|
|
|
ds = MMapRetrievalIndexedDataset(data_file) |
|
assert np.array_equal(ds.get(0), gt1) |
|
assert np.array_equal(ds.get(1), gt2) |
|
fetch1, fetch2 = ds[0:2] |
|
assert np.array_equal(fetch1, gt1) |
|
assert np.array_equal(fetch2, gt2) |
|
chunk_id = ds.get_chunk_id(0, 64) |
|
assert chunk_id == 2 |
|
assert ds.from_chunk_id_to_doc_id(0) == 0 |
|
assert ds.from_chunk_id_to_doc_id(1) == 0 |
|
assert ds.from_chunk_id_to_doc_id(2) == 0 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 128) |
|
assert np.array_equal(ds.get_chunk(chunk_id), padded_gt1[64 : 64 + 64 * 2]) |
|
chunk_id = ds.get_chunk_id(1, 0) |
|
assert chunk_id == 3 |
|
assert ds.from_chunk_id_to_doc_id(3) == 1 |
|
assert ds.from_chunk_id_to_doc_id(4) == 1 |
|
assert ds.from_chunk_id_to_doc_id(5) == 1 |
|
assert ds.from_chunk_id_to_doc_id(6) == 1 |
|
assert ds.from_chunk_id_to_doc_id(7) == 1 |
|
assert ds.from_chunk_id_to_doc_id(8) == 1 |
|
assert ds.from_chunk_id_to_doc_id(9) == 1 |
|
with pytest.raises(ValueError): |
|
ds.from_chunk_id_to_doc_id(10) |
|
assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[0 : chunk_size * 2]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 1), gt2[stride : stride + chunk_size * 2]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 2), gt2[stride * 2 : stride * 2 + chunk_size * 2]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 3), gt2[stride * 3 : stride * 3 + chunk_size * 2]) |
|
assert ds.get_chunk_id(1, 64) == 5 |
|
assert ds.get_chunk_id(1, 128) == 7 |
|
assert ds.get_chunk_id(1, 192) == 9 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 256) |
|
chunk_id = ds.get_chunk_id(1, 64) |
|
assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[64:192]) |
|
multi_chunks = ds.get_chunk(slice(0, ds.chunks)) |
|
assert np.array_equal(multi_chunks[0], padded_gt1[0 : chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[1], padded_gt1[stride : stride + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[2], padded_gt1[stride * 2 : stride * 2 + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[3], padded_gt2[0 : chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[4], padded_gt2[stride : stride + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[5], padded_gt2[stride * 2 : stride * 2 + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[6], padded_gt2[stride * 3 : stride * 3 + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[7], padded_gt2[stride * 4 : stride * 4 + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[8], padded_gt2[stride * 5 : stride * 5 + chunk_size * 2]) |
|
assert np.array_equal(multi_chunks[9], padded_gt2[stride * 6 : stride * 6 + chunk_size * 2]) |
|
finally: |
|
os.remove(index_file) |
|
os.remove(bin_file) |
|
|
|
@pytest.mark.unit |
|
def test_create_retrieval_data_index(self): |
|
|
|
chunk_size = 64 |
|
pad_id = 0 |
|
sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence1) % chunk_size) |
|
gt1 = np.pad(sentence1, (0, padded_size), 'constant', constant_values=pad_id) |
|
padded_gt1 = np.pad(sentence1, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) |
|
|
|
sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) |
|
padded_size = chunk_size - (len(sentence2) % chunk_size) |
|
gt2 = np.pad(sentence2, (0, padded_size), 'constant', constant_values=pad_id) |
|
padded_gt2 = np.pad(sentence2, (0, padded_size + chunk_size), 'constant', constant_values=pad_id) |
|
|
|
data_file = '/tmp/test' |
|
index_file = data_file + '.idx' |
|
bin_file = data_file + '.bin' |
|
try: |
|
builder = MMapRetrievalIndexedDatasetBuilder(bin_file, chunk_size, pad_id, True) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.finalize(index_file) |
|
|
|
ds = MMapRetrievalIndexedDataset(data_file) |
|
assert np.array_equal(ds.get(0), gt1) |
|
assert np.array_equal(ds.get(1), gt2) |
|
fetch1, fetch2 = ds[0:2] |
|
assert np.array_equal(fetch1, gt1) |
|
assert np.array_equal(fetch2, gt2) |
|
chunk_id = ds.get_chunk_id(0, 64) |
|
assert chunk_id == 1 |
|
assert ds.from_chunk_id_to_doc_id(0) == 0 |
|
assert ds.from_chunk_id_to_doc_id(1) == 0 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 128) |
|
assert np.array_equal(ds.get_chunk(chunk_id), padded_gt1[64 : 64 + 64 * 2]) |
|
chunk_id = ds.get_chunk_id(1, 0) |
|
assert chunk_id == 2 |
|
assert ds.from_chunk_id_to_doc_id(2) == 1 |
|
assert ds.from_chunk_id_to_doc_id(3) == 1 |
|
assert ds.from_chunk_id_to_doc_id(4) == 1 |
|
assert ds.from_chunk_id_to_doc_id(5) == 1 |
|
with pytest.raises(ValueError): |
|
ds.from_chunk_id_to_doc_id(6) |
|
assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[0:128]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 1), padded_gt2[64:192]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 2), padded_gt2[128:256]) |
|
assert np.array_equal(ds.get_chunk(chunk_id + 3), padded_gt2[192:320]) |
|
assert ds.get_chunk_id(1, 64) == 3 |
|
assert ds.get_chunk_id(1, 128) == 4 |
|
assert ds.get_chunk_id(1, 192) == 5 |
|
with pytest.raises(ValueError): |
|
ds.get_chunk_id(0, 256) |
|
chunk_id = ds.get_chunk_id(1, 64) |
|
assert np.array_equal(ds.get_chunk(chunk_id), padded_gt2[64:192]) |
|
multi_chunks = ds.get_chunk(slice(0, ds.chunks)) |
|
assert np.array_equal(multi_chunks[0], padded_gt1[0:128]) |
|
assert np.array_equal(multi_chunks[1], padded_gt1[64 : 64 + 128]) |
|
assert np.array_equal(multi_chunks[2], padded_gt2[0:128]) |
|
assert np.array_equal(multi_chunks[3], padded_gt2[64 : 64 + 128]) |
|
assert np.array_equal(multi_chunks[4], padded_gt2[128 : 128 + 128]) |
|
assert np.array_equal(multi_chunks[5], padded_gt2[192 : 192 + 128]) |
|
finally: |
|
os.remove(index_file) |
|
os.remove(bin_file) |
|
|
|
@pytest.mark.unit |
|
def test_knn_index(self): |
|
data_file = '/tmp/test' |
|
index_file = data_file + '.idx' |
|
K = 8 |
|
index_files = [f'{data_file}_{i}.idx' for i in range(3)] |
|
merged_file = '/tmp/merged.idx' |
|
try: |
|
with KNNIndex.writer(index_file, K) as w: |
|
map_np0 = np.random.randint(0, 100, (50, K)) |
|
w.write(map_np0) |
|
map_np1 = np.random.randint(0, 100, (50, K)) |
|
w.write(map_np1) |
|
map_np2 = np.random.randint(0, 100, (50, K)) |
|
w.write(map_np2) |
|
f = KNNIndex(index_file) |
|
assert f.K == K |
|
assert f.len == map_np0.shape[0] + map_np1.shape[0] + map_np2.shape[0] |
|
assert np.array_equal(map_np0, f.knn_map[:50]) |
|
assert np.array_equal(map_np1, f.knn_map[50:100]) |
|
assert np.array_equal(map_np2, f.knn_map[100:]) |
|
assert np.array_equal(f.get_KNN_chunk_ids(5), map_np0[5]) |
|
assert f.chunk_start_id == 0 |
|
assert f.chunk_end_id == f.len |
|
|
|
with KNNIndex.writer(index_file, K, 100) as w: |
|
map_np0 = np.random.randint(0, 100, (50, K)) |
|
w.write(map_np0) |
|
map_np1 = np.random.randint(0, 100, (50, K)) |
|
w.write(map_np1) |
|
map_np2 = np.random.randint(0, 100, (50, K)) |
|
w.write(map_np2) |
|
f = KNNIndex(index_file) |
|
assert f.K == K |
|
assert f.len == map_np0.shape[0] + map_np1.shape[0] + map_np2.shape[0] |
|
assert np.array_equal(map_np0, f.knn_map[:50]) |
|
assert np.array_equal(map_np1, f.knn_map[50:100]) |
|
assert np.array_equal(map_np2, f.knn_map[100:]) |
|
assert np.array_equal(f.get_KNN_chunk_ids(5 + 100), map_np0[5]) |
|
assert f.chunk_start_id == 100 |
|
assert f.chunk_end_id == f.len + 100 |
|
|
|
|
|
inputs = [] |
|
start = 0 |
|
for i in range(3): |
|
with KNNIndex.writer(index_files[i], K, offset=start) as w: |
|
map_np0 = np.random.randint(0, 100, (50, K)) |
|
inputs.append(map_np0) |
|
w.write(map_np0) |
|
map_np1 = np.random.randint(0, 100, (50, K)) |
|
inputs.append(map_np1) |
|
w.write(map_np1) |
|
f = KNNIndex(index_files[i]) |
|
start += f.len |
|
merge_knn_files(index_files, merged_file) |
|
f = KNNIndex(merged_file) |
|
input_array = np.vstack(inputs) |
|
assert f.len == 100 * 3 |
|
for i in range(300): |
|
assert np.array_equal(f.get_KNN_chunk_ids(i), input_array[i]) |
|
assert f.chunk_start_id == 0 |
|
assert f.chunk_end_id == f.len |
|
assert f.K == K |
|
|
|
finally: |
|
os.remove(index_file) |
|
for i in range(3): |
|
os.remove(index_files[i]) |
|
os.remove(merged_file) |
|
|
|
@pytest.mark.unit |
|
@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") |
|
def test_retro_dataset(self): |
|
|
|
chunk_size = 64 |
|
pad_id = 0 |
|
sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) |
|
sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) |
|
sentence3 = torch.arange(0, 300, 2, dtype=torch.int64) |
|
sentence4 = torch.arange(1, 400, 2, dtype=torch.int64) |
|
|
|
|
|
|
|
|
|
data_file = '/tmp/test_data' |
|
data_index_file = data_file + '.idx' |
|
data_bin_file = data_file + '.bin' |
|
db_file = '/tmp/test_db_data' |
|
db_index_file = db_file + '.idx' |
|
db_bin_file = db_file + '.bin' |
|
K = 8 |
|
map_index_file = '/tmp/test_map.idx' |
|
index_path = '/tmp' |
|
|
|
cfg = OmegaConf.create({'data': {"index_mapping_dir": index_path}}) |
|
|
|
|
|
class Tokenizer: |
|
eos_id = 1 |
|
pad_id = 0 |
|
|
|
tokenizer = Tokenizer() |
|
|
|
num_samples = 100 |
|
seq_len = 192 |
|
name = 'test' |
|
data_prefix = 'pref' |
|
seed = 1 |
|
_filename = index_path + '/' + data_prefix |
|
_filename += '_{}_indexmap'.format(name) |
|
_filename += '_{}ns'.format(num_samples) |
|
_filename += '_{}sl'.format(seq_len) |
|
_filename += '_{}s'.format(seed) |
|
doc_idx_filename = _filename + '_doc_idx.npy' |
|
sample_idx_filename = _filename + '_sample_idx.npy' |
|
shuffle_idx_filename = _filename + '_shuffle_idx.npy' |
|
|
|
try: |
|
builder = MMapRetrievalIndexedDatasetBuilder(data_bin_file, chunk_size, pad_id, False) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.finalize(data_index_file) |
|
|
|
builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True) |
|
builder.add_item(sentence3) |
|
builder.add_item(sentence4) |
|
builder.finalize(db_index_file) |
|
|
|
|
|
data_index = MMapRetrievalIndexedDataset(data_file) |
|
db_index = MMapRetrievalIndexedDataset(db_file) |
|
|
|
with KNNIndex.writer(map_index_file, K) as w: |
|
map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) |
|
w.write(map_np) |
|
map_index = KNNIndex(map_index_file) |
|
|
|
documents = np.arange(0, data_index.sizes.shape[0]) |
|
d = RETRODataset( |
|
cfg, |
|
None, |
|
tokenizer, |
|
name, |
|
data_prefix, |
|
documents, |
|
data_index, |
|
num_samples, |
|
seq_len, |
|
seed, |
|
map_index, |
|
db_index, |
|
) |
|
for i in range(len(d)): |
|
record = d[i] |
|
assert record['tokens'].shape[0] == seq_len |
|
assert record['labels'].shape[0] == seq_len |
|
assert record['retrieved_ids'].shape[0] == seq_len // chunk_size |
|
assert record['retrieved_ids'].shape[1] == K |
|
assert record['retrieved_ids'].shape[2] == chunk_size * 2 |
|
assert record['tokens_mask'].shape[0] == seq_len |
|
|
|
finally: |
|
os.remove(data_bin_file) |
|
os.remove(data_index_file) |
|
os.remove(db_bin_file) |
|
os.remove(db_index_file) |
|
os.remove(map_index_file) |
|
os.remove(doc_idx_filename) |
|
os.remove(sample_idx_filename) |
|
os.remove(shuffle_idx_filename) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.add_item(sentence3) |
|
builder.add_item(sentence4) |
|
builder.finalize(db_index_file) |
|
|
|
|
|
data_index = MMapRetrievalIndexedDataset(db_file) |
|
db_index = MMapRetrievalIndexedDataset(db_file) |
|
|
|
with KNNIndex.writer(map_index_file, K) as w: |
|
map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) |
|
w.write(map_np) |
|
map_index = KNNIndex(map_index_file) |
|
|
|
documents = np.arange(0, data_index.sizes.shape[0]) |
|
d = RETRODataset( |
|
cfg, |
|
None, |
|
tokenizer, |
|
name, |
|
data_prefix, |
|
documents, |
|
data_index, |
|
num_samples, |
|
seq_len, |
|
seed, |
|
map_index, |
|
db_index, |
|
) |
|
for i in range(len(d)): |
|
record = d[i] |
|
assert record['tokens'].shape[0] == seq_len |
|
assert record['labels'].shape[0] == seq_len |
|
assert record['retrieved_ids'].shape[0] == seq_len // chunk_size |
|
assert record['retrieved_ids'].shape[1] == K |
|
assert record['retrieved_ids'].shape[2] == chunk_size * 2 |
|
assert record['tokens_mask'].shape[0] == seq_len |
|
|
|
finally: |
|
os.remove(db_bin_file) |
|
os.remove(db_index_file) |
|
os.remove(map_index_file) |
|
os.remove(doc_idx_filename) |
|
os.remove(sample_idx_filename) |
|
os.remove(shuffle_idx_filename) |
|
|
|
@pytest.mark.unit |
|
@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") |
|
def test_retro_dataset_stride32(self): |
|
chunk_size = 64 |
|
pad_id = 0 |
|
sentence1 = torch.arange(0, 200, 2, dtype=torch.int64) |
|
sentence2 = torch.arange(1, 500, 2, dtype=torch.int64) |
|
sentence3 = torch.arange(0, 300, 2, dtype=torch.int64) |
|
sentence4 = torch.arange(1, 400, 2, dtype=torch.int64) |
|
|
|
|
|
|
|
|
|
data_file = '/tmp/test_data' |
|
data_index_file = data_file + '.idx' |
|
data_bin_file = data_file + '.bin' |
|
db_file = '/tmp/test_db_data' |
|
db_index_file = db_file + '.idx' |
|
db_bin_file = db_file + '.bin' |
|
K = 8 |
|
map_index_file = '/tmp/test_map.idx' |
|
index_path = '/tmp' |
|
|
|
cfg = OmegaConf.create({'data': {"index_mapping_dir": index_path}}) |
|
|
|
|
|
class Tokenizer: |
|
eos_id = 1 |
|
pad_id = 0 |
|
|
|
tokenizer = Tokenizer() |
|
|
|
num_samples = 100 |
|
stride = 32 |
|
seq_len = 192 |
|
name = 'test' |
|
data_prefix = 'pref' |
|
seed = 1 |
|
_filename = index_path + '/' + data_prefix |
|
_filename += '_{}_indexmap'.format(name) |
|
_filename += '_{}ns'.format(num_samples) |
|
_filename += '_{}sl'.format(seq_len) |
|
_filename += '_{}s'.format(seed) |
|
doc_idx_filename = _filename + '_doc_idx.npy' |
|
sample_idx_filename = _filename + '_sample_idx.npy' |
|
shuffle_idx_filename = _filename + '_shuffle_idx.npy' |
|
|
|
try: |
|
builder = MMapRetrievalIndexedDatasetBuilder(data_bin_file, chunk_size, pad_id, False, stride=32) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.finalize(data_index_file) |
|
|
|
builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True, stride=32) |
|
builder.add_item(sentence3) |
|
builder.add_item(sentence4) |
|
builder.finalize(db_index_file) |
|
|
|
|
|
data_index = MMapRetrievalIndexedDataset(data_file) |
|
db_index = MMapRetrievalIndexedDataset(db_file) |
|
|
|
with KNNIndex.writer(map_index_file, K) as w: |
|
map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) |
|
w.write(map_np) |
|
map_index = KNNIndex(map_index_file) |
|
|
|
documents = np.arange(0, data_index.sizes.shape[0]) |
|
d = RETRODataset( |
|
cfg, |
|
None, |
|
tokenizer, |
|
name, |
|
data_prefix, |
|
documents, |
|
data_index, |
|
num_samples, |
|
seq_len, |
|
seed, |
|
map_index, |
|
db_index, |
|
) |
|
for i in range(len(d)): |
|
record = d[i] |
|
assert record['tokens'].shape[0] == seq_len |
|
assert record['labels'].shape[0] == seq_len |
|
assert record['retrieved_ids'].shape[0] == seq_len // chunk_size |
|
assert record['retrieved_ids'].shape[1] == K |
|
assert record['retrieved_ids'].shape[2] == chunk_size * 2 |
|
assert record['tokens_mask'].shape[0] == seq_len |
|
|
|
finally: |
|
os.remove(data_bin_file) |
|
os.remove(data_index_file) |
|
os.remove(db_bin_file) |
|
os.remove(db_index_file) |
|
os.remove(map_index_file) |
|
os.remove(doc_idx_filename) |
|
os.remove(sample_idx_filename) |
|
os.remove(shuffle_idx_filename) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
builder = MMapRetrievalIndexedDatasetBuilder(db_bin_file, chunk_size, pad_id, True, stride=32) |
|
builder.add_item(sentence1) |
|
builder.add_item(sentence2) |
|
builder.add_item(sentence3) |
|
builder.add_item(sentence4) |
|
builder.finalize(db_index_file) |
|
|
|
|
|
data_index = MMapRetrievalIndexedDataset(db_file) |
|
db_index = MMapRetrievalIndexedDataset(db_file) |
|
|
|
with KNNIndex.writer(map_index_file, K) as w: |
|
map_np = np.random.randint(-3, db_index.chunks, (data_index.chunks, K)) |
|
w.write(map_np) |
|
map_index = KNNIndex(map_index_file) |
|
|
|
documents = np.arange(0, data_index.sizes.shape[0]) |
|
d = RETRODataset( |
|
cfg, |
|
None, |
|
tokenizer, |
|
name, |
|
data_prefix, |
|
documents, |
|
data_index, |
|
num_samples, |
|
seq_len, |
|
seed, |
|
map_index, |
|
db_index, |
|
) |
|
for i in range(len(d)): |
|
record = d[i] |
|
assert record['tokens'].shape[0] == seq_len |
|
assert record['labels'].shape[0] == seq_len |
|
assert record['retrieved_ids'].shape[0] == seq_len // chunk_size |
|
assert record['retrieved_ids'].shape[1] == K |
|
assert record['retrieved_ids'].shape[2] == chunk_size * 2 |
|
assert record['tokens_mask'].shape[0] == seq_len |
|
|
|
finally: |
|
os.remove(db_bin_file) |
|
os.remove(db_index_file) |
|
os.remove(map_index_file) |
|
os.remove(doc_idx_filename) |
|
os.remove(sample_idx_filename) |
|
os.remove(shuffle_idx_filename) |
|
|
|
@pytest.mark.unit |
|
@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") |
|
def test_dedup(self): |
|
total = 1000 |
|
id_start = np.array([0, 100, 200, 300, 500, 900]) |
|
beg = 30 |
|
end = 210 |
|
chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) |
|
build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) |
|
for i in range(30, 100): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[0:2]) |
|
for i in range(100, 200): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[1:3]) |
|
for i in range(200, 210): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[2:4]) |
|
beg = 5 |
|
end = 100 |
|
chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) |
|
build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) |
|
for i in range(beg, end): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[0:2]) |
|
beg = 100 |
|
end = 200 |
|
chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) |
|
build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) |
|
for i in range(beg, end): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[1:3]) |
|
beg = 900 |
|
end = 1000 |
|
chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) |
|
build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) |
|
for i in range(beg, end): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], np.array([900, 1000])) |
|
beg = 150 |
|
end = 250 |
|
chunk_id_to_doc_id_map = np.zeros((end - beg, 2), dtype=np.int64) |
|
build_map(id_start, chunk_id_to_doc_id_map, total, beg, end) |
|
for i in range(beg, 200): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[1:3]) |
|
for i in range(200, end): |
|
assert_array_equal(chunk_id_to_doc_id_map[i - beg], id_start[2:4]) |
|
|
|
I = np.arange(1000)[None, :] |
|
tmp_neighbors = np.ones_like(I) * -1 |
|
with pytest.raises(ValueError): |
|
dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, 0, beg) |
|
|
|
I = np.arange(1000)[None, :] |
|
tmp_neighbors = np.ones_like(I) * -1 |
|
with pytest.raises(ValueError): |
|
dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, 250, beg) |
|
|
|
for i in range(beg, 200): |
|
I = np.arange(1000)[None, :] |
|
tmp_neighbors = np.ones_like(I) * -1 |
|
dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, i, beg) |
|
gt = np.array(list(range(100)) + list(range(200, 1000)) + ([-1] * 100)) |
|
assert_array_equal(tmp_neighbors[0], gt) |
|
|
|
for i in range(200, 250): |
|
I = np.arange(1000)[None, :] |
|
tmp_neighbors = np.ones_like(I) * -1 |
|
dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, i, beg) |
|
gt = np.array(list(range(200)) + list(range(300, 1000)) + ([-1] * 100)) |
|
assert_array_equal(tmp_neighbors[0], gt) |
|
|
|
I = np.arange(1000)[None, :] |
|
I = np.repeat(I, 70, axis=0) |
|
tmp_neighbors = np.ones_like(I) * -1 |
|
dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, 180, beg) |
|
gt0 = np.array(list(range(100)) + list(range(200, 1000)) + ([-1] * 100)) |
|
gt1 = np.array(list(range(200)) + list(range(300, 1000)) + ([-1] * 100)) |
|
for i in range(20): |
|
assert_array_equal(tmp_neighbors[i], gt0) |
|
for i in range(20, 70): |
|
assert_array_equal(tmp_neighbors[i], gt1) |
|
|