Spaces:
Running
on
T4
Running
on
T4
update
Browse files- rag_pipeline.py +77 -47
- retriever.py +89 -76
rag_pipeline.py
CHANGED
@@ -80,74 +80,104 @@ def initialize_components(data_path):
|
|
80 |
"bm25_model": bm25_model
|
81 |
}
|
82 |
|
83 |
-
def
|
84 |
"""
|
85 |
-
|
|
|
86 |
"""
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
# 1. Truy xuất ngữ cảnh
|
|
|
94 |
retrieved_results = search_relevant_laws(
|
95 |
query_text=query,
|
96 |
embedding_model=components["embedding_model"],
|
97 |
faiss_index=components["faiss_index"],
|
98 |
chunks_data=components["chunks_data"],
|
99 |
bm25_model=components["bm25_model"],
|
100 |
-
k=5,
|
101 |
-
initial_k_multiplier=
|
102 |
)
|
103 |
|
104 |
-
# 2. Định dạng Context
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
112 |
-
text = res.get('text', '*Nội dung không có*')
|
113 |
-
context_parts.append(f"{header}\n{text}")
|
114 |
-
context = "\n\n---\n\n".join(context_parts)
|
115 |
-
|
116 |
-
# 3. Xây dựng Prompt đơn giản (không có lịch sử trò chuyện)
|
117 |
-
prompt = f"""Bạn là trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam.
|
118 |
-
|
119 |
-
Dựa trên các đoạn luật dưới đây:
|
120 |
{context}
|
121 |
|
122 |
-
|
|
|
123 |
|
124 |
-
|
125 |
-
"""
|
126 |
|
127 |
-
|
|
|
|
|
128 |
|
129 |
-
#
|
130 |
-
inputs = tokenizer(
|
131 |
-
text=prompt,
|
132 |
-
images=None,
|
133 |
-
return_tensors="pt"
|
134 |
-
).to("cuda" if torch.cuda.is_available() else "cpu")
|
135 |
|
|
|
136 |
generation_config = dict(
|
137 |
max_new_tokens=256,
|
138 |
-
temperature=0.
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
do_sample=True,
|
143 |
-
pad_token_id=tokenizer.eos_token_id,
|
144 |
-
eos_token_id=tokenizer.eos_token_id
|
145 |
)
|
146 |
|
147 |
output_ids = llm_model.generate(**inputs, **generation_config)
|
148 |
-
|
149 |
-
|
150 |
-
response_text = tokenizer.decode(
|
151 |
|
152 |
print("--- Tạo câu trả lời hoàn tất ---")
|
153 |
-
return response_text
|
|
|
80 |
"bm25_model": bm25_model
|
81 |
}
|
82 |
|
83 |
+
def _format_context_with_summary(retrieved_results: list[dict]) -> str:
|
84 |
"""
|
85 |
+
Hàm phụ trợ: Định dạng ngữ cảnh từ kết quả truy xuất, bổ sung tóm tắt từ metadata.
|
86 |
+
Hàm này được tách ra để làm cho code sạch sẽ và dễ quản lý hơn.
|
87 |
"""
|
88 |
+
if not retrieved_results:
|
89 |
+
return "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
90 |
+
|
91 |
+
context_parts = []
|
92 |
+
for i, res in enumerate(retrieved_results):
|
93 |
+
metadata = res.get('metadata', {})
|
94 |
+
text = res.get('text', '*Nội dung không có*')
|
95 |
+
|
96 |
+
# Tạo header rõ ràng
|
97 |
+
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Điểm {metadata.get('point_id', '')} Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
98 |
+
|
99 |
+
# --- LOGIC TÓM TẮT THÔNG MINH TỪ METADATA ---
|
100 |
+
metadata_summary = ""
|
101 |
+
penalty_details_list = metadata.get("penalties_detail", [])
|
102 |
+
|
103 |
+
if penalty_details_list:
|
104 |
+
summary_parts = []
|
105 |
+
# Chỉ lấy thông tin từ mục hình phạt đầu tiên trong danh sách
|
106 |
+
details = penalty_details_list[0].get('details', {})
|
107 |
+
|
108 |
+
# Tóm tắt mức phạt tiền cho cá nhân (phổ biến nhất)
|
109 |
+
i_min = details.get("individual_fine_min")
|
110 |
+
i_max = details.get("individual_fine_max")
|
111 |
+
if i_min is not None and i_max is not None:
|
112 |
+
summary_parts.append(f"Phạt tiền cá nhân từ {i_min:,} - {i_max:,} đồng.")
|
113 |
+
|
114 |
+
# Tóm tắt mức trừ điểm
|
115 |
+
points = details.get("points_deducted")
|
116 |
+
if points is not None:
|
117 |
+
summary_parts.append(f"Trừ {points} điểm GPLX.")
|
118 |
+
|
119 |
+
if summary_parts:
|
120 |
+
# Chèn dòng tóm tắt vào giữa header và text
|
121 |
+
metadata_summary = f"\n[Tóm tắt từ metadata: {' '.join(summary_parts)}]"
|
122 |
+
|
123 |
+
context_parts.append(f"{header}{metadata_summary}\n{text}")
|
124 |
+
|
125 |
+
return "\n\n---\n\n".join(context_parts)
|
126 |
+
|
127 |
+
|
128 |
+
def generate_response(query: str, components: dict) -> str:
|
129 |
+
"""
|
130 |
+
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
131 |
+
Phiên bản đã được tối ưu và tái cấu trúc.
|
132 |
+
"""
|
133 |
+
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
134 |
|
135 |
+
# 1. Truy xuất ngữ cảnh bằng retriever đã được nâng cấp
|
136 |
+
# (Giả định search_relevant_laws đã được sửa để ưu tiên loại xe)
|
137 |
retrieved_results = search_relevant_laws(
|
138 |
query_text=query,
|
139 |
embedding_model=components["embedding_model"],
|
140 |
faiss_index=components["faiss_index"],
|
141 |
chunks_data=components["chunks_data"],
|
142 |
bm25_model=components["bm25_model"],
|
143 |
+
k=5,
|
144 |
+
initial_k_multiplier=15
|
145 |
)
|
146 |
|
147 |
+
# 2. Định dạng Context một cách thông minh bằng hàm phụ trợ
|
148 |
+
context = _format_context_with_summary(retrieved_results)
|
149 |
+
|
150 |
+
# 3. Xây dựng Prompt
|
151 |
+
prompt = f"""Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam. Dựa vào các trích dẫn luật dưới đây để trả lời câu hỏi của người dùng một cách chính xác.
|
152 |
+
|
153 |
+
### Thông tin luật:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
{context}
|
155 |
|
156 |
+
### Câu hỏi:
|
157 |
+
{query}
|
158 |
|
159 |
+
### Trả lời:"""
|
|
|
160 |
|
161 |
+
# 4. Tạo câu trả lời từ LLM
|
162 |
+
llm_model = components["llm_model"]
|
163 |
+
tokenizer = components["tokenizer"]
|
164 |
|
165 |
+
# Chuyển input lên cùng device với model
|
166 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
# Cấu hình generation tối ưu cho việc trả lời câu hỏi dựa trên facts
|
169 |
generation_config = dict(
|
170 |
max_new_tokens=256,
|
171 |
+
temperature=0.1, # Rất thấp để câu trả lời bám sát ngữ cảnh
|
172 |
+
repetition_penalty=1.1, # Phạt nhẹ việc lặp từ
|
173 |
+
do_sample=True, # Vẫn cần bật để temperature và các tham số khác có hiệu lực
|
174 |
+
pad_token_id=tokenizer.eos_token_id
|
|
|
|
|
|
|
175 |
)
|
176 |
|
177 |
output_ids = llm_model.generate(**inputs, **generation_config)
|
178 |
+
|
179 |
+
# Chỉ decode phần văn bản được sinh ra mới, bỏ qua phần prompt
|
180 |
+
response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
181 |
|
182 |
print("--- Tạo câu trả lời hoàn tất ---")
|
183 |
+
return response_text
|
retriever.py
CHANGED
@@ -12,106 +12,119 @@ def tokenize_vi_for_bm25_setup(text):
|
|
12 |
text = re.sub(r'[^\w\s]', '', text)
|
13 |
return text.split()
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def search_relevant_laws(
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
"""
|
26 |
-
Thực hiện
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
"""
|
29 |
if k <= 0:
|
30 |
-
print("Lỗi: k (số lượng kết quả) phải là số dương.")
|
31 |
return []
|
32 |
|
33 |
-
print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'")
|
34 |
-
query_lower = query_text.lower()
|
35 |
-
|
36 |
-
# Phân tích query
|
37 |
-
fine_keywords = r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính'
|
38 |
-
points_keywords = r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx'
|
39 |
-
query_mentions_fine = bool(re.search(fine_keywords, query_lower))
|
40 |
-
query_mentions_points = bool(re.search(points_keywords, query_lower))
|
41 |
-
needs_specific_metadata_filter = query_mentions_fine or query_mentions_points
|
42 |
-
print(f" Phân tích query: Đề cập tiền phạt? {query_mentions_fine}, Đề cập điểm trừ? {query_mentions_points}")
|
43 |
-
|
44 |
num_vectors_in_index = faiss_index.ntotal
|
45 |
if num_vectors_in_index == 0:
|
46 |
-
print("Lỗi: FAISS index rỗng.")
|
47 |
return []
|
48 |
|
49 |
-
|
50 |
|
51 |
-
#
|
52 |
try:
|
53 |
-
|
54 |
-
query_embedding_np =
|
55 |
faiss.normalize_L2(query_embedding_np)
|
56 |
-
|
|
|
57 |
except Exception as e:
|
58 |
-
print(f"Lỗi
|
59 |
-
|
60 |
|
61 |
-
#
|
62 |
try:
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever]
|
68 |
except Exception as e:
|
69 |
-
print(f"Lỗi
|
70 |
-
|
71 |
|
72 |
-
#
|
73 |
rrf_scores = defaultdict(float)
|
74 |
-
|
75 |
|
76 |
-
|
77 |
-
for rank, doc_idx in enumerate(semantic_indices_raw[0]):
|
78 |
-
if 0 <= doc_idx < num_vectors_in_index:
|
79 |
-
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
|
80 |
-
all_retrieved_indices_set.add(doc_idx)
|
81 |
-
|
82 |
-
for rank, item in enumerate(top_bm25_results):
|
83 |
-
doc_idx = item['index']
|
84 |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
|
85 |
-
all_retrieved_indices_set.add(doc_idx)
|
86 |
|
87 |
-
|
88 |
-
|
89 |
|
90 |
-
#
|
91 |
-
|
92 |
-
|
|
|
93 |
|
94 |
-
for
|
95 |
try:
|
96 |
-
|
97 |
-
|
98 |
-
original_chunk = chunks_data[result_index]
|
99 |
-
original_metadata = original_chunk.get('metadata', {})
|
100 |
-
# Thêm logic xử lý metadata boosting ở đây nếu cần...
|
101 |
-
# Hiện tại, chỉ trả về kết quả đã fusion.
|
102 |
-
# Bạn có thể copy lại toàn bộ logic boosting từ script gốc vào đây.
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
})
|
113 |
-
except
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
text = re.sub(r'[^\w\s]', '', text)
|
13 |
return text.split()
|
14 |
|
15 |
+
def _get_vehicle_type(query_lower: str) -> str | None:
|
16 |
+
"""Xác định loại xe được đề cập trong câu truy vấn."""
|
17 |
+
# Từ điển định nghĩa các từ khóa cho từng loại xe
|
18 |
+
vehicle_keywords = {
|
19 |
+
"ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"],
|
20 |
+
"xe máy": ["xe máy", "xe mô tô", "xe gắn máy"],
|
21 |
+
"xe đạp": ["xe đạp", "xe thô sơ"],
|
22 |
+
"máy kéo": ["máy kéo", "xe chuyên dùng"]
|
23 |
+
}
|
24 |
+
for vehicle_type, keywords in vehicle_keywords.items():
|
25 |
+
if any(keyword in query_lower for keyword in keywords):
|
26 |
+
return vehicle_type
|
27 |
+
return None
|
28 |
+
|
29 |
def search_relevant_laws(
|
30 |
+
query_text: str,
|
31 |
+
embedding_model,
|
32 |
+
faiss_index,
|
33 |
+
chunks_data: list[dict],
|
34 |
+
bm25_model,
|
35 |
+
k: int = 5,
|
36 |
+
initial_k_multiplier: int = 15,
|
37 |
+
rrf_k_constant: int = 60
|
38 |
+
) -> list[dict]:
|
39 |
"""
|
40 |
+
Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe.
|
41 |
+
|
42 |
+
Quy trình:
|
43 |
+
1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa).
|
44 |
+
2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF).
|
45 |
+
3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe).
|
46 |
+
4. Sắp xếp lại và trả về top-k kết quả cuối cùng.
|
47 |
"""
|
48 |
if k <= 0:
|
|
|
49 |
return []
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
num_vectors_in_index = faiss_index.ntotal
|
52 |
if num_vectors_in_index == 0:
|
|
|
53 |
return []
|
54 |
|
55 |
+
num_candidates = min(k * initial_k_multiplier, num_vectors_in_index)
|
56 |
|
57 |
+
# --- 1. Semantic Search (FAISS) ---
|
58 |
try:
|
59 |
+
query_embedding = embedding_model.encode([query_text], convert_to_tensor=True)
|
60 |
+
query_embedding_np = query_embedding.cpu().numpy().astype('float32')
|
61 |
faiss.normalize_L2(query_embedding_np)
|
62 |
+
_, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
|
63 |
+
semantic_indices = semantic_indices[0]
|
64 |
except Exception as e:
|
65 |
+
print(f"Lỗi FAISS search: {e}")
|
66 |
+
semantic_indices = []
|
67 |
|
68 |
+
# --- 2. Keyword Search (BM25) ---
|
69 |
try:
|
70 |
+
tokenized_query = tokenize_vi_for_bm25_setup(query_text)
|
71 |
+
bm25_scores = bm25_model.get_scores(tokenized_query)
|
72 |
+
# Lấy top N chỉ mục từ BM25
|
73 |
+
top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates]
|
|
|
74 |
except Exception as e:
|
75 |
+
print(f"Lỗi BM25 search: {e}")
|
76 |
+
top_bm25_indices = []
|
77 |
|
78 |
+
# --- 3. Result Fusion (RRF) ---
|
79 |
rrf_scores = defaultdict(float)
|
80 |
+
all_indices = set(semantic_indices) | set(top_bm25_indices)
|
81 |
|
82 |
+
for rank, doc_idx in enumerate(semantic_indices):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
|
|
|
84 |
|
85 |
+
for rank, doc_idx in enumerate(top_bm25_indices):
|
86 |
+
rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
|
87 |
|
88 |
+
# --- 4. Metadata Boosting & Final Ranking ---
|
89 |
+
query_lower = query_text.lower()
|
90 |
+
matched_vehicle = _get_vehicle_type(query_lower)
|
91 |
+
final_results = []
|
92 |
|
93 |
+
for doc_idx in all_indices:
|
94 |
try:
|
95 |
+
metadata = chunks_data[doc_idx].get('metadata', {})
|
96 |
+
final_score = rrf_scores[doc_idx]
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
# **LOGIC BOOSTING QUAN TRỌNG NHẤT**
|
99 |
+
if matched_vehicle:
|
100 |
+
article_title_lower = metadata.get("article_title", "").lower()
|
101 |
+
# Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài
|
102 |
+
vehicle_keywords = {
|
103 |
+
"ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"],
|
104 |
+
"xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"]
|
105 |
+
}
|
106 |
+
if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])):
|
107 |
+
# Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên
|
108 |
+
final_score += 0.5
|
109 |
+
|
110 |
+
final_results.append({
|
111 |
+
'index': doc_idx,
|
112 |
+
'final_score': final_score
|
113 |
})
|
114 |
+
except IndexError:
|
115 |
+
continue
|
116 |
+
|
117 |
+
final_results.sort(key=lambda x: x['final_score'], reverse=True)
|
118 |
+
|
119 |
+
# Lấy đầy đủ thông tin cho top-k kết quả cuối cùng
|
120 |
+
top_k_results = []
|
121 |
+
for res in final_results[:k]:
|
122 |
+
doc_idx = res['index']
|
123 |
+
top_k_results.append({
|
124 |
+
'index': doc_idx,
|
125 |
+
'final_score': res['final_score'],
|
126 |
+
'text': chunks_data[doc_idx].get('text', ''),
|
127 |
+
'metadata': chunks_data[doc_idx].get('metadata', {})
|
128 |
+
})
|
129 |
+
|
130 |
+
return top_k_results
|