|
import lancedb |
|
import os |
|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CROSS_ENC_MODEL = os.getenv("CROSS_ENC_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(CROSS_ENC_MODEL) |
|
cross_encoder = AutoModelForSequenceClassification.from_pretrained(CROSS_ENC_MODEL) |
|
cross_encoder.eval() |
|
|
|
db = lancedb.connect(".lancedb") |
|
TABLE = db.open_table(os.getenv("TABLE_NAME")) |
|
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") |
|
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") |
|
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) |
|
|
|
retriever = SentenceTransformer(os.getenv("EMB_MODEL")) |
|
|
|
def rerank(query, documents): |
|
pairs = [[query, doc] for doc in documents] |
|
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
scores = cross_encoder(**inputs).logits.squeeze() |
|
sorted_docs = [doc for _, doc in sorted(zip(scores, documents), key=lambda x: x[0], reverse=True)] |
|
return sorted_docs |
|
|
|
def retrieve(query, k, rr=True): |
|
query_vec = retriever.encode(query) |
|
try: |
|
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list() |
|
documents = [doc[TEXT_COLUMN] for doc in documents] |
|
|
|
|
|
if rr: |
|
documents = rerank(query, documents) |
|
|
|
return documents |
|
|
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|