Spaces:
Sleeping
Sleeping
import logging | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.config import Settings | |
import uuid | |
import os | |
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME | |
logger = logging.getLogger(__name__) | |
_embedding_model_instance = None | |
def get_embedding_model(): | |
"""Kiểm tra và khởi tạo embedding đảm bảo chỉ khởi tạo một lần""" | |
global _embedding_model_instance | |
if _embedding_model_instance is None: | |
logger.info("Khởi tạo EmbeddingModel instance lần đầu") | |
_embedding_model_instance = EmbeddingModel() | |
else: | |
logger.debug("Sử dụng EmbeddingModel instance đã có") | |
return _embedding_model_instance | |
class EmbeddingModel: | |
def __init__(self): | |
"""Khởi tạo embedding model và ChromaDB client""" | |
logger.info(f"Đang khởi tạo embedding model: {EMBEDDING_MODEL}") | |
try: | |
# Khởi tạo sentence transformer với trust_remote_code=True | |
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) | |
logger.info("Đã tải sentence transformer model") | |
except Exception as e: | |
logger.error(f"Lỗi khởi tạo model: {e}") | |
# Thử với cache_folder explicit | |
cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME', '/app/.cache/sentence-transformers') | |
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True) | |
logger.info("Đã tải sentence transformer model với cache folder explicit") | |
# SỬA: Khai báo biến persist_directory local để tránh lỗi scope | |
persist_directory = CHROMA_PERSIST_DIRECTORY | |
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi | |
try: | |
os.makedirs(persist_directory, exist_ok=True) | |
# Test ghi file để kiểm tra permission | |
test_file = os.path.join(persist_directory, 'test_permission.tmp') | |
with open(test_file, 'w') as f: | |
f.write('test') | |
os.remove(test_file) | |
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {persist_directory}") | |
except Exception as e: | |
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}") | |
# Fallback to /tmp directory | |
import tempfile | |
persist_directory = os.path.join(tempfile.gettempdir(), 'chroma_db') | |
os.makedirs(persist_directory, exist_ok=True) | |
logger.warning(f"Sử dụng thư mục tạm thời: {persist_directory}") | |
# Khởi tạo ChromaDB client với persistent storage | |
try: | |
self.chroma_client = chromadb.PersistentClient( | |
path=persist_directory, | |
settings=Settings( | |
anonymized_telemetry=False, | |
allow_reset=True | |
) | |
) | |
logger.info(f"Đã kết nối ChromaDB tại: {persist_directory}") | |
except Exception as e: | |
logger.error(f"Lỗi kết nối ChromaDB: {e}") | |
# Fallback to in-memory client | |
logger.warning("Fallback to in-memory ChromaDB client") | |
self.chroma_client = chromadb.Client() | |
# Lấy hoặc tạo collection với cosine similarity | |
try: | |
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME) | |
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items") | |
except Exception: | |
logger.info(f"Collection '{COLLECTION_NAME}' không tồn tại, tạo mới với cosine similarity...") | |
self.collection = self.chroma_client.create_collection( | |
name=COLLECTION_NAME, | |
metadata={ | |
"hnsw:space": "cosine", # Cosine distance | |
"hnsw:M": 16, # Optimize for accuracy | |
"hnsw:construction_ef": 100 | |
} | |
) | |
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}") | |
def _initialize_collection(self): | |
"""Khởi tạo collection với cosine similarity""" | |
try: | |
# Kiểm tra xem collection đã tồn tại chưa | |
existing_collections = [col.name for col in self.chroma_client.list_collections()] | |
if COLLECTION_NAME in existing_collections: | |
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME) | |
# Kiểm tra distance function hiện tại | |
current_metadata = self.collection.metadata or {} | |
current_space = current_metadata.get("hnsw:space", "l2") | |
if current_space != "cosine": | |
logger.warning(f"Collection hiện tại đang dùng {current_space}, cần migration sang cosine") | |
if self.collection.count() > 0: | |
self._migrate_to_cosine() | |
else: | |
# Collection trống, xóa và tạo lại | |
self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
self._create_cosine_collection() | |
else: | |
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với cosine similarity, {self.collection.count()} items") | |
else: | |
# Collection chưa tồn tại, tạo mới với cosine | |
self._create_cosine_collection() | |
except Exception as e: | |
logger.error(f"Lỗi khởi tạo collection: {e}") | |
# Fallback: tạo collection mới | |
self._create_cosine_collection() | |
def _create_cosine_collection(self): | |
"""Tạo collection mới với cosine similarity""" | |
try: | |
self.collection = self.chroma_client.create_collection( | |
name=COLLECTION_NAME, | |
metadata={"hnsw:space": "cosine"} | |
) | |
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}") | |
except Exception as e: | |
logger.error(f"Lỗi tạo collection với cosine: {e}") | |
# Fallback về collection mặc định | |
self.collection = self.chroma_client.get_or_create_collection(name=COLLECTION_NAME) | |
logger.warning("Đã fallback về collection mặc định (có thể dùng L2)") | |
def _migrate_to_cosine(self): | |
"""Migration collection từ L2 sang cosine""" | |
try: | |
logger.info("Bắt đầu migration collection sang cosine similarity...") | |
# Backup toàn bộ data | |
all_data = self.collection.get( | |
include=['documents', 'metadatas', 'embeddings'], | |
limit=self.collection.count() | |
) | |
if not all_data['documents']: | |
logger.info("Collection trống, chỉ cần tạo lại") | |
self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
self._create_cosine_collection() | |
return | |
# Xóa collection cũ và tạo mới với cosine | |
self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
self._create_cosine_collection() | |
# Restore data theo batch | |
documents = all_data['documents'] | |
metadatas = all_data['metadatas'] | |
embeddings = all_data['embeddings'] | |
ids = all_data['ids'] | |
batch_size = 100 | |
total_items = len(documents) | |
for i in range(0, total_items, batch_size): | |
batch_docs = documents[i:i + batch_size] | |
batch_metas = metadatas[i:i + batch_size] if metadatas else None | |
batch_embeds = embeddings[i:i + batch_size] if embeddings else None | |
batch_ids = ids[i:i + batch_size] | |
if batch_embeds: | |
# Có embeddings sẵn, dùng luôn | |
self.collection.add( | |
documents=batch_docs, | |
metadatas=batch_metas, | |
embeddings=batch_embeds, | |
ids=batch_ids | |
) | |
else: | |
# Tính lại embeddings | |
new_embeddings = self.encode(batch_docs, is_query=False) | |
self.collection.add( | |
documents=batch_docs, | |
metadatas=batch_metas, | |
embeddings=new_embeddings, | |
ids=batch_ids | |
) | |
logger.info(f"Migration progress: {min(i + batch_size, total_items)}/{total_items}") | |
logger.info(f"Migration hoàn thành! Đã chuyển {total_items} items sang cosine similarity") | |
except Exception as e: | |
logger.error(f"Lỗi migration: {e}") | |
# Tạo collection mới nếu migration thất bại | |
self._create_cosine_collection() | |
def test_embedding_quality(self): | |
try: | |
# Test cases | |
test_cases = [ | |
("query: Tháp dinh dưỡng cho trẻ", "passage: Tháp dinh dưỡng cho trẻ từ 6-11 tuổi"), | |
("query: dinh dưỡng", "passage: dinh dưỡng cho học sinh"), | |
("query: xin chào", "passage: Tháp dinh dưỡng cho trẻ") | |
] | |
for query_text, doc_text in test_cases: | |
# Encode | |
query_emb = self.model.encode([query_text], normalize_embeddings=True)[0] | |
doc_emb = self.model.encode([doc_text], normalize_embeddings=True)[0] | |
# Calculate cosine similarity manually | |
import numpy as np | |
similarity = np.dot(query_emb, doc_emb) | |
logger.info(f"Query: {query_text}") | |
logger.info(f"Doc: {doc_text}") | |
logger.info(f"Similarity: {similarity:.3f}") | |
logger.info(f"Query norm: {np.linalg.norm(query_emb):.3f}") | |
logger.info(f"Doc norm: {np.linalg.norm(doc_emb):.3f}") | |
logger.info("-" * 50) | |
except Exception as e: | |
logger.error(f"Test embedding error: {e}") | |
def _add_prefix_to_text(self, text, is_query=True): | |
# Clean text trước | |
text = text.strip() | |
# Kiểm tra xem text đã có prefix chưa | |
if text.startswith(('query:', 'passage:')): | |
return text | |
# Thêm prefix phù hợp | |
if is_query: | |
return f"query: {text}" | |
else: | |
return f"passage: {text}" | |
def encode(self, texts, is_query=True): | |
""" | |
Encode văn bản thành embeddings với proper normalization | |
""" | |
try: | |
if isinstance(texts, str): | |
texts = [texts] | |
# Thêm prefix cho texts (QUAN TRỌNG cho multilingual-e5-base) | |
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts] | |
logger.debug(f"Đang encode {len(processed_texts)} văn bản với prefix") | |
logger.debug(f"Sample processed text: {processed_texts[0][:100]}...") | |
# Encode với normalize_embeddings=True (QUAN TRỌNG!) | |
embeddings = self.model.encode( | |
processed_texts, | |
show_progress_bar=False, | |
normalize_embeddings=True # ✅ THÊM DÒNG NÀY | |
) | |
# Double-check normalization | |
import numpy as np | |
for i, emb in enumerate(embeddings[:2]): # Check first 2 embeddings | |
norm = np.linalg.norm(emb) | |
logger.debug(f"Embedding {i} norm: {norm}") | |
if abs(norm - 1.0) > 0.01: | |
logger.warning(f"Embedding {i} not properly normalized: norm = {norm}") | |
return embeddings.tolist() | |
except Exception as e: | |
logger.error(f"Lỗi encode văn bản: {e}") | |
raise | |
def search(self, query, top_k=5, age_filter=None): | |
"""Tìm kiếm văn bản tương tự trong ChromaDB""" | |
try: | |
query_embedding = self.encode(query, is_query=True)[0] | |
where_clause = None | |
if age_filter: | |
where_clause = { | |
"$and": [ | |
{"age_min": {"$lte": age_filter}}, | |
{"age_max": {"$gte": age_filter}} | |
] | |
} | |
print(f"🔍 AGE FILTER: Tìm kiếm cho tuổi {age_filter}") | |
print(f"🔍 WHERE CLAUSE: {where_clause}") | |
else: | |
print(f"⚠️ KHÔNG CÓ AGE FILTER - Tìm tất cả chunks") | |
search_results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=top_k, | |
where=where_clause, | |
include=['documents', 'metadatas', 'distances'] | |
) | |
print(f"\n{'='*60}") | |
print(f"📊 CHROMADB SEARCH RESULTS") | |
print(f"{'='*60}") | |
print(f"Query: {query}") | |
print(f"Age filter: {age_filter}") | |
print(f"Found {len(search_results['documents'][0]) if search_results['documents'] else 0} chunks") | |
print(f"{'='*60}") | |
if not search_results or not search_results['documents']: | |
logger.warning("Không tìm thấy kết quả nào") | |
return [] | |
results = [] | |
documents = search_results['documents'][0] | |
metadatas = search_results['metadatas'][0] | |
distances = search_results['distances'][0] | |
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)): | |
chunk_id = metadata.get('chunk_id', f'chunk_{i}') | |
title = metadata.get('title', 'No title') | |
age_range = metadata.get('age_range', 'Unknown') | |
age_min = metadata.get('age_min', 'N/A') | |
age_max = metadata.get('age_max', 'N/A') | |
content_type = metadata.get('content_type', 'text') | |
chapter = metadata.get('chapter', 'Unknown') | |
similarity = round(1 - distance, 3) | |
results.append({ | |
'document': doc, | |
'metadata': metadata or {}, | |
'distance': distance, | |
'similarity': similarity, | |
'rank': i + 1 | |
}) | |
print(f"\n{'='*60}") | |
logger.info(f"Tim thay {len(results)} ket qua cho query") | |
return results | |
except Exception as e: | |
logger.error(f"Loi tim kiem: {e}") | |
return [] | |
def add_documents(self, documents, metadatas=None, ids=None): | |
"""Thêm documents vào ChromaDB""" | |
try: | |
if not documents: | |
logger.warning("Không có documents để thêm") | |
return False | |
if not ids: | |
ids = [str(uuid.uuid4()) for _ in documents] | |
if not metadatas: | |
metadatas = [{} for _ in documents] | |
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB") | |
embeddings = self.encode(documents, is_query=False) | |
self.collection.add( | |
embeddings=embeddings, | |
documents=documents, | |
metadatas=metadatas, | |
ids=ids | |
) | |
logger.info(f"Đã thêm thành công {len(documents)} documents") | |
return True | |
except Exception as e: | |
logger.error(f"Lỗi thêm documents: {e}") | |
return False | |
def index_chunks(self, chunks): | |
"""Index các chunks dữ liệu vào ChromaDB""" | |
try: | |
if not chunks: | |
logger.warning("Không có chunks để index") | |
return False | |
documents = [] | |
metadatas = [] | |
ids = [] | |
for chunk in chunks: | |
if not chunk.get('content'): | |
logger.warning(f"Chunk thiếu content: {chunk}") | |
continue | |
documents.append(chunk['content']) | |
metadata = chunk.get('metadata', {}) | |
metadatas.append(metadata) | |
chunk_id = chunk.get('id') or str(uuid.uuid4()) | |
ids.append(chunk_id) | |
if not documents: | |
logger.warning("Không có documents hợp lệ để index") | |
return False | |
batch_size = 100 | |
total_batches = (len(documents) + batch_size - 1) // batch_size | |
for i in range(0, len(documents), batch_size): | |
batch_docs = documents[i:i + batch_size] | |
batch_metas = metadatas[i:i + batch_size] | |
batch_ids = ids[i:i + batch_size] | |
batch_num = (i // batch_size) + 1 | |
logger.info(f"Đang xử lý batch {batch_num}/{total_batches} ({len(batch_docs)} items)") | |
success = self.add_documents(batch_docs, batch_metas, batch_ids) | |
if not success: | |
logger.error(f"Lỗi xử lý batch {batch_num}") | |
return False | |
logger.info(f"Đã index thành công {len(documents)} chunks") | |
return True | |
except Exception as e: | |
logger.error(f"Lỗi index chunks: {e}") | |
return False | |
def count(self): | |
"""Đếm số lượng documents trong collection""" | |
try: | |
return self.collection.count() | |
except Exception as e: | |
logger.error(f"Lỗi đếm documents: {e}") | |
return 0 | |
def delete_collection(self): | |
"""Xóa collection hiện tại""" | |
try: | |
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}") | |
self.chroma_client.delete_collection(name=COLLECTION_NAME) | |
# Tạo lại collection với cosine similarity | |
self._create_cosine_collection() | |
logger.info("Đã tạo lại collection mới với cosine similarity") | |
return True | |
except Exception as e: | |
logger.error(f"Lỗi xóa collection: {e}") | |
return False | |
def get_collection_info(self): | |
"""Lấy thông tin về collection và distance function""" | |
try: | |
metadata = self.collection.metadata or {} | |
distance_func = metadata.get("hnsw:space", "l2") | |
return { | |
'collection_name': COLLECTION_NAME, | |
'distance_function': distance_func, | |
'total_documents': self.count(), | |
'metadata': metadata | |
} | |
except Exception as e: | |
logger.error(f"Lỗi lấy collection info: {e}") | |
return {'error': str(e)} | |
def verify_cosine_similarity(self): | |
"""Kiểm tra và xác nhận đang sử dụng cosine similarity""" | |
try: | |
info = self.get_collection_info() | |
distance_func = info.get('distance_function', 'unknown') | |
logger.info(f"Collection đang sử dụng distance function: {distance_func}") | |
if distance_func == "cosine": | |
logger.info("Xác nhận: Đang sử dụng cosine similarity") | |
return True | |
else: | |
logger.warning(f"Cảnh báo: Đang sử dụng {distance_func}, không phải cosine") | |
return False | |
except Exception as e: | |
logger.error(f"Lỗi verify cosine: {e}") | |
return False | |
def get_stats(self): | |
"""Lấy thống kê về collection""" | |
try: | |
total_count = self.count() | |
collection_info = self.get_collection_info() | |
sample_results = self.collection.get(limit=min(100, total_count)) | |
content_types = {} | |
chapters = {} | |
age_groups = {} | |
if sample_results and sample_results.get('metadatas'): | |
for metadata in sample_results['metadatas']: | |
if not metadata: | |
continue | |
content_type = metadata.get('content_type', 'unknown') | |
content_types[content_type] = content_types.get(content_type, 0) + 1 | |
chapter = metadata.get('chapter', 'unknown') | |
chapters[chapter] = chapters.get(chapter, 0) + 1 | |
age_group = metadata.get('age_group', 'unknown') | |
age_groups[age_group] = age_groups.get(age_group, 0) + 1 | |
return { | |
'total_documents': total_count, | |
'content_types': content_types, | |
'chapters': chapters, | |
'age_groups': age_groups, | |
'collection_name': COLLECTION_NAME, | |
'embedding_model': EMBEDDING_MODEL, | |
'distance_function': collection_info.get('distance_function', 'unknown'), | |
'using_cosine_similarity': collection_info.get('distance_function') == 'cosine' | |
} | |
except Exception as e: | |
logger.error(f"Lỗi lấy stats: {e}") | |
return { | |
'total_documents': 0, | |
'error': str(e) | |
} |