import h5py import logging import pickle import sqlite3 import struct import numpy as np from models.vector_index import VectorIndex import random class DenseRetriever: def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False): self.model = model self.vector_index = VectorIndex(768) self.batch_size = batch_size self.use_gpu = use_gpu self.db = sqlite3.connect(db_path) self.db.row_factory = sqlite3.Row self.debug = debug def load_pretrained_index(self, path): self.vector_index.load(path) def populate_index(self, table_name): cur = self.db.cursor() query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \ else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000' for r in cur.execute(query): e = r['encoded'] v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))] self.vector_index.index.add(np.ascontiguousarray([v])) print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='') print() logging.info("Finished adding vectors.") def search(self, queries, limit=1000, probes=512, min_similarity=0): query_vectors = self.model.encode(queries, batch_size=self.batch_size) ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes) results = [] for j in range(len(ids)): results.append([ (ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity ]) return results