import logging import math import pickle import h5py import faiss import numpy as np import gpustat import pdb # see http://ulrichpaquet.com/Papers/SpeedUp.pdf theorem 5 def get_phi(xb): return (xb ** 2).sum(1).max() def augment_xb(xb, phi=None): norms = (xb ** 2).sum(1) if phi is None: phi = norms.max() extracol = np.sqrt(phi - norms) return np.hstack((xb, extracol.reshape(-1, 1))) def augment_xq(xq): extracol = np.zeros(len(xq), dtype='float32') return np.hstack((xq, extracol.reshape(-1, 1))) class VectorIndex: def __init__(self, d): self.d = d self.vectors = [] self.index = None def add(self, v): self.vectors.append(v) def add_vectors(self, vs): logging.info('Adding vectors to index...') self.index.add(vs) def build(self, use_gpu=False): self.vectors = np.array(self.vectors) # OOM at this step if building too many vectors faiss.normalize_L2(self.vectors) #self.vectors = augment_xq(self.vectors) logging.info('Indexing {} vectors'.format(self.vectors.shape[0])) if self.vectors.shape[0] > 50000: num_centroids = 8 * int(math.sqrt(math.pow(2, int(math.log(self.vectors.shape[0], 2))))) logging.info('Using {} centroids'.format(num_centroids)) self.index = faiss.index_factory(self.d, "IVF{}_HNSW32,Flat".format(num_centroids)) ngpu = faiss.get_num_gpus() if ngpu > 0 and use_gpu: logging.info('Using {} GPUs'.format(ngpu)) index_ivf = faiss.extract_index_ivf(self.index) gpustat.print_gpustat() clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(self.d)) index_ivf.clustering_index = clustering_index logging.info('Training index...') self.index.train(self.vectors) else: self.index = faiss.IndexFlatL2(self.d) if faiss.get_num_gpus() > 0 and use_gpu: gpustat.print_gpustat() self.index = faiss.index_cpu_to_all_gpus(self.index) def load(self, path): self.index = faiss.read_index(path) def save(self, path): gpustat.print_gpustat() faiss.write_index(faiss.index_gpu_to_cpu(self.index), path) def save_vectors(self, path): #pickle.dump(self.vectors, open(path, 'wb'), protocol=4) f = h5py.File(path, 'w') dset = f.create_dataset('data', data=self.vectors) f.close() def search(self, vectors, k=1, probes=1): if not isinstance(vectors, np.ndarray): vectors = np.array(vectors) #faiss.normalize_L2(vectors) try: self.index.nprobe = probes except: pass distances, ids = self.index.search(vectors, k) similarities = [(2-d)/2 for d in distances] return ids, similarities