import argparse import lancedb import torch import pyarrow as pa import pandas as pd from pathlib import Path import tqdm import numpy as np import logging from transformers import AutoConfig from sentence_transformers import SentenceTransformer logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser() parser.add_argument("--emb-model", help="embedding model name on HF hub", type=str) parser.add_argument("--table", help="table name in DB", type=str) parser.add_argument("--input-dir", help="input directory with documents to ingest", type=str) parser.add_argument("--vec-column", help="vector column name in the table", type=str, default="vector") parser.add_argument("--text-column", help="text column name in the table", type=str, default="text") parser.add_argument("--db-loc", help="database location", type=str, default=str(Path().resolve() / ".lancedb")) parser.add_argument("--batch-size", help="batch size for embedding model", type=int, default=32) parser.add_argument("--num-partitions", help="number of partitions for index", type=int, default=256) parser.add_argument("--num-sub-vectors", help="number of sub-vectors for index", type=int, default=96) args = parser.parse_args() emb_config = AutoConfig.from_pretrained(args.emb_model) emb_dimension = emb_config.hidden_size assert emb_dimension % args.num_sub_vectors == 0, \ "Embedding size must be divisible by the num of sub vectors" model = SentenceTransformer(args.emb_model) model.eval() if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" logger.info(f"using {str(device)} device") db = lancedb.connect(args.db_loc) schema = pa.schema( [ pa.field(args.vec_column, pa.list_(pa.float32(), emb_dimension)), pa.field(args.text_column, pa.string()) ] ) tbl = db.create_table(args.table, schema=schema, mode="overwrite") input_dir = Path(args.input_dir) files = list(input_dir.rglob("*")) sentences = [] for file in files: if file.is_file(): with open(file, encoding='utf-8') as f: sentences.append(f.read()) for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / args.batch_size)))): try: batch = [sent for sent in sentences[i * args.batch_size:(i + 1) * args.batch_size] if len(sent) > 0] encoded = model.encode(batch, normalize_embeddings=True, device=device) encoded = [list(vec) for vec in encoded] df = pd.DataFrame({ args.vec_column: encoded, args.text_column: batch }) tbl.add(df) except: logger.info(f"batch {i} was skipped") ''' create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/ with the size of the transformer docs, index is not really needed but we'll do it for demonstrational purposes ''' tbl.create_index( num_partitions=args.num_partitions, num_sub_vectors=args.num_sub_vectors, vector_column_name=args.vec_column ) if __name__ == "__main__": main()