Spaces:
Sleeping
Sleeping
| import h5py | |
| import logging | |
| import pickle | |
| import sqlite3 | |
| import struct | |
| import numpy as np | |
| from models.vector_index import VectorIndex | |
| import random | |
| class DenseRetriever: | |
| def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False): | |
| self.model = model | |
| self.vector_index = VectorIndex(768) | |
| self.batch_size = batch_size | |
| self.use_gpu = use_gpu | |
| self.db = sqlite3.connect(db_path) | |
| self.db.row_factory = sqlite3.Row | |
| self.debug = debug | |
| def load_pretrained_index(self, path): | |
| self.vector_index.load(path) | |
| def populate_index(self, table_name): | |
| cur = self.db.cursor() | |
| query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \ | |
| else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000' | |
| for r in cur.execute(query): | |
| e = r['encoded'] | |
| v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))] | |
| self.vector_index.index.add(np.ascontiguousarray([v])) | |
| print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='') | |
| print() | |
| logging.info("Finished adding vectors.") | |
| def search(self, queries, limit=1000, probes=512, min_similarity=0): | |
| query_vectors = self.model.encode(queries, batch_size=self.batch_size) | |
| ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes) | |
| results = [] | |
| for j in range(len(ids)): | |
| results.append([ | |
| (ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity | |
| ]) | |
| return results | |