quin-general / models /dense_retriever.py
anabmaulana's picture
init
592f71e
raw
history blame
1.67 kB
import h5py
import logging
import pickle
import sqlite3
import struct
import numpy as np
from models.vector_index import VectorIndex
import random
class DenseRetriever:
def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False):
self.model = model
self.vector_index = VectorIndex(768)
self.batch_size = batch_size
self.use_gpu = use_gpu
self.db = sqlite3.connect(db_path)
self.db.row_factory = sqlite3.Row
self.debug = debug
def load_pretrained_index(self, path):
self.vector_index.load(path)
def populate_index(self, table_name):
cur = self.db.cursor()
query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \
else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000'
for r in cur.execute(query):
e = r['encoded']
v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))]
self.vector_index.index.add(np.ascontiguousarray([v]))
print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='')
print()
logging.info("Finished adding vectors.")
def search(self, queries, limit=1000, probes=512, min_similarity=0):
query_vectors = self.model.encode(queries, batch_size=self.batch_size)
ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes)
results = []
for j in range(len(ids)):
results.append([
(ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity
])
return results