File size: 2,954 Bytes
592f71e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import logging
import math
import pickle
import h5py
import faiss

import numpy as np
import gpustat

import pdb

# see http://ulrichpaquet.com/Papers/SpeedUp.pdf theorem 5

def get_phi(xb):
    return (xb ** 2).sum(1).max()


def augment_xb(xb, phi=None):
    norms = (xb ** 2).sum(1)
    if phi is None:
        phi = norms.max()
    extracol = np.sqrt(phi - norms)
    return np.hstack((xb, extracol.reshape(-1, 1)))


def augment_xq(xq):
    extracol = np.zeros(len(xq), dtype='float32')
    return np.hstack((xq, extracol.reshape(-1, 1)))


class VectorIndex:
    def __init__(self, d):
        self.d = d
        self.vectors = []
        self.index = None

    def add(self, v):
        self.vectors.append(v)

    def add_vectors(self, vs):
        logging.info('Adding vectors to index...')
        self.index.add(vs)


    def build(self, use_gpu=False):
        self.vectors = np.array(self.vectors) # OOM at this step if building too many vectors

        faiss.normalize_L2(self.vectors)

        #self.vectors = augment_xq(self.vectors)

        logging.info('Indexing {} vectors'.format(self.vectors.shape[0]))

        if self.vectors.shape[0] > 50000:
            num_centroids = 8 * int(math.sqrt(math.pow(2, int(math.log(self.vectors.shape[0], 2)))))

            logging.info('Using {} centroids'.format(num_centroids))

            self.index = faiss.index_factory(self.d, "IVF{}_HNSW32,Flat".format(num_centroids))

            ngpu = faiss.get_num_gpus()
            if ngpu > 0 and use_gpu:
                logging.info('Using {} GPUs'.format(ngpu))

                index_ivf = faiss.extract_index_ivf(self.index)
                gpustat.print_gpustat()
                clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(self.d))
                index_ivf.clustering_index = clustering_index

            logging.info('Training index...')

            self.index.train(self.vectors)
        else:
            self.index = faiss.IndexFlatL2(self.d)
            if faiss.get_num_gpus() > 0 and use_gpu:
                gpustat.print_gpustat()
                self.index = faiss.index_cpu_to_all_gpus(self.index)

    def load(self, path):
        self.index = faiss.read_index(path)

    def save(self, path):
        gpustat.print_gpustat()
        faiss.write_index(faiss.index_gpu_to_cpu(self.index), path)

    def save_vectors(self, path):
        #pickle.dump(self.vectors, open(path, 'wb'), protocol=4)
        f = h5py.File(path, 'w')
        dset = f.create_dataset('data', data=self.vectors)
        f.close()

    def search(self, vectors, k=1, probes=1):
        if not isinstance(vectors, np.ndarray):
            vectors = np.array(vectors)
        #faiss.normalize_L2(vectors)
        try:
            self.index.nprobe = probes
        except:
            pass
        distances, ids = self.index.search(vectors, k)
        similarities = [(2-d)/2 for d in distances]
        return ids, similarities