Spaces:
Sleeping
Sleeping
# retrieval.py | |
import json | |
import re | |
import numpy as np | |
import faiss # Thư viện này cần được cài đặt (faiss-cpu hoặc faiss-gpu) | |
from collections import defaultdict | |
from typing import List, Dict, Any, Tuple, Optional, Callable | |
# --- 1. CÁC HẰNG SỐ VÀ MAP CHO LOẠI PHƯƠNG TIỆN --- | |
VEHICLE_TYPE_MAP: Dict[str, List[str]] = { | |
"xe máy": ["xe máy", "xe mô tô", "xe gắn máy", "xe máy điện", "mô tô hai bánh", "mô tô ba bánh"], | |
"ô tô": ["xe ô tô", "ô tô con", "ô tô tải", "ô tô khách", "xe con", "xe tải", "xe khách", "ô tô điện"], | |
"xe cơ giới": ["xe cơ giới"], # Loại chung hơn | |
"xe thô sơ": ["xe thô sơ", "xe đạp", "xích lô", "xe đạp điện"], # Thêm xe đạp điện vào xe thô sơ | |
"người đi bộ": ["người đi bộ"], | |
# Thêm các loại khác nếu cần, ví dụ "xe chuyên dùng" nếu muốn tách riêng | |
} | |
# --- 2. HÀM TIỆN ÍCH --- | |
def get_standardized_vehicle_type(text_input: Optional[str]) -> Optional[str]: | |
""" | |
Suy luận và chuẩn hóa loại phương tiện từ một chuỗi text. | |
Ưu tiên các loại cụ thể trước, sau đó đến các loại chung hơn. | |
""" | |
if not text_input or not isinstance(text_input, str): | |
return None | |
text_lower = text_input.lower() | |
# Ưu tiên kiểm tra "xe máy" và các biến thể trước "xe cơ giới" | |
# Cẩn thận với "xe máy chuyên dùng" | |
is_moto = any(re.search(r'\b' + re.escape(kw) + r'\b', text_lower) for kw in VEHICLE_TYPE_MAP["xe máy"]) | |
if is_moto: | |
# Tránh trường hợp "xe máy chuyên dùng" bị tính là "xe máy" | |
# nếu "xe máy chuyên dùng" là một category riêng hoặc cần xử lý đặc biệt. | |
# Hiện tại, nếu có "chuyên dùng" và "xe máy", nó vẫn sẽ là "xe máy" nếu không có category "xe máy chuyên dùng". | |
if "chuyên dùng" in text_lower and "xe máy chuyên dùng" not in text_lower: # Logic này có thể cần review tùy theo định nghĩa | |
# Nếu bạn có category "xe máy chuyên dùng" trong VEHICLE_TYPE_MAP, nó sẽ được xử lý ở vòng lặp sau. | |
# Nếu không, "xe máy chuyên dùng" vẫn có thể bị coi là "xe máy". | |
pass # Để nó rơi xuống các kiểm tra khác nếu cần | |
else: | |
return "xe máy" | |
# Kiểm tra "ô tô" và các biến thể | |
is_car = any(re.search(r'\b' + re.escape(kw) + r'\b', text_lower) for kw in VEHICLE_TYPE_MAP["ô tô"]) | |
if is_car: | |
return "ô tô" | |
# Kiểm tra các loại chung hơn hoặc khác | |
# Thứ tự này quan trọng nếu có sự chồng chéo (ví dụ: "xe cơ giới" bao gồm "ô tô", "xe máy") | |
# Đã xử lý ô tô, xe máy ở trên nên thứ tự ở đây ít quan trọng hơn giữa các loại còn lại. | |
for standard_type, keywords in VEHICLE_TYPE_MAP.items(): | |
if standard_type in ["xe máy", "ô tô"]: # Đã xử lý ở trên | |
continue | |
if any(re.search(r'\b' + re.escape(kw) + r'\b', text_lower) for kw in keywords): | |
return standard_type | |
return None # Trả về None nếu không khớp rõ ràng | |
def tokenize_vi_for_bm25_setup(text: str) -> List[str]: | |
""" | |
Tokenize tiếng Việt đơn giản cho BM25: lowercase, loại bỏ dấu câu, split theo khoảng trắng. | |
""" | |
text = text.lower() | |
text = re.sub(r'[^\w\s]', '', text) # Loại bỏ các ký tự không phải chữ, số, hoặc khoảng trắng | |
return text.split() | |
# --- 3. HÀM XỬ LÝ DỮ LIỆU LUẬT TỪ JSON SANG CHUNKS --- | |
def process_law_data_to_chunks(structured_data_input: Any) -> List[Dict[str, Any]]: | |
""" | |
Xử lý dữ liệu luật có cấu trúc (từ JSON) thành một danh sách phẳng các "chunks". | |
Mỗi chunk bao gồm "text" và "metadata". | |
""" | |
flat_list: List[Dict[str, Any]] = [] | |
if isinstance(structured_data_input, dict) and "article" in structured_data_input: | |
articles_list: List[Dict[str, Any]] = [structured_data_input] | |
elif isinstance(structured_data_input, list): | |
articles_list = structured_data_input | |
else: | |
print("Lỗi: Dữ liệu đầu vào không phải là danh sách các Điều luật hoặc một đối tượng Điều luật.") | |
return flat_list | |
for article_data in articles_list: | |
if not isinstance(article_data, dict): | |
# print(f"Cảnh báo: Bỏ qua mục không phải dict trong articles_list: {article_data}") | |
continue | |
raw_article_title = article_data.get("article_title", "") | |
article_metadata_base = { | |
"source_document": article_data.get("source_document"), | |
"article": article_data.get("article"), # Số điều, ví dụ "Điều 5" | |
"article_title": raw_article_title # Tiêu đề của Điều, ví dụ "Xử phạt người điều khiển..." | |
} | |
article_level_vehicle_type = get_standardized_vehicle_type(raw_article_title) or "không xác định" | |
clauses = article_data.get("clauses", []) | |
if not isinstance(clauses, list): | |
# print(f"Cảnh báo: 'clauses' trong Điều {article_data.get('article')} không phải list.") | |
continue | |
for clause_data in clauses: | |
if not isinstance(clause_data, dict): | |
# print(f"Cảnh báo: Bỏ qua mục không phải dict trong 'clauses' của Điều {article_data.get('article')}") | |
continue | |
clause_metadata_base = article_metadata_base.copy() | |
clause_number = clause_data.get("clause_number") # Số khoản, ví dụ "1" hoặc "1a" | |
clause_metadata_base.update({"clause_number": clause_number}) | |
clause_summary_data = clause_data.get("clause_metadata_summary") | |
if isinstance(clause_summary_data, dict): | |
clause_metadata_base["overall_fine_note_for_clause_text"] = clause_summary_data.get("overall_fine_note_for_clause") | |
clause_metadata_base["overall_points_deduction_note_for_clause_text"] = clause_summary_data.get("overall_points_deduction_note_for_clause") | |
points_in_clause = clause_data.get("points_in_clause", []) | |
if not isinstance(points_in_clause, list): | |
# print(f"Cảnh báo: 'points_in_clause' trong Khoản {clause_number} của Điều {article_data.get('article')} không phải list.") | |
continue | |
if points_in_clause: # Nếu có các Điểm trong Khoản | |
for point_data in points_in_clause: | |
if not isinstance(point_data, dict): | |
# print(f"Cảnh báo: Bỏ qua mục không phải dict trong 'points_in_clause'.") | |
continue | |
chunk_text = point_data.get("point_text_original") | |
if not chunk_text: chunk_text = point_data.get("violation_description_summary") # Fallback | |
if not chunk_text: continue # Bỏ qua nếu không có text | |
current_chunk_metadata = clause_metadata_base.copy() | |
current_chunk_metadata["point_id"] = point_data.get("point_id") # ID của điểm, ví dụ "a" | |
current_chunk_metadata["violation_description_summary"] = point_data.get("violation_description_summary") | |
# 1. Làm phẳng 'penalties_detail' | |
current_chunk_metadata.update({ | |
'has_fine': False, 'has_points_deduction': False, | |
'has_license_suspension': False, 'has_confiscation': False | |
}) | |
penalty_types_for_this_point: List[str] = [] | |
points_values: List[Any] = [] | |
s_min_months: List[float] = [] # Sửa thành float để chứa giá trị tháng lẻ | |
s_max_months: List[float] = [] | |
confiscation_items_list: List[str] = [] | |
penalties = point_data.get("penalties_detail", []) | |
if isinstance(penalties, list): | |
for p_item in penalties: | |
if not isinstance(p_item, dict): continue | |
p_type_original, p_details = p_item.get("penalty_type"), p_item.get("details", {}) | |
if p_type_original: penalty_types_for_this_point.append(str(p_type_original)) | |
if not isinstance(p_details, dict): continue | |
p_type_lower = str(p_type_original).lower() | |
if "phạt tiền" in p_type_lower: | |
current_chunk_metadata['has_fine'] = True | |
if p_details.get("individual_fine_min") is not None: current_chunk_metadata['individual_fine_min'] = p_details.get("individual_fine_min") | |
if p_details.get("individual_fine_max") is not None: current_chunk_metadata['individual_fine_max'] = p_details.get("individual_fine_max") | |
if "trừ điểm" in p_type_lower or "điểm giấy phép lái xe" in p_type_lower : # Mở rộng kiểm tra | |
current_chunk_metadata['has_points_deduction'] = True | |
if p_details.get("points_deducted") is not None: points_values.append(p_details.get("points_deducted")) | |
if "tước quyền sử dụng giấy phép lái xe" in p_type_lower or "tước bằng lái" in p_type_lower: | |
current_chunk_metadata['has_license_suspension'] = True | |
if p_details.get("suspension_duration_min_months") is not None: s_min_months.append(float(p_details.get("suspension_duration_min_months"))) | |
if p_details.get("suspension_duration_max_months") is not None: s_max_months.append(float(p_details.get("suspension_duration_max_months"))) | |
if "tịch thu" in p_type_lower: | |
current_chunk_metadata['has_confiscation'] = True | |
if p_details.get("confiscation_item"): confiscation_items_list.append(str(p_details.get("confiscation_item"))) | |
if penalty_types_for_this_point: current_chunk_metadata['penalty_types_str'] = ", ".join(sorted(list(set(penalty_types_for_this_point)))) | |
if points_values: current_chunk_metadata['points_deducted_values_str'] = ", ".join(map(str, sorted(list(set(points_values))))) | |
if s_min_months: current_chunk_metadata['suspension_min_months'] = min(s_min_months) | |
if s_max_months: current_chunk_metadata['suspension_max_months'] = max(s_max_months) | |
if confiscation_items_list: current_chunk_metadata['confiscation_items_str'] = ", ".join(sorted(list(set(confiscation_items_list)))) | |
# 2. Thông tin tốc độ | |
if point_data.get("speed_limit") is not None: current_chunk_metadata['speed_limit_value'] = point_data.get("speed_limit") | |
if point_data.get("speed_limit_min") is not None: current_chunk_metadata['speed_limit_min_value'] = point_data.get("speed_limit_min") | |
if point_data.get("speed_type"): current_chunk_metadata['speed_category'] = point_data.get("speed_type") | |
# 3. Thông tin loại xe/đường cụ thể từ point_data | |
speed_limits_extra = point_data.get("speed_limits_by_vehicle_type_and_road_type", []) | |
point_specific_vehicle_types_raw: List[str] = [] | |
point_specific_road_types: List[str] = [] | |
if isinstance(speed_limits_extra, list): | |
for sl_item in speed_limits_extra: | |
if isinstance(sl_item, dict): | |
if sl_item.get("vehicle_type"): point_specific_vehicle_types_raw.append(str(sl_item.get("vehicle_type")).lower()) | |
if sl_item.get("road_type"): point_specific_road_types.append(str(sl_item.get("road_type")).lower()) | |
if point_specific_vehicle_types_raw: current_chunk_metadata['point_specific_vehicle_types_str'] = ", ".join(sorted(list(set(point_specific_vehicle_types_raw)))) | |
if point_specific_road_types: current_chunk_metadata['point_specific_road_types_str'] = ", ".join(sorted(list(set(point_specific_road_types)))) | |
# 4. Gán 'applicable_vehicle_type' chính | |
derived_vehicle_type_from_point = "không xác định" | |
if point_specific_vehicle_types_raw: | |
normalized_types_from_point_data = set() | |
for vt_raw in set(point_specific_vehicle_types_raw): | |
standard_type = get_standardized_vehicle_type(vt_raw) | |
if standard_type: normalized_types_from_point_data.add(standard_type) | |
if len(normalized_types_from_point_data) == 1: | |
derived_vehicle_type_from_point = list(normalized_types_from_point_data)[0] | |
elif len(normalized_types_from_point_data) > 1: | |
# Xử lý trường hợp có nhiều loại xe đã chuẩn hóa | |
if "ô tô" in normalized_types_from_point_data and "xe máy" in normalized_types_from_point_data: | |
derived_vehicle_type_from_point = "ô tô và xe máy" | |
# Có thể thêm các logic ưu tiên khác nếu cần (ví dụ: nếu có "xe cơ giới" và "ô tô" -> "ô tô") | |
elif "ô tô" in normalized_types_from_point_data: derived_vehicle_type_from_point = "ô tô" | |
elif "xe máy" in normalized_types_from_point_data: derived_vehicle_type_from_point = "xe máy" | |
else: # Nếu không có cặp ưu tiên rõ ràng | |
derived_vehicle_type_from_point = "nhiều loại cụ thể" # Hoặc join các loại: ", ".join(sorted(list(normalized_types_from_point_data))) | |
# Ưu tiên thông tin từ point, nếu không rõ ràng thì mới dùng từ article | |
# Các loại được coi là rõ ràng: "ô tô", "xe máy", "xe cơ giới", "xe thô sơ", "người đi bộ", "ô tô và xe máy" | |
clear_types = ["ô tô", "xe máy", "xe cơ giới", "xe thô sơ", "người đi bộ", "ô tô và xe máy"] | |
if derived_vehicle_type_from_point not in clear_types and derived_vehicle_type_from_point != "không xác định": | |
# Nếu là "nhiều loại cụ thể" mà không phải "ô tô và xe máy" hoặc các loại không rõ ràng khác | |
current_chunk_metadata['applicable_vehicle_type'] = article_level_vehicle_type | |
elif derived_vehicle_type_from_point == "không xác định": | |
current_chunk_metadata['applicable_vehicle_type'] = article_level_vehicle_type | |
else: # Các trường hợp còn lại (rõ ràng từ point) | |
current_chunk_metadata['applicable_vehicle_type'] = derived_vehicle_type_from_point | |
# 5. Các trường khác | |
if point_data.get("applies_to"): current_chunk_metadata['applies_to_context'] = point_data.get("applies_to") | |
if point_data.get("location"): current_chunk_metadata['specific_location_info'] = point_data.get("location") | |
final_metadata_cleaned = {k: v for k, v in current_chunk_metadata.items() if v is not None} | |
flat_list.append({ "text": chunk_text, "metadata": final_metadata_cleaned }) | |
else: # Nếu Khoản không có Điểm nào, thì text của Khoản là một chunk | |
chunk_text = clause_data.get("clause_text_original") | |
if chunk_text: # Chỉ thêm nếu có text | |
current_clause_level_metadata = clause_metadata_base.copy() | |
# Với chunk cấp độ Khoản, loại xe sẽ lấy từ article_level_vehicle_type | |
current_clause_level_metadata['applicable_vehicle_type'] = article_level_vehicle_type | |
# Kiểm tra xem có thông tin phạt tiền tổng thể ở Khoản không | |
if current_clause_level_metadata.get("overall_fine_note_for_clause_text"): | |
current_clause_level_metadata['has_fine_clause_level'] = True # Dùng để boost nếu cần | |
final_metadata_cleaned = {k:v for k,v in current_clause_level_metadata.items() if v is not None} | |
flat_list.append({ "text": chunk_text, "metadata": final_metadata_cleaned }) | |
return flat_list | |
# --- 4. HÀM PHÂN TÍCH QUERY --- | |
def analyze_query(query_text: str) -> Dict[str, Any]: | |
""" | |
Phân tích query để xác định ý định của người dùng (ví dụ: hỏi về phạt tiền, điểm, loại xe...). | |
""" | |
query_lower = query_text.lower() | |
analysis: Dict[str, Any] = { | |
"mentions_fine": bool(re.search(r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính|nộp phạt', query_lower)), | |
"mentions_points": bool(re.search(r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx|giấy phép lái xe', query_lower)), | |
"mentions_suspension": bool(re.search(r'tước bằng|tước giấy phép lái xe|giam bằng|treo bằng|thu bằng lái|tước quyền sử dụng', query_lower)), | |
"mentions_confiscation": bool(re.search(r'tịch thu|thu xe|thu phương tiện', query_lower)), | |
"mentions_max_speed": bool(re.search(r'tốc độ tối đa|giới hạn tốc độ|chạy quá tốc độ|vượt tốc độ', query_lower)), | |
"mentions_min_speed": bool(re.search(r'tốc độ tối thiểu|chạy chậm hơn', query_lower)), | |
"mentions_safe_distance": bool(re.search(r'khoảng cách an toàn|cự ly an toàn|cự ly tối thiểu|giữ khoảng cách', query_lower)), | |
"mentions_remedial_measures": bool(re.search(r'biện pháp khắc phục|khắc phục hậu quả', query_lower)), | |
"vehicle_type_query": None, # Sẽ được điền bằng loại xe chuẩn hóa nếu có | |
} | |
# Sử dụng lại VEHICLE_TYPE_MAP để chuẩn hóa loại xe trong query | |
# Ưu tiên loại cụ thể trước | |
queried_vehicle_standardized = get_standardized_vehicle_type(query_lower) | |
if queried_vehicle_standardized: | |
analysis["vehicle_type_query"] = queried_vehicle_standardized | |
return analysis | |
# --- 5. HÀM TÌM KIẾM KẾT HỢP (HYBRID SEARCH) --- | |
def search_relevant_laws( | |
query_text: str, | |
embedding_model, # Kiểu dữ liệu: SentenceTransformer model | |
faiss_index, # Kiểu dữ liệu: faiss.Index | |
chunks_data: List[Dict[str, Any]], | |
bm25_model, # Kiểu dữ liệu: BM25Okapi model | |
k: int = 5, | |
initial_k_multiplier: int = 10, | |
rrf_k_constant: int = 60, | |
# Trọng số cho các loại boost (có thể điều chỉnh) | |
boost_fine: float = 0.15, | |
boost_points: float = 0.15, | |
boost_both_fine_points: float = 0.10, # Boost thêm nếu khớp cả hai | |
boost_vehicle_type: float = 0.20, | |
boost_suspension: float = 0.18, | |
boost_confiscation: float = 0.18, | |
boost_max_speed: float = 0.15, | |
boost_min_speed: float = 0.15, | |
boost_safe_distance: float = 0.12, | |
boost_remedial_measures: float = 0.10 | |
) -> List[Dict[str, Any]]: | |
""" | |
Thực hiện tìm kiếm kết hợp (semantic + keyword) với RRF và metadata re-ranking. | |
""" | |
if k <= 0: | |
print("Lỗi: k (số lượng kết quả) phải là số dương.") | |
return [] | |
if not chunks_data: | |
print("Lỗi: chunks_data rỗng, không thể tìm kiếm.") | |
return [] | |
print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'") | |
# === 1. Phân tích Query === | |
query_analysis = analyze_query(query_text) | |
# print(f" Phân tích query: {json.dumps(query_analysis, ensure_ascii=False, indent=2)}") | |
num_vectors_in_index = faiss_index.ntotal | |
if num_vectors_in_index == 0: | |
print("Lỗi: FAISS index rỗng.") | |
return [] | |
# Số lượng ứng viên ban đầu từ mỗi retriever | |
num_candidates_each_retriever = max(min(k * initial_k_multiplier, num_vectors_in_index), min(k, num_vectors_in_index)) | |
if num_candidates_each_retriever == 0: | |
print(f" Không thể lấy đủ số lượng ứng viên ban đầu (num_candidates = 0).") | |
return [] | |
# === 2. Semantic Search (FAISS) === | |
semantic_indices_raw: np.ndarray = np.array([[]], dtype=int) # Khởi tạo rỗng | |
try: | |
query_embedding_tensor = embedding_model.encode( | |
[query_text], convert_to_tensor=True, device=embedding_model.device | |
) | |
query_embedding_np = query_embedding_tensor.cpu().numpy().astype('float32') | |
faiss.normalize_L2(query_embedding_np) # Chuẩn hóa vector query | |
# print(f" Đã tạo và chuẩn hóa vector truy vấn shape: {query_embedding_np.shape}") | |
# print(f" Tìm kiếm {num_candidates_each_retriever} kết quả ngữ nghĩa (FAISS)...") | |
_, semantic_indices_raw = faiss_index.search(query_embedding_np, num_candidates_each_retriever) | |
# print(f" ✅ Tìm kiếm ngữ nghĩa (FAISS) hoàn tất.") | |
except Exception as e: | |
print(f"Lỗi khi tìm kiếm ngữ nghĩa (FAISS): {e}") | |
# semantic_indices_raw đã được khởi tạo là rỗng | |
# === 3. Keyword Search (BM25) === | |
# print(f" Tìm kiếm {num_candidates_each_retriever} kết quả từ khóa (BM25)...") | |
tokenized_query_bm25 = tokenize_vi_for_bm25_setup(query_text) | |
top_bm25_results: List[Dict[str, Any]] = [] | |
try: | |
if bm25_model and tokenized_query_bm25: | |
all_bm25_scores = bm25_model.get_scores(tokenized_query_bm25) | |
# Lấy chỉ số và score cho các document có score > 0 | |
bm25_results_with_indices = [ | |
{'index': i, 'score': score} for i, score in enumerate(all_bm25_scores) if score > 0 | |
] | |
# Sắp xếp theo score giảm dần | |
bm25_results_with_indices.sort(key=lambda x: x['score'], reverse=True) | |
top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever] | |
# print(f" ✅ Tìm kiếm từ khóa (BM25) hoàn tất, tìm thấy {len(top_bm25_results)} ứng viên.") | |
else: | |
# print(" Cảnh báo: BM25 model hoặc tokenized query không hợp lệ, bỏ qua BM25.") | |
pass | |
except Exception as e: | |
print(f"Lỗi khi tìm kiếm từ khóa (BM25): {e}") | |
# === 4. Result Fusion (Reciprocal Rank Fusion - RRF) === | |
# print(f" Kết hợp kết quả từ FAISS và BM25 bằng RRF (k_const={rrf_k_constant})...") | |
rrf_scores: Dict[int, float] = defaultdict(float) | |
all_retrieved_indices_set: set[int] = set() | |
if semantic_indices_raw.size > 0: | |
for rank, doc_idx_int in enumerate(semantic_indices_raw[0]): | |
doc_idx = int(doc_idx_int) # Đảm bảo là int | |
if 0 <= doc_idx < len(chunks_data): # Kiểm tra doc_idx hợp lệ với chunks_data | |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
all_retrieved_indices_set.add(doc_idx) | |
for rank, item in enumerate(top_bm25_results): | |
doc_idx = item['index'] | |
if 0 <= doc_idx < len(chunks_data): | |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank) | |
all_retrieved_indices_set.add(doc_idx) | |
fused_initial_results: List[Dict[str, Any]] = [] | |
for doc_idx in all_retrieved_indices_set: | |
fused_initial_results.append({ | |
'index': doc_idx, | |
'fused_score': rrf_scores[doc_idx] | |
}) | |
fused_initial_results.sort(key=lambda x: x['fused_score'], reverse=True) | |
# print(f" ✅ Kết hợp RRF hoàn tất, có {len(fused_initial_results)} ứng viên duy nhất.") | |
# === 5. Xử lý Metadata, Lọc và Tái xếp hạng cuối cùng === | |
final_processed_results: List[Dict[str, Any]] = [] | |
# Xử lý metadata cho số lượng ứng viên lớn hơn để có đủ lựa chọn sau khi lọc | |
num_to_process_metadata = min(len(fused_initial_results), num_candidates_each_retriever * 2 if num_candidates_each_retriever > 0 else k * 3) | |
# print(f" Xử lý metadata và tính điểm cuối cùng cho top {num_to_process_metadata} ứng viên lai ghép...") | |
for rank_idx, res_item in enumerate(fused_initial_results[:num_to_process_metadata]): | |
result_index = res_item['index'] | |
base_score_from_fusion = res_item['fused_score'] | |
metadata_boost_components: Dict[str, float] = defaultdict(float) | |
passes_all_strict_filters = True | |
try: | |
original_chunk = chunks_data[result_index] | |
chunk_metadata = original_chunk.get('metadata', {}) | |
chunk_text_lower = original_chunk.get('text', '').lower() | |
# 5.1 & 5.2: Tiền phạt và Điểm phạt | |
has_fine_info_in_chunk = chunk_metadata.get("has_fine", False) or chunk_metadata.get("has_fine_clause_level", False) | |
has_points_info_in_chunk = chunk_metadata.get("has_points_deduction", False) | |
if query_analysis["mentions_fine"]: | |
if has_fine_info_in_chunk: | |
metadata_boost_components["fine"] += boost_fine | |
elif not query_analysis["mentions_points"]: # Chỉ hỏi tiền mà chunk không có -> lọc | |
passes_all_strict_filters = False | |
if query_analysis["mentions_points"]: | |
if has_points_info_in_chunk: | |
metadata_boost_components["points"] += boost_points | |
elif not query_analysis["mentions_fine"]: # Chỉ hỏi điểm mà chunk không có -> lọc | |
passes_all_strict_filters = False | |
if query_analysis["mentions_fine"] and query_analysis["mentions_points"]: | |
if not has_fine_info_in_chunk and not has_points_info_in_chunk: # Query hỏi cả hai mà chunk ko có cả hai | |
passes_all_strict_filters = False | |
elif has_fine_info_in_chunk and has_points_info_in_chunk: | |
metadata_boost_components["both_fine_points"] += boost_both_fine_points | |
# 5.3. Loại xe | |
queried_vehicle = query_analysis["vehicle_type_query"] | |
if queried_vehicle: | |
applicable_vehicle_meta = chunk_metadata.get("applicable_vehicle_type", "").lower() | |
point_specific_vehicles_meta = chunk_metadata.get("point_specific_vehicle_types_str", "").lower() | |
article_title_lower = chunk_metadata.get("article_title", "").lower() | |
match_vehicle = False | |
if queried_vehicle in applicable_vehicle_meta: match_vehicle = True | |
elif queried_vehicle in point_specific_vehicles_meta: match_vehicle = True | |
# Kiểm tra xem queried_vehicle (đã chuẩn hóa) có trong article_title không | |
elif queried_vehicle in article_title_lower: match_vehicle = True # Đơn giản hóa, có thể dùng regex nếu cần chính xác hơn | |
# Kiểm tra xem queried_vehicle (đã chuẩn hóa) có trong chunk_text không | |
elif queried_vehicle in chunk_text_lower: match_vehicle = True | |
if match_vehicle: | |
metadata_boost_components["vehicle_type"] += boost_vehicle_type | |
else: | |
# Logic lọc: nếu query có loại xe cụ thể, mà applicable_vehicle_type của chunk là một loại khác | |
# và không phải là "không xác định", "nhiều loại cụ thể", hoặc loại cha chung ("xe cơ giới" cho "ô tô") | |
if applicable_vehicle_meta and \ | |
applicable_vehicle_meta not in ["không xác định", "nhiều loại cụ thể", "ô tô và xe máy"] and \ | |
not (applicable_vehicle_meta == "xe cơ giới" and queried_vehicle in ["ô tô", "xe máy"]): | |
passes_all_strict_filters = False | |
# 5.4. Tước quyền sử dụng GPLX | |
if query_analysis["mentions_suspension"] and chunk_metadata.get("has_license_suspension"): | |
metadata_boost_components["suspension"] += boost_suspension | |
# 5.5. Tịch thu | |
if query_analysis["mentions_confiscation"] and chunk_metadata.get("has_confiscation"): | |
metadata_boost_components["confiscation"] += boost_confiscation | |
# 5.6. Tốc độ tối đa | |
if query_analysis["mentions_max_speed"]: | |
if chunk_metadata.get("speed_limit_value") is not None or \ | |
"tốc độ tối đa" in chunk_metadata.get("speed_category","").lower() or \ | |
any(kw in chunk_text_lower for kw in ["quá tốc độ", "tốc độ tối đa", "vượt tốc độ quy định"]): | |
metadata_boost_components["max_speed"] += boost_max_speed | |
# 5.7. Tốc độ tối thiểu | |
if query_analysis["mentions_min_speed"]: | |
if chunk_metadata.get("speed_limit_min_value") is not None or \ | |
"tốc độ tối thiểu" in chunk_metadata.get("speed_category","").lower() or \ | |
"tốc độ tối thiểu" in chunk_text_lower: | |
metadata_boost_components["min_speed"] += boost_min_speed | |
# 5.8. Khoảng cách an toàn | |
if query_analysis["mentions_safe_distance"]: | |
if any(kw in chunk_text_lower for kw in ["khoảng cách an toàn", "cự ly an toàn", "cự ly tối thiểu", "giữ khoảng cách"]): | |
metadata_boost_components["safe_distance"] += boost_safe_distance | |
# 5.9. Biện pháp khắc phục | |
if query_analysis["mentions_remedial_measures"]: | |
if any(kw in chunk_text_lower for kw in ["biện pháp khắc phục", "khắc phục hậu quả"]): | |
metadata_boost_components["remedial_measures"] += boost_remedial_measures | |
if not passes_all_strict_filters: | |
continue | |
total_metadata_boost = sum(metadata_boost_components.values()) | |
final_score_calculated = base_score_from_fusion + total_metadata_boost | |
final_processed_results.append({ | |
"rank_after_fusion": rank_idx + 1, | |
"index": int(result_index), | |
"base_score_rrf": float(base_score_from_fusion), | |
"metadata_boost_components": dict(metadata_boost_components), # Lưu lại để debug | |
"metadata_boost_total": float(total_metadata_boost), | |
"final_score": final_score_calculated, | |
"text": original_chunk.get('text', '*Không có text*'), | |
"metadata": chunk_metadata, # Giữ nguyên metadata gốc | |
"query_analysis_for_boost": query_analysis # (Tùy chọn) Lưu lại query analysis dùng cho boosting | |
}) | |
except IndexError: | |
print(f"Lỗi Index: Chỉ số {result_index} nằm ngoài chunks_data (size: {len(chunks_data)}). Bỏ qua chunk này.") | |
except Exception as e: | |
print(f"Lỗi khi xử lý ứng viên lai ghép tại chỉ số {result_index}: {e}. Bỏ qua chunk này.") | |
final_processed_results.sort(key=lambda x: x["final_score"], reverse=True) | |
final_results_top_k = final_processed_results[:k] | |
print(f" ✅ Xử lý, kết hợp metadata boost và tái xếp hạng hoàn tất. Trả về {len(final_results_top_k)} kết quả.") | |
return final_results_top_k | |
# --- (Tùy chọn) Hàm main để test nhanh file này --- | |
if __name__ == '__main__': | |
print("Chạy test cho retrieval.py...") | |
# --- Test get_standardized_vehicle_type --- | |
print("\n--- Test get_standardized_vehicle_type ---") | |
test_vehicles = [ | |
"người điều khiển xe ô tô con", "xe gắn máy", "xe cơ giới", "xe máy chuyên dùng", | |
"xe đạp điện", "người đi bộ", "ô tô tải và xe mô tô", None, "" | |
] | |
for tv in test_vehicles: | |
print(f"'{tv}' -> '{get_standardized_vehicle_type(tv)}'") | |
# --- Test analyze_query --- | |
print("\n--- Test analyze_query ---") | |
test_queries = [ | |
"xe máy không gương phạt bao nhiêu tiền?", | |
"ô tô chạy quá tốc độ 20km bị trừ mấy điểm gplx", | |
"đi bộ ở đâu thì đúng luật", | |
"biện pháp khắc phục khi gây tai nạn là gì" | |
] | |
for tq in test_queries: | |
print(f"Query: '{tq}'\nAnalysis: {json.dumps(analyze_query(tq), indent=2, ensure_ascii=False)}") | |
# --- Test process_law_data_to_chunks (cần file JSON mẫu) --- | |
# Giả sử bạn có file JSON mẫu 'sample_law_data.json' cùng cấp | |
# sample_data = { | |
# "article": "Điều 5", | |
# "article_title": "Xử phạt người điều khiển xe ô tô và các loại xe tương tự xe ô tô vi phạm quy tắc giao thông đường bộ", | |
# "source_document": "Nghị định 100/2019/NĐ-CP", | |
# "clauses": [ | |
# { | |
# "clause_number": "1", | |
# "clause_text_original": "Phạt tiền từ 200.000 đồng đến 400.000 đồng đối với người điều khiển xe thực hiện một trong các hành vi vi phạm sau đây:", | |
# "points_in_clause": [ | |
# { | |
# "point_id": "a", | |
# "point_text_original": "Không chấp hành hiệu lệnh, chỉ dẫn của biển báo hiệu, vạch kẻ đường, trừ các hành vi vi phạm quy định tại điểm a khoản 2, điểm c khoản 3, điểm đ khoản 4, điểm g khoản 5, điểm b khoản 6, điểm b khoản 7, điểm d khoản 8 Điều này;", | |
# "penalties_detail": [ | |
# {"penalty_type": "Phạt tiền", "details": {"individual_fine_min": 200000, "individual_fine_max": 400000}} | |
# ], | |
# "speed_limits_by_vehicle_type_and_road_type": [{"vehicle_type": "xe ô tô con"}] | |
# } | |
# ] | |
# }, | |
# { | |
# "clause_number": "12", # Khoản không có điểm | |
# "clause_text_original": "Ngoài việc bị phạt tiền, người điều khiển xe thực hiện hành vi vi phạm còn bị áp dụng các hình thức xử phạt bổ sung sau đây: ...", | |
# "clause_metadata_summary": {"overall_fine_note_for_clause": "Áp dụng hình phạt bổ sung"} | |
# } | |
# ] | |
# } | |
# | |
# print("\n--- Test process_law_data_to_chunks ---") | |
# chunks = process_law_data_to_chunks(sample_data) | |
# print(f"Số chunks được tạo: {len(chunks)}") | |
# if chunks: | |
# print("Chunk đầu tiên:") | |
# print(json.dumps(chunks[0], indent=2, ensure_ascii=False)) | |
# print("Chunk cuối cùng (nếu có nhiều hơn 1):") | |
# if len(chunks) > 1: | |
# print(json.dumps(chunks[-1], indent=2, ensure_ascii=False)) | |
# Để test search_relevant_laws, bạn cần mock embedding_model, faiss_index, bm25_model | |
# và có chunks_data đã được xử lý. | |
print("\n--- Các test khác (ví dụ: search_relevant_laws) cần mock hoặc dữ liệu đầy đủ. ---") |