Spaces:
Sleeping
Sleeping
import os | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from pyvi.ViTokenizer import tokenize | |
import sqlite3 | |
from app.core.type import Node | |
FAISS_INDEX_PATH = 'app/data/faiss_index.index' | |
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db' | |
class SingletonModel: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(SingletonModel, cls).__new__(cls) | |
cls._instance.model = SentenceTransformer('dangvantuan/vietnamese-embedding') | |
return cls._instance | |
class DataMapping: | |
def __init__(self): | |
try: | |
self.model: SentenceTransformer = SingletonModel().model | |
self.index: faiss.IndexFlatL2 = self.__load_faiss_index() | |
self.conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False) | |
self.cursor = self.conn.cursor() | |
self.executor = ThreadPoolExecutor(max_workers=4) | |
except Exception as e: | |
print(f"Error while initializing DataMapping: {e}") | |
raise | |
def __del__(self): | |
if hasattr(self, 'cursor'): | |
self.cursor.close() | |
if hasattr(self, 'conn'): | |
self.conn.close() | |
def __load_faiss_index(self, index_file = FAISS_INDEX_PATH): | |
if os.path.exists(index_file): | |
index = faiss.read_index(index_file) | |
print(f"Đã nạp FAISS index từ {index_file}") | |
return index | |
return None | |
def get_top_index_by_text(self, text, top_k=1, distance_threshold=float(0.6)): | |
if not text or top_k < 1: | |
raise ValueError("Invalid input: text cannot be empty and top_k must be positive") | |
q_token = tokenize(text) | |
q_vec = self.model.encode([q_token]) | |
faiss.normalize_L2(q_vec) | |
D, I = self.index.search(q_vec, top_k) | |
mask = D[0] >= distance_threshold | |
filtered_indices = I[0][mask].tolist() | |
distances = D[0][mask].tolist() | |
return filtered_indices, distances | |
def get_embedding_by_id(self, id): | |
self.cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,)) | |
return self.cursor.fetchone() | |
def get_embedding_by_label(self, label): | |
self.cursor.execute("SELECT * FROM embeddings WHERE label = ?", (label,)) | |
result = self.cursor.fetchall() | |
return [Node.data_row_to_node(row) for row in result] | |
def get_embedding_by_id_threadsafe(self, id): | |
"""Thread-safe version - tạo connection riêng cho mỗi thread""" | |
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False) | |
cursor = conn.cursor() | |
try: | |
cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,)) | |
result = cursor.fetchone() | |
return result | |
finally: | |
cursor.close() | |
conn.close() | |
def get_top_result_by_text(self, text, top_k = 1, type = None) -> list[Node]: | |
"""Sync version của get_top_result_by_text""" | |
top_index, distances = self.get_top_index_by_text(text, top_k) | |
results = [self.get_embedding_by_id(int(index)) for index in top_index] | |
if type: | |
results = [result for result in results if result[3] == type] | |
return [Node.data_row_to_node(result, distance) for result, distance in zip(results, distances)] | |
async def get_top_result_by_text_async(self, text, top_k = 1, type = None) -> list[Node]: | |
"""Async version - thread-safe với connection riêng""" | |
def _get_top_result(): | |
try: | |
top_index, distances = self.get_top_index_by_text(text, top_k) | |
results = [self.get_embedding_by_id_threadsafe(int(index)) for index in top_index] | |
if type: | |
results = [result for result in results if result and result[3] == type] | |
valid_results = [(result, distance) for result, distance in zip(results, distances) if result is not None] | |
return [Node.data_row_to_node(result, distance) for result, distance in valid_results] | |
except Exception as e: | |
print(f"Error in _get_top_result for text '{text}': {e}") | |
return [] | |
loop = asyncio.get_event_loop() | |
return await loop.run_in_executor(self.executor, _get_top_result) | |