BackEnd / core /embedding_model.py
HaRin2806
fix bug
76a8f20
import logging
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import uuid
import os
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME
logger = logging.getLogger(__name__)
_embedding_model_instance = None
def get_embedding_model():
"""Kiểm tra và khởi tạo embedding đảm bảo chỉ khởi tạo một lần"""
global _embedding_model_instance
if _embedding_model_instance is None:
logger.info("Khởi tạo EmbeddingModel instance lần đầu")
_embedding_model_instance = EmbeddingModel()
else:
logger.debug("Sử dụng EmbeddingModel instance đã có")
return _embedding_model_instance
class EmbeddingModel:
def __init__(self):
"""Khởi tạo embedding model và ChromaDB client"""
logger.info(f"Đang khởi tạo embedding model: {EMBEDDING_MODEL}")
try:
# Khởi tạo sentence transformer với trust_remote_code=True
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True)
logger.info("Đã tải sentence transformer model")
except Exception as e:
logger.error(f"Lỗi khởi tạo model: {e}")
# Thử với cache_folder explicit
cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME', '/app/.cache/sentence-transformers')
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True)
logger.info("Đã tải sentence transformer model với cache folder explicit")
# SỬA: Khai báo biến persist_directory local để tránh lỗi scope
persist_directory = CHROMA_PERSIST_DIRECTORY
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi
try:
os.makedirs(persist_directory, exist_ok=True)
# Test ghi file để kiểm tra permission
test_file = os.path.join(persist_directory, 'test_permission.tmp')
with open(test_file, 'w') as f:
f.write('test')
os.remove(test_file)
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {persist_directory}")
except Exception as e:
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}")
# Fallback to /tmp directory
import tempfile
persist_directory = os.path.join(tempfile.gettempdir(), 'chroma_db')
os.makedirs(persist_directory, exist_ok=True)
logger.warning(f"Sử dụng thư mục tạm thời: {persist_directory}")
# Khởi tạo ChromaDB client với persistent storage
try:
self.chroma_client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
logger.info(f"Đã kết nối ChromaDB tại: {persist_directory}")
except Exception as e:
logger.error(f"Lỗi kết nối ChromaDB: {e}")
# Fallback to in-memory client
logger.warning("Fallback to in-memory ChromaDB client")
self.chroma_client = chromadb.Client()
# Lấy hoặc tạo collection với cosine similarity
try:
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items")
except Exception:
logger.info(f"Collection '{COLLECTION_NAME}' không tồn tại, tạo mới với cosine similarity...")
self.collection = self.chroma_client.create_collection(
name=COLLECTION_NAME,
metadata={
"hnsw:space": "cosine", # Cosine distance
"hnsw:M": 16, # Optimize for accuracy
"hnsw:construction_ef": 100
}
)
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}")
def _initialize_collection(self):
"""Khởi tạo collection với cosine similarity"""
try:
# Kiểm tra xem collection đã tồn tại chưa
existing_collections = [col.name for col in self.chroma_client.list_collections()]
if COLLECTION_NAME in existing_collections:
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
# Kiểm tra distance function hiện tại
current_metadata = self.collection.metadata or {}
current_space = current_metadata.get("hnsw:space", "l2")
if current_space != "cosine":
logger.warning(f"Collection hiện tại đang dùng {current_space}, cần migration sang cosine")
if self.collection.count() > 0:
self._migrate_to_cosine()
else:
# Collection trống, xóa và tạo lại
self.chroma_client.delete_collection(name=COLLECTION_NAME)
self._create_cosine_collection()
else:
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với cosine similarity, {self.collection.count()} items")
else:
# Collection chưa tồn tại, tạo mới với cosine
self._create_cosine_collection()
except Exception as e:
logger.error(f"Lỗi khởi tạo collection: {e}")
# Fallback: tạo collection mới
self._create_cosine_collection()
def _create_cosine_collection(self):
"""Tạo collection mới với cosine similarity"""
try:
self.collection = self.chroma_client.create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"}
)
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}")
except Exception as e:
logger.error(f"Lỗi tạo collection với cosine: {e}")
# Fallback về collection mặc định
self.collection = self.chroma_client.get_or_create_collection(name=COLLECTION_NAME)
logger.warning("Đã fallback về collection mặc định (có thể dùng L2)")
def _migrate_to_cosine(self):
"""Migration collection từ L2 sang cosine"""
try:
logger.info("Bắt đầu migration collection sang cosine similarity...")
# Backup toàn bộ data
all_data = self.collection.get(
include=['documents', 'metadatas', 'embeddings'],
limit=self.collection.count()
)
if not all_data['documents']:
logger.info("Collection trống, chỉ cần tạo lại")
self.chroma_client.delete_collection(name=COLLECTION_NAME)
self._create_cosine_collection()
return
# Xóa collection cũ và tạo mới với cosine
self.chroma_client.delete_collection(name=COLLECTION_NAME)
self._create_cosine_collection()
# Restore data theo batch
documents = all_data['documents']
metadatas = all_data['metadatas']
embeddings = all_data['embeddings']
ids = all_data['ids']
batch_size = 100
total_items = len(documents)
for i in range(0, total_items, batch_size):
batch_docs = documents[i:i + batch_size]
batch_metas = metadatas[i:i + batch_size] if metadatas else None
batch_embeds = embeddings[i:i + batch_size] if embeddings else None
batch_ids = ids[i:i + batch_size]
if batch_embeds:
# Có embeddings sẵn, dùng luôn
self.collection.add(
documents=batch_docs,
metadatas=batch_metas,
embeddings=batch_embeds,
ids=batch_ids
)
else:
# Tính lại embeddings
new_embeddings = self.encode(batch_docs, is_query=False)
self.collection.add(
documents=batch_docs,
metadatas=batch_metas,
embeddings=new_embeddings,
ids=batch_ids
)
logger.info(f"Migration progress: {min(i + batch_size, total_items)}/{total_items}")
logger.info(f"Migration hoàn thành! Đã chuyển {total_items} items sang cosine similarity")
except Exception as e:
logger.error(f"Lỗi migration: {e}")
# Tạo collection mới nếu migration thất bại
self._create_cosine_collection()
def test_embedding_quality(self):
try:
# Test cases
test_cases = [
("query: Tháp dinh dưỡng cho trẻ", "passage: Tháp dinh dưỡng cho trẻ từ 6-11 tuổi"),
("query: dinh dưỡng", "passage: dinh dưỡng cho học sinh"),
("query: xin chào", "passage: Tháp dinh dưỡng cho trẻ")
]
for query_text, doc_text in test_cases:
# Encode
query_emb = self.model.encode([query_text], normalize_embeddings=True)[0]
doc_emb = self.model.encode([doc_text], normalize_embeddings=True)[0]
# Calculate cosine similarity manually
import numpy as np
similarity = np.dot(query_emb, doc_emb)
logger.info(f"Query: {query_text}")
logger.info(f"Doc: {doc_text}")
logger.info(f"Similarity: {similarity:.3f}")
logger.info(f"Query norm: {np.linalg.norm(query_emb):.3f}")
logger.info(f"Doc norm: {np.linalg.norm(doc_emb):.3f}")
logger.info("-" * 50)
except Exception as e:
logger.error(f"Test embedding error: {e}")
def _add_prefix_to_text(self, text, is_query=True):
# Clean text trước
text = text.strip()
# Kiểm tra xem text đã có prefix chưa
if text.startswith(('query:', 'passage:')):
return text
# Thêm prefix phù hợp
if is_query:
return f"query: {text}"
else:
return f"passage: {text}"
def encode(self, texts, is_query=True):
"""
Encode văn bản thành embeddings với proper normalization
"""
try:
if isinstance(texts, str):
texts = [texts]
# Thêm prefix cho texts (QUAN TRỌNG cho multilingual-e5-base)
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts]
logger.debug(f"Đang encode {len(processed_texts)} văn bản với prefix")
logger.debug(f"Sample processed text: {processed_texts[0][:100]}...")
# Encode với normalize_embeddings=True (QUAN TRỌNG!)
embeddings = self.model.encode(
processed_texts,
show_progress_bar=False,
normalize_embeddings=True # ✅ THÊM DÒNG NÀY
)
# Double-check normalization
import numpy as np
for i, emb in enumerate(embeddings[:2]): # Check first 2 embeddings
norm = np.linalg.norm(emb)
logger.debug(f"Embedding {i} norm: {norm}")
if abs(norm - 1.0) > 0.01:
logger.warning(f"Embedding {i} not properly normalized: norm = {norm}")
return embeddings.tolist()
except Exception as e:
logger.error(f"Lỗi encode văn bản: {e}")
raise
def search(self, query, top_k=5, age_filter=None):
"""Tìm kiếm văn bản tương tự trong ChromaDB"""
try:
query_embedding = self.encode(query, is_query=True)[0]
where_clause = None
if age_filter:
where_clause = {
"$and": [
{"age_min": {"$lte": age_filter}},
{"age_max": {"$gte": age_filter}}
]
}
print(f"🔍 AGE FILTER: Tìm kiếm cho tuổi {age_filter}")
print(f"🔍 WHERE CLAUSE: {where_clause}")
else:
print(f"⚠️ KHÔNG CÓ AGE FILTER - Tìm tất cả chunks")
search_results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where_clause,
include=['documents', 'metadatas', 'distances']
)
print(f"\n{'='*60}")
print(f"📊 CHROMADB SEARCH RESULTS")
print(f"{'='*60}")
print(f"Query: {query}")
print(f"Age filter: {age_filter}")
print(f"Found {len(search_results['documents'][0]) if search_results['documents'] else 0} chunks")
print(f"{'='*60}")
if not search_results or not search_results['documents']:
logger.warning("Không tìm thấy kết quả nào")
return []
results = []
documents = search_results['documents'][0]
metadatas = search_results['metadatas'][0]
distances = search_results['distances'][0]
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
chunk_id = metadata.get('chunk_id', f'chunk_{i}')
title = metadata.get('title', 'No title')
age_range = metadata.get('age_range', 'Unknown')
age_min = metadata.get('age_min', 'N/A')
age_max = metadata.get('age_max', 'N/A')
content_type = metadata.get('content_type', 'text')
chapter = metadata.get('chapter', 'Unknown')
similarity = round(1 - distance, 3)
results.append({
'document': doc,
'metadata': metadata or {},
'distance': distance,
'similarity': similarity,
'rank': i + 1
})
print(f"\n{'='*60}")
logger.info(f"Tim thay {len(results)} ket qua cho query")
return results
except Exception as e:
logger.error(f"Loi tim kiem: {e}")
return []
def add_documents(self, documents, metadatas=None, ids=None):
"""Thêm documents vào ChromaDB"""
try:
if not documents:
logger.warning("Không có documents để thêm")
return False
if not ids:
ids = [str(uuid.uuid4()) for _ in documents]
if not metadatas:
metadatas = [{} for _ in documents]
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB")
embeddings = self.encode(documents, is_query=False)
self.collection.add(
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Đã thêm thành công {len(documents)} documents")
return True
except Exception as e:
logger.error(f"Lỗi thêm documents: {e}")
return False
def index_chunks(self, chunks):
"""Index các chunks dữ liệu vào ChromaDB"""
try:
if not chunks:
logger.warning("Không có chunks để index")
return False
documents = []
metadatas = []
ids = []
for chunk in chunks:
if not chunk.get('content'):
logger.warning(f"Chunk thiếu content: {chunk}")
continue
documents.append(chunk['content'])
metadata = chunk.get('metadata', {})
metadatas.append(metadata)
chunk_id = chunk.get('id') or str(uuid.uuid4())
ids.append(chunk_id)
if not documents:
logger.warning("Không có documents hợp lệ để index")
return False
batch_size = 100
total_batches = (len(documents) + batch_size - 1) // batch_size
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_metas = metadatas[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_num = (i // batch_size) + 1
logger.info(f"Đang xử lý batch {batch_num}/{total_batches} ({len(batch_docs)} items)")
success = self.add_documents(batch_docs, batch_metas, batch_ids)
if not success:
logger.error(f"Lỗi xử lý batch {batch_num}")
return False
logger.info(f"Đã index thành công {len(documents)} chunks")
return True
except Exception as e:
logger.error(f"Lỗi index chunks: {e}")
return False
def count(self):
"""Đếm số lượng documents trong collection"""
try:
return self.collection.count()
except Exception as e:
logger.error(f"Lỗi đếm documents: {e}")
return 0
def delete_collection(self):
"""Xóa collection hiện tại"""
try:
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}")
self.chroma_client.delete_collection(name=COLLECTION_NAME)
# Tạo lại collection với cosine similarity
self._create_cosine_collection()
logger.info("Đã tạo lại collection mới với cosine similarity")
return True
except Exception as e:
logger.error(f"Lỗi xóa collection: {e}")
return False
def get_collection_info(self):
"""Lấy thông tin về collection và distance function"""
try:
metadata = self.collection.metadata or {}
distance_func = metadata.get("hnsw:space", "l2")
return {
'collection_name': COLLECTION_NAME,
'distance_function': distance_func,
'total_documents': self.count(),
'metadata': metadata
}
except Exception as e:
logger.error(f"Lỗi lấy collection info: {e}")
return {'error': str(e)}
def verify_cosine_similarity(self):
"""Kiểm tra và xác nhận đang sử dụng cosine similarity"""
try:
info = self.get_collection_info()
distance_func = info.get('distance_function', 'unknown')
logger.info(f"Collection đang sử dụng distance function: {distance_func}")
if distance_func == "cosine":
logger.info("Xác nhận: Đang sử dụng cosine similarity")
return True
else:
logger.warning(f"Cảnh báo: Đang sử dụng {distance_func}, không phải cosine")
return False
except Exception as e:
logger.error(f"Lỗi verify cosine: {e}")
return False
def get_stats(self):
"""Lấy thống kê về collection"""
try:
total_count = self.count()
collection_info = self.get_collection_info()
sample_results = self.collection.get(limit=min(100, total_count))
content_types = {}
chapters = {}
age_groups = {}
if sample_results and sample_results.get('metadatas'):
for metadata in sample_results['metadatas']:
if not metadata:
continue
content_type = metadata.get('content_type', 'unknown')
content_types[content_type] = content_types.get(content_type, 0) + 1
chapter = metadata.get('chapter', 'unknown')
chapters[chapter] = chapters.get(chapter, 0) + 1
age_group = metadata.get('age_group', 'unknown')
age_groups[age_group] = age_groups.get(age_group, 0) + 1
return {
'total_documents': total_count,
'content_types': content_types,
'chapters': chapters,
'age_groups': age_groups,
'collection_name': COLLECTION_NAME,
'embedding_model': EMBEDDING_MODEL,
'distance_function': collection_info.get('distance_function', 'unknown'),
'using_cosine_similarity': collection_info.get('distance_function') == 'cosine'
}
except Exception as e:
logger.error(f"Lỗi lấy stats: {e}")
return {
'total_documents': 0,
'error': str(e)
}