File size: 1,671 Bytes
592f71e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import h5py
import logging
import pickle
import sqlite3
import struct

import numpy as np

from models.vector_index import VectorIndex

import random


class DenseRetriever:
    def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False):
        self.model = model
        self.vector_index = VectorIndex(768)
        self.batch_size = batch_size
        self.use_gpu = use_gpu
        self.db = sqlite3.connect(db_path)
        self.db.row_factory = sqlite3.Row
        self.debug = debug

    def load_pretrained_index(self, path):
        self.vector_index.load(path)

    def populate_index(self, table_name):
        cur = self.db.cursor()
        query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \
            else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000'
        for r in cur.execute(query):
            e = r['encoded']
            v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))]
            self.vector_index.index.add(np.ascontiguousarray([v]))
            print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='')
        print()
        logging.info("Finished adding vectors.")

    def search(self, queries, limit=1000, probes=512, min_similarity=0):
        query_vectors = self.model.encode(queries, batch_size=self.batch_size)
        ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes)
        results = []
        for j in range(len(ids)):
            results.append([
                (ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity
            ])
        return results