crop-diag-module / app /utils /data_mapping.py
Sontranwakumo
fix: big update on final result
5dfb339
raw
history blame
4.42 kB
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
from sentence_transformers import SentenceTransformer
import faiss
from pyvi.ViTokenizer import tokenize
import sqlite3
from app.core.type import Node
FAISS_INDEX_PATH = 'app/data/faiss_index.index'
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
class SingletonModel:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(SingletonModel, cls).__new__(cls)
cls._instance.model = SentenceTransformer('dangvantuan/vietnamese-embedding')
return cls._instance
class DataMapping:
def __init__(self):
try:
self.model: SentenceTransformer = SingletonModel().model
self.index: faiss.IndexFlatL2 = self.__load_faiss_index()
self.conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False)
self.cursor = self.conn.cursor()
self.executor = ThreadPoolExecutor(max_workers=4)
except Exception as e:
print(f"Error while initializing DataMapping: {e}")
raise
def __del__(self):
if hasattr(self, 'cursor'):
self.cursor.close()
if hasattr(self, 'conn'):
self.conn.close()
def __load_faiss_index(self, index_file = FAISS_INDEX_PATH):
if os.path.exists(index_file):
index = faiss.read_index(index_file)
print(f"Đã nạp FAISS index từ {index_file}")
return index
return None
def get_top_index_by_text(self, text, top_k=1, distance_threshold=float(0.6)):
if not text or top_k < 1:
raise ValueError("Invalid input: text cannot be empty and top_k must be positive")
q_token = tokenize(text)
q_vec = self.model.encode([q_token])
faiss.normalize_L2(q_vec)
D, I = self.index.search(q_vec, top_k)
mask = D[0] >= distance_threshold
filtered_indices = I[0][mask].tolist()
distances = D[0][mask].tolist()
return filtered_indices, distances
def get_embedding_by_id(self, id):
self.cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,))
return self.cursor.fetchone()
def get_embedding_by_label(self, label):
self.cursor.execute("SELECT * FROM embeddings WHERE label = ?", (label,))
result = self.cursor.fetchall()
return [Node.data_row_to_node(row) for row in result]
def get_embedding_by_id_threadsafe(self, id):
"""Thread-safe version - tạo connection riêng cho mỗi thread"""
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False)
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,))
result = cursor.fetchone()
return result
finally:
cursor.close()
conn.close()
def get_top_result_by_text(self, text, top_k = 1, type = None) -> list[Node]:
"""Sync version của get_top_result_by_text"""
top_index, distances = self.get_top_index_by_text(text, top_k)
results = [self.get_embedding_by_id(int(index)) for index in top_index]
if type:
results = [result for result in results if result[3] == type]
return [Node.data_row_to_node(result, distance) for result, distance in zip(results, distances)]
async def get_top_result_by_text_async(self, text, top_k = 1, type = None) -> list[Node]:
"""Async version - thread-safe với connection riêng"""
def _get_top_result():
try:
top_index, distances = self.get_top_index_by_text(text, top_k)
results = [self.get_embedding_by_id_threadsafe(int(index)) for index in top_index]
if type:
results = [result for result in results if result and result[3] == type]
valid_results = [(result, distance) for result, distance in zip(results, distances) if result is not None]
return [Node.data_row_to_node(result, distance) for result, distance in valid_results]
except Exception as e:
print(f"Error in _get_top_result for text '{text}': {e}")
return []
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, _get_top_result)