Spaces:
Sleeping
Sleeping
| from transformers import AutoConfig | |
| from sentence_transformers import SentenceTransformer | |
| import lancedb | |
| import torch | |
| import pyarrow as pa | |
| import pandas as pd | |
| import numpy as np | |
| import tqdm | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| class VectorDB: | |
| vector_column = "vector" | |
| description_column = "description" | |
| name_column = "name" | |
| table_name = "pimcore_actions" | |
| emb_model = '' | |
| db_location = '' | |
| def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size): | |
| self.retriever = SentenceTransformer(emb_model) | |
| emb_config = AutoConfig.from_pretrained(emb_model) | |
| emb_dimension = emb_config.hidden_size | |
| assert emb_dimension % num_sub_vectors == 0, \ | |
| "Embedding size must be divisible by the num of sub vectors" | |
| print('Model loaded...') | |
| print(emb_model) | |
| model = SentenceTransformer(emb_model) | |
| model.eval() | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| print(f"Device: {device}") | |
| db = lancedb.connect(db_location) | |
| schema = pa.schema( | |
| [ | |
| pa.field(self.vector_column, pa.list_(pa.float32(), emb_dimension)), | |
| pa.field(self.description_column, pa.string()), | |
| pa.field(self.name_column, pa.string()) | |
| ] | |
| ) | |
| tbl = db.create_table(self.table_name, schema=schema, mode="overwrite") | |
| df = pd.read_csv(actions_list_file_path) | |
| sentences = df.values | |
| print("Starting vector generation") | |
| for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))): | |
| try: | |
| batch = [sent for sent in sentences[i * batch_size:(i + 1) * batch_size] if len(sent) > 0] | |
| to_encode = [entry[1] for entry in batch] | |
| names = [entry[0] for entry in batch] | |
| encoded = model.encode(to_encode, normalize_embeddings=True, device=device) | |
| encoded = [list(vec) for vec in encoded] | |
| df = pd.DataFrame({ | |
| self.vector_column: encoded, | |
| self.description_column: to_encode, | |
| self.name_column: names | |
| }) | |
| tbl.add(df) | |
| except: | |
| print(f"batch {i} was skipped") | |
| self.db = db | |
| self.table = tbl | |
| print("Vector generation done.") | |
| # def get_embedding_db_as_pandas(self): | |
| # db = lancedb.connect(self.db_location) | |
| # tbl = db.open_table(self.table_name) | |
| # return tbl.to_pandas() | |
| def retrieve_prefiltered_hits(self, query, k): | |
| query_vec = self.retriever.encode(query) | |
| logger.info('encoded') | |
| documents = self.table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list() | |
| names = [doc[self.name_column] for doc in documents] | |
| descriptions = [doc[self.description_column] for doc in documents] | |
| logger.info('done topK lookup') | |
| return names, descriptions | |