Spaces:
Sleeping
Sleeping
| import logging | |
| import math | |
| import pickle | |
| import h5py | |
| import faiss | |
| import numpy as np | |
| import gpustat | |
| import pdb | |
| # see http://ulrichpaquet.com/Papers/SpeedUp.pdf theorem 5 | |
| def get_phi(xb): | |
| return (xb ** 2).sum(1).max() | |
| def augment_xb(xb, phi=None): | |
| norms = (xb ** 2).sum(1) | |
| if phi is None: | |
| phi = norms.max() | |
| extracol = np.sqrt(phi - norms) | |
| return np.hstack((xb, extracol.reshape(-1, 1))) | |
| def augment_xq(xq): | |
| extracol = np.zeros(len(xq), dtype='float32') | |
| return np.hstack((xq, extracol.reshape(-1, 1))) | |
| class VectorIndex: | |
| def __init__(self, d): | |
| self.d = d | |
| self.vectors = [] | |
| self.index = None | |
| def add(self, v): | |
| self.vectors.append(v) | |
| def add_vectors(self, vs): | |
| logging.info('Adding vectors to index...') | |
| self.index.add(vs) | |
| def build(self, use_gpu=False): | |
| self.vectors = np.array(self.vectors) # OOM at this step if building too many vectors | |
| faiss.normalize_L2(self.vectors) | |
| #self.vectors = augment_xq(self.vectors) | |
| logging.info('Indexing {} vectors'.format(self.vectors.shape[0])) | |
| if self.vectors.shape[0] > 50000: | |
| num_centroids = 8 * int(math.sqrt(math.pow(2, int(math.log(self.vectors.shape[0], 2))))) | |
| logging.info('Using {} centroids'.format(num_centroids)) | |
| self.index = faiss.index_factory(self.d, "IVF{}_HNSW32,Flat".format(num_centroids)) | |
| ngpu = faiss.get_num_gpus() | |
| if ngpu > 0 and use_gpu: | |
| logging.info('Using {} GPUs'.format(ngpu)) | |
| index_ivf = faiss.extract_index_ivf(self.index) | |
| gpustat.print_gpustat() | |
| clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(self.d)) | |
| index_ivf.clustering_index = clustering_index | |
| logging.info('Training index...') | |
| self.index.train(self.vectors) | |
| else: | |
| self.index = faiss.IndexFlatL2(self.d) | |
| if faiss.get_num_gpus() > 0 and use_gpu: | |
| gpustat.print_gpustat() | |
| self.index = faiss.index_cpu_to_all_gpus(self.index) | |
| def load(self, path): | |
| self.index = faiss.read_index(path) | |
| def save(self, path): | |
| gpustat.print_gpustat() | |
| faiss.write_index(faiss.index_gpu_to_cpu(self.index), path) | |
| def save_vectors(self, path): | |
| #pickle.dump(self.vectors, open(path, 'wb'), protocol=4) | |
| f = h5py.File(path, 'w') | |
| dset = f.create_dataset('data', data=self.vectors) | |
| f.close() | |
| def search(self, vectors, k=1, probes=1): | |
| if not isinstance(vectors, np.ndarray): | |
| vectors = np.array(vectors) | |
| #faiss.normalize_L2(vectors) | |
| try: | |
| self.index.nprobe = probes | |
| except: | |
| pass | |
| distances, ids = self.index.search(vectors, k) | |
| similarities = [(2-d)/2 for d in distances] | |
| return ids, similarities | |