Spaces:
Sleeping
Sleeping
File size: 1,671 Bytes
592f71e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import h5py
import logging
import pickle
import sqlite3
import struct
import numpy as np
from models.vector_index import VectorIndex
import random
class DenseRetriever:
def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False):
self.model = model
self.vector_index = VectorIndex(768)
self.batch_size = batch_size
self.use_gpu = use_gpu
self.db = sqlite3.connect(db_path)
self.db.row_factory = sqlite3.Row
self.debug = debug
def load_pretrained_index(self, path):
self.vector_index.load(path)
def populate_index(self, table_name):
cur = self.db.cursor()
query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \
else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000'
for r in cur.execute(query):
e = r['encoded']
v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))]
self.vector_index.index.add(np.ascontiguousarray([v]))
print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='')
print()
logging.info("Finished adding vectors.")
def search(self, queries, limit=1000, probes=512, min_similarity=0):
query_vectors = self.model.encode(queries, batch_size=self.batch_size)
ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes)
results = []
for j in range(len(ids)):
results.append([
(ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity
])
return results
|