quin-general / models /vector_index.py
anabmaulana's picture
init
592f71e
raw
history blame
2.95 kB
import logging
import math
import pickle
import h5py
import faiss
import numpy as np
import gpustat
import pdb
# see http://ulrichpaquet.com/Papers/SpeedUp.pdf theorem 5
def get_phi(xb):
return (xb ** 2).sum(1).max()
def augment_xb(xb, phi=None):
norms = (xb ** 2).sum(1)
if phi is None:
phi = norms.max()
extracol = np.sqrt(phi - norms)
return np.hstack((xb, extracol.reshape(-1, 1)))
def augment_xq(xq):
extracol = np.zeros(len(xq), dtype='float32')
return np.hstack((xq, extracol.reshape(-1, 1)))
class VectorIndex:
def __init__(self, d):
self.d = d
self.vectors = []
self.index = None
def add(self, v):
self.vectors.append(v)
def add_vectors(self, vs):
logging.info('Adding vectors to index...')
self.index.add(vs)
def build(self, use_gpu=False):
self.vectors = np.array(self.vectors) # OOM at this step if building too many vectors
faiss.normalize_L2(self.vectors)
#self.vectors = augment_xq(self.vectors)
logging.info('Indexing {} vectors'.format(self.vectors.shape[0]))
if self.vectors.shape[0] > 50000:
num_centroids = 8 * int(math.sqrt(math.pow(2, int(math.log(self.vectors.shape[0], 2)))))
logging.info('Using {} centroids'.format(num_centroids))
self.index = faiss.index_factory(self.d, "IVF{}_HNSW32,Flat".format(num_centroids))
ngpu = faiss.get_num_gpus()
if ngpu > 0 and use_gpu:
logging.info('Using {} GPUs'.format(ngpu))
index_ivf = faiss.extract_index_ivf(self.index)
gpustat.print_gpustat()
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(self.d))
index_ivf.clustering_index = clustering_index
logging.info('Training index...')
self.index.train(self.vectors)
else:
self.index = faiss.IndexFlatL2(self.d)
if faiss.get_num_gpus() > 0 and use_gpu:
gpustat.print_gpustat()
self.index = faiss.index_cpu_to_all_gpus(self.index)
def load(self, path):
self.index = faiss.read_index(path)
def save(self, path):
gpustat.print_gpustat()
faiss.write_index(faiss.index_gpu_to_cpu(self.index), path)
def save_vectors(self, path):
#pickle.dump(self.vectors, open(path, 'wb'), protocol=4)
f = h5py.File(path, 'w')
dset = f.create_dataset('data', data=self.vectors)
f.close()
def search(self, vectors, k=1, probes=1):
if not isinstance(vectors, np.ndarray):
vectors = np.array(vectors)
#faiss.normalize_L2(vectors)
try:
self.index.nprobe = probes
except:
pass
distances, ids = self.index.search(vectors, k)
similarities = [(2-d)/2 for d in distances]
return ids, similarities