|
import argparse |
|
import lancedb |
|
import torch |
|
import pyarrow as pa |
|
import pandas as pd |
|
from pathlib import Path |
|
import tqdm |
|
import numpy as np |
|
import logging |
|
|
|
from transformers import AutoConfig |
|
from sentence_transformers import SentenceTransformer |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--emb-model", help="embedding model name on HF hub", type=str) |
|
parser.add_argument("--table", help="table name in DB", type=str) |
|
parser.add_argument("--input-dir", help="input directory with documents to ingest", type=str) |
|
parser.add_argument("--vec-column", help="vector column name in the table", type=str, default="vector") |
|
parser.add_argument("--text-column", help="text column name in the table", type=str, default="text") |
|
parser.add_argument("--db-loc", help="database location", type=str, |
|
default=str(Path().resolve() / ".lancedb")) |
|
parser.add_argument("--batch-size", help="batch size for embedding model", type=int, default=32) |
|
parser.add_argument("--num-partitions", help="number of partitions for index", type=int, default=256) |
|
parser.add_argument("--num-sub-vectors", help="number of sub-vectors for index", type=int, default=96) |
|
|
|
args = parser.parse_args() |
|
|
|
emb_config = AutoConfig.from_pretrained(args.emb_model) |
|
emb_dimension = emb_config.hidden_size |
|
|
|
assert emb_dimension % args.num_sub_vectors == 0, \ |
|
"Embedding size must be divisible by the num of sub vectors" |
|
|
|
model = SentenceTransformer(args.emb_model) |
|
model.eval() |
|
|
|
if torch.backends.mps.is_available(): |
|
device = "mps" |
|
elif torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
logger.info(f"using {str(device)} device") |
|
|
|
db = lancedb.connect(args.db_loc) |
|
|
|
schema = pa.schema( |
|
[ |
|
pa.field(args.vec_column, pa.list_(pa.float32(), emb_dimension)), |
|
pa.field(args.text_column, pa.string()) |
|
] |
|
) |
|
tbl = db.create_table(args.table, schema=schema, mode="overwrite") |
|
|
|
input_dir = Path(args.input_dir) |
|
|
|
files = list(input_dir.rglob("*")) |
|
sentences = [] |
|
for file in files: |
|
if file.is_file(): |
|
with open(file, encoding='utf-8') as f: |
|
sentences.append(f.read()) |
|
|
|
for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / args.batch_size)))): |
|
try: |
|
batch = [sent for sent in sentences[i * args.batch_size:(i + 1) * args.batch_size] if len(sent) > 0] |
|
encoded = model.encode(batch, normalize_embeddings=True, device=device) |
|
encoded = [list(vec) for vec in encoded] |
|
|
|
df = pd.DataFrame({ |
|
args.vec_column: encoded, |
|
args.text_column: batch |
|
}) |
|
|
|
tbl.add(df) |
|
except: |
|
logger.info(f"batch {i} was skipped") |
|
|
|
''' |
|
create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/ |
|
with the size of the transformer docs, index is not really needed |
|
but we'll do it for demonstrational purposes |
|
''' |
|
tbl.create_index( |
|
num_partitions=args.num_partitions, |
|
num_sub_vectors=args.num_sub_vectors, |
|
vector_column_name=args.vec_column |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |