deddoggo commited on
Commit
3264b15
·
1 Parent(s): 0002c1d
Files changed (2) hide show
  1. rag_pipeline.py +77 -47
  2. retriever.py +89 -76
rag_pipeline.py CHANGED
@@ -80,74 +80,104 @@ def initialize_components(data_path):
80
  "bm25_model": bm25_model
81
  }
82
 
83
- def generate_response(query, components):
84
  """
85
- Tạo câu trả lời cho một query (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
 
86
  """
87
- print("--- Bắt đầu quy trình RAG (Single-turn) cho query mới ---")
88
-
89
- # Unpack các thành phần
90
- llm_model = components["llm_model"]
91
- tokenizer = components["tokenizer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- # 1. Truy xuất ngữ cảnh trực tiếp từ câu hỏi của người dùng
 
94
  retrieved_results = search_relevant_laws(
95
  query_text=query,
96
  embedding_model=components["embedding_model"],
97
  faiss_index=components["faiss_index"],
98
  chunks_data=components["chunks_data"],
99
  bm25_model=components["bm25_model"],
100
- k=5,
101
- initial_k_multiplier=18
102
  )
103
 
104
- # 2. Định dạng Context
105
- if not retrieved_results:
106
- context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
107
- else:
108
- context_parts = []
109
- for i, res in enumerate(retrieved_results):
110
- metadata = res.get('metadata', {})
111
- header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
112
- text = res.get('text', '*Nội dung không có*')
113
- context_parts.append(f"{header}\n{text}")
114
- context = "\n\n---\n\n".join(context_parts)
115
-
116
- # 3. Xây dựng Prompt đơn giản (không có lịch sử trò chuyện)
117
- prompt = f"""Bạn là trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam.
118
-
119
- Dựa trên các đoạn luật dưới đây:
120
  {context}
121
 
122
- Hãy trả lời câu hỏi của người dùng bằng tiếng Việt, chính xác và dễ hiểu. Nếu cần, hãy trích dẫn điều, khoản hoặc điểm tương ứng trong văn bản luật. Nếu không đủ thông tin trong các đoạn trên, hãy trả lời “Tôi không chắc, cần kiểm tra thêm văn bản luật liên quan.”
 
123
 
124
- Câu hỏi: {query}
125
- """
126
 
127
- print("--- Bắt đầu tạo câu trả lời từ LLM ---")
 
 
128
 
129
- # SỬA LỖI CHO VISION MODEL: Sử dụng API tường minh
130
- inputs = tokenizer(
131
- text=prompt,
132
- images=None,
133
- return_tensors="pt"
134
- ).to("cuda" if torch.cuda.is_available() else "cpu")
135
 
 
136
  generation_config = dict(
137
  max_new_tokens=256,
138
- temperature=0.5,
139
- top_p=0.7,
140
- top_k=50,
141
- repetition_penalty=1.1,
142
- do_sample=True,
143
- pad_token_id=tokenizer.eos_token_id,
144
- eos_token_id=tokenizer.eos_token_id
145
  )
146
 
147
  output_ids = llm_model.generate(**inputs, **generation_config)
148
- input_length = inputs.input_ids.shape[1]
149
- generated_ids = output_ids[0][input_length:]
150
- response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
151
 
152
  print("--- Tạo câu trả lời hoàn tất ---")
153
- return response_text
 
80
  "bm25_model": bm25_model
81
  }
82
 
83
+ def _format_context_with_summary(retrieved_results: list[dict]) -> str:
84
  """
85
+ Hàm phụ trợ: Định dạng ngữ cảnh từ kết quả truy xuất, bổ sung tóm tắt từ metadata.
86
+ Hàm này được tách ra để làm cho code sạch sẽ và dễ quản lý hơn.
87
  """
88
+ if not retrieved_results:
89
+ return "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
90
+
91
+ context_parts = []
92
+ for i, res in enumerate(retrieved_results):
93
+ metadata = res.get('metadata', {})
94
+ text = res.get('text', '*Nội dung không có*')
95
+
96
+ # Tạo header rõ ràng
97
+ header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Điểm {metadata.get('point_id', '')} Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
98
+
99
+ # --- LOGIC TÓM TẮT THÔNG MINH TỪ METADATA ---
100
+ metadata_summary = ""
101
+ penalty_details_list = metadata.get("penalties_detail", [])
102
+
103
+ if penalty_details_list:
104
+ summary_parts = []
105
+ # Chỉ lấy thông tin từ mục hình phạt đầu tiên trong danh sách
106
+ details = penalty_details_list[0].get('details', {})
107
+
108
+ # Tóm tắt mức phạt tiền cho cá nhân (phổ biến nhất)
109
+ i_min = details.get("individual_fine_min")
110
+ i_max = details.get("individual_fine_max")
111
+ if i_min is not None and i_max is not None:
112
+ summary_parts.append(f"Phạt tiền cá nhân từ {i_min:,} - {i_max:,} đồng.")
113
+
114
+ # Tóm tắt mức trừ điểm
115
+ points = details.get("points_deducted")
116
+ if points is not None:
117
+ summary_parts.append(f"Trừ {points} điểm GPLX.")
118
+
119
+ if summary_parts:
120
+ # Chèn dòng tóm tắt vào giữa header và text
121
+ metadata_summary = f"\n[Tóm tắt từ metadata: {' '.join(summary_parts)}]"
122
+
123
+ context_parts.append(f"{header}{metadata_summary}\n{text}")
124
+
125
+ return "\n\n---\n\n".join(context_parts)
126
+
127
+
128
+ def generate_response(query: str, components: dict) -> str:
129
+ """
130
+ Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
131
+ Phiên bản đã được tối ưu và tái cấu trúc.
132
+ """
133
+ print("--- Bắt đầu quy trình RAG cho query mới ---")
134
 
135
+ # 1. Truy xuất ngữ cảnh bằng retriever đã được nâng cấp
136
+ # (Giả định search_relevant_laws đã được sửa để ưu tiên loại xe)
137
  retrieved_results = search_relevant_laws(
138
  query_text=query,
139
  embedding_model=components["embedding_model"],
140
  faiss_index=components["faiss_index"],
141
  chunks_data=components["chunks_data"],
142
  bm25_model=components["bm25_model"],
143
+ k=5,
144
+ initial_k_multiplier=15
145
  )
146
 
147
+ # 2. Định dạng Context một cách thông minh bằng hàm phụ trợ
148
+ context = _format_context_with_summary(retrieved_results)
149
+
150
+ # 3. Xây dựng Prompt
151
+ prompt = f"""Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam. Dựa vào các trích dẫn luật dưới đây để trả lời câu hỏi của người dùng một cách chính xác.
152
+
153
+ ### Thông tin luật:
 
 
 
 
 
 
 
 
 
154
  {context}
155
 
156
+ ### Câu hỏi:
157
+ {query}
158
 
159
+ ### Trả lời:"""
 
160
 
161
+ # 4. Tạo câu trả lời từ LLM
162
+ llm_model = components["llm_model"]
163
+ tokenizer = components["tokenizer"]
164
 
165
+ # Chuyển input lên cùng device với model
166
+ inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
 
 
 
 
167
 
168
+ # Cấu hình generation tối ưu cho việc trả lời câu hỏi dựa trên facts
169
  generation_config = dict(
170
  max_new_tokens=256,
171
+ temperature=0.1, # Rất thấp để câu trả lời bám sát ngữ cảnh
172
+ repetition_penalty=1.1, # Phạt nhẹ việc lặp từ
173
+ do_sample=True, # Vẫn cần bật để temperature và các tham số khác có hiệu lực
174
+ pad_token_id=tokenizer.eos_token_id
 
 
 
175
  )
176
 
177
  output_ids = llm_model.generate(**inputs, **generation_config)
178
+
179
+ # Chỉ decode phần văn bản được sinh ra mới, bỏ qua phần prompt
180
+ response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
181
 
182
  print("--- Tạo câu trả lời hoàn tất ---")
183
+ return response_text
retriever.py CHANGED
@@ -12,106 +12,119 @@ def tokenize_vi_for_bm25_setup(text):
12
  text = re.sub(r'[^\w\s]', '', text)
13
  return text.split()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def search_relevant_laws(
16
- query_text,
17
- embedding_model,
18
- faiss_index,
19
- chunks_data,
20
- bm25_model,
21
- k=5,
22
- initial_k_multiplier=10,
23
- rrf_k_constant=60
24
- ):
25
  """
26
- Thực hiện tìm kiếm lai (Hybrid Search) kết hợp Semantic Search (FAISS) Keyword Search (BM25),
27
- sau đó kết hợp kết quả bằng Reciprocal Rank Fusion (RRF) và tăng cường bằng metadata.
 
 
 
 
 
28
  """
29
  if k <= 0:
30
- print("Lỗi: k (số lượng kết quả) phải là số dương.")
31
  return []
32
 
33
- print(f"\n🔎 Đang tìm kiếm (Hybrid) cho truy vấn: '{query_text}'")
34
- query_lower = query_text.lower()
35
-
36
- # Phân tích query
37
- fine_keywords = r'tiền|phạt|bao nhiêu đồng|bao nhiêu tiền|mức phạt|xử phạt hành chính'
38
- points_keywords = r'điểm|trừ điểm|mấy điểm|trừ bao nhiêu điểm|bằng lái|gplx'
39
- query_mentions_fine = bool(re.search(fine_keywords, query_lower))
40
- query_mentions_points = bool(re.search(points_keywords, query_lower))
41
- needs_specific_metadata_filter = query_mentions_fine or query_mentions_points
42
- print(f" Phân tích query: Đề cập tiền phạt? {query_mentions_fine}, Đề cập điểm trừ? {query_mentions_points}")
43
-
44
  num_vectors_in_index = faiss_index.ntotal
45
  if num_vectors_in_index == 0:
46
- print("Lỗi: FAISS index rỗng.")
47
  return []
48
 
49
- num_candidates_each_retriever = min(k * initial_k_multiplier, num_vectors_in_index)
50
 
51
- # === 1. Semantic Search (FAISS) ===
52
  try:
53
- query_embedding_tensor = embedding_model.encode([query_text], convert_to_tensor=True, device=embedding_model.device)
54
- query_embedding_np = query_embedding_tensor.cpu().numpy().astype('float32')
55
  faiss.normalize_L2(query_embedding_np)
56
- semantic_scores_raw, semantic_indices_raw = faiss_index.search(query_embedding_np, num_candidates_each_retriever)
 
57
  except Exception as e:
58
- print(f"Lỗi khi tìm kiếm ngữ nghĩa (FAISS): {e}")
59
- semantic_indices_raw = np.array([[]], dtype=int)
60
 
61
- # === 2. Keyword Search (BM25) ===
62
  try:
63
- tokenized_query_bm25 = tokenize_vi_for_bm25_setup(query_text)
64
- all_bm25_scores = bm25_model.get_scores(tokenized_query_bm25)
65
- bm25_results_with_indices = [{'index': i, 'score': score} for i, score in enumerate(all_bm25_scores) if score > 0]
66
- bm25_results_with_indices.sort(key=lambda x: x['score'], reverse=True)
67
- top_bm25_results = bm25_results_with_indices[:num_candidates_each_retriever]
68
  except Exception as e:
69
- print(f"Lỗi khi tìm kiếm từ khóa (BM25): {e}")
70
- top_bm25_results = []
71
 
72
- # === 3. Result Fusion (RRF) ===
73
  rrf_scores = defaultdict(float)
74
- all_retrieved_indices_set = set()
75
 
76
- if semantic_indices_raw.size > 0:
77
- for rank, doc_idx in enumerate(semantic_indices_raw[0]):
78
- if 0 <= doc_idx < num_vectors_in_index:
79
- rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
80
- all_retrieved_indices_set.add(doc_idx)
81
-
82
- for rank, item in enumerate(top_bm25_results):
83
- doc_idx = item['index']
84
  rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
85
- all_retrieved_indices_set.add(doc_idx)
86
 
87
- fused_initial_results = [{'index': doc_idx, 'fused_score': rrf_scores[doc_idx]} for doc_idx in all_retrieved_indices_set]
88
- fused_initial_results.sort(key=lambda x: x['fused_score'], reverse=True)
89
 
90
- # === 4. Lọc Tái xếp hạng cuối cùng ===
91
- final_processed_results = []
92
- num_to_process_metadata = min(len(fused_initial_results), num_candidates_each_retriever * 2)
 
93
 
94
- for rank_idx, res_item in enumerate(fused_initial_results[:num_to_process_metadata]):
95
  try:
96
- result_index = res_item['index']
97
- base_score_from_fusion = res_item['fused_score']
98
- original_chunk = chunks_data[result_index]
99
- original_metadata = original_chunk.get('metadata', {})
100
- # Thêm logic xử lý metadata boosting ở đây nếu cần...
101
- # Hiện tại, chỉ trả về kết quả đã fusion.
102
- # Bạn có thể copy lại toàn bộ logic boosting từ script gốc vào đây.
103
 
104
- final_score_calculated = base_score_from_fusion # (Thêm boosting vào đây)
105
-
106
- final_processed_results.append({
107
- "rank_after_fusion": rank_idx + 1,
108
- "index": int(result_index),
109
- "final_score": final_score_calculated,
110
- "text": original_chunk.get('text', '*Không text*'),
111
- "metadata": original_metadata
 
 
 
 
 
 
 
112
  })
113
- except Exception as e:
114
- print(f"Lỗi khi xử lý ứng viên tại chỉ số {res_item.get('index')}: {e}")
115
-
116
- final_processed_results.sort(key=lambda x: x["final_score"], reverse=True)
117
- return final_processed_results[:k]
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  text = re.sub(r'[^\w\s]', '', text)
13
  return text.split()
14
 
15
+ def _get_vehicle_type(query_lower: str) -> str | None:
16
+ """Xác định loại xe được đề cập trong câu truy vấn."""
17
+ # Từ điển định nghĩa các từ khóa cho từng loại xe
18
+ vehicle_keywords = {
19
+ "ô tô": ["ô tô", "xe con", "xe chở người", "xe chở hàng"],
20
+ "xe máy": ["xe máy", "xe mô tô", "xe gắn máy"],
21
+ "xe đạp": ["xe đạp", "xe thô sơ"],
22
+ "máy kéo": ["máy kéo", "xe chuyên dùng"]
23
+ }
24
+ for vehicle_type, keywords in vehicle_keywords.items():
25
+ if any(keyword in query_lower for keyword in keywords):
26
+ return vehicle_type
27
+ return None
28
+
29
  def search_relevant_laws(
30
+ query_text: str,
31
+ embedding_model,
32
+ faiss_index,
33
+ chunks_data: list[dict],
34
+ bm25_model,
35
+ k: int = 5,
36
+ initial_k_multiplier: int = 15,
37
+ rrf_k_constant: int = 60
38
+ ) -> list[dict]:
39
  """
40
+ Thực hiện Tìm kiếm Lai (Hybrid Search) với logic tăng điểm (boosting) cho loại xe.
41
+
42
+ Quy trình:
43
+ 1. Tìm kiếm song song bằng FAISS (ngữ nghĩa) và BM25 (từ khóa).
44
+ 2. Kết hợp kết quả bằng Reciprocal Rank Fusion (RRF).
45
+ 3. Tăng điểm (boost) cho các kết quả khớp với metadata quan trọng (loại xe).
46
+ 4. Sắp xếp lại và trả về top-k kết quả cuối cùng.
47
  """
48
  if k <= 0:
 
49
  return []
50
 
 
 
 
 
 
 
 
 
 
 
 
51
  num_vectors_in_index = faiss_index.ntotal
52
  if num_vectors_in_index == 0:
 
53
  return []
54
 
55
+ num_candidates = min(k * initial_k_multiplier, num_vectors_in_index)
56
 
57
+ # --- 1. Semantic Search (FAISS) ---
58
  try:
59
+ query_embedding = embedding_model.encode([query_text], convert_to_tensor=True)
60
+ query_embedding_np = query_embedding.cpu().numpy().astype('float32')
61
  faiss.normalize_L2(query_embedding_np)
62
+ _, semantic_indices = faiss_index.search(query_embedding_np, num_candidates)
63
+ semantic_indices = semantic_indices[0]
64
  except Exception as e:
65
+ print(f"Lỗi FAISS search: {e}")
66
+ semantic_indices = []
67
 
68
+ # --- 2. Keyword Search (BM25) ---
69
  try:
70
+ tokenized_query = tokenize_vi_for_bm25_setup(query_text)
71
+ bm25_scores = bm25_model.get_scores(tokenized_query)
72
+ # Lấy top N chỉ mục từ BM25
73
+ top_bm25_indices = np.argsort(bm25_scores)[::-1][:num_candidates]
 
74
  except Exception as e:
75
+ print(f"Lỗi BM25 search: {e}")
76
+ top_bm25_indices = []
77
 
78
+ # --- 3. Result Fusion (RRF) ---
79
  rrf_scores = defaultdict(float)
80
+ all_indices = set(semantic_indices) | set(top_bm25_indices)
81
 
82
+ for rank, doc_idx in enumerate(semantic_indices):
 
 
 
 
 
 
 
83
  rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
 
84
 
85
+ for rank, doc_idx in enumerate(top_bm25_indices):
86
+ rrf_scores[doc_idx] += 1.0 / (rrf_k_constant + rank)
87
 
88
+ # --- 4. Metadata Boosting & Final Ranking ---
89
+ query_lower = query_text.lower()
90
+ matched_vehicle = _get_vehicle_type(query_lower)
91
+ final_results = []
92
 
93
+ for doc_idx in all_indices:
94
  try:
95
+ metadata = chunks_data[doc_idx].get('metadata', {})
96
+ final_score = rrf_scores[doc_idx]
 
 
 
 
 
97
 
98
+ # **LOGIC BOOSTING QUAN TRỌNG NHẤT**
99
+ if matched_vehicle:
100
+ article_title_lower = metadata.get("article_title", "").lower()
101
+ # Định nghĩa lại từ khóa bên trong để tránh phụ thuộc bên ngoài
102
+ vehicle_keywords = {
103
+ "ô tô": ["ô tô", "xe con"], "xe máy": ["xe máy", "xe mô tô"],
104
+ "xe đạp": ["xe đạp", "xe thô sơ"], "máy kéo": ["máy kéo", "xe chuyên dùng"]
105
+ }
106
+ if any(keyword in article_title_lower for keyword in vehicle_keywords.get(matched_vehicle, [])):
107
+ # Cộng một điểm thưởng rất lớn để đảm bảo nó được ưu tiên
108
+ final_score += 0.5
109
+
110
+ final_results.append({
111
+ 'index': doc_idx,
112
+ 'final_score': final_score
113
  })
114
+ except IndexError:
115
+ continue
116
+
117
+ final_results.sort(key=lambda x: x['final_score'], reverse=True)
118
+
119
+ # Lấy đầy đủ thông tin cho top-k kết quả cuối cùng
120
+ top_k_results = []
121
+ for res in final_results[:k]:
122
+ doc_idx = res['index']
123
+ top_k_results.append({
124
+ 'index': doc_idx,
125
+ 'final_score': res['final_score'],
126
+ 'text': chunks_data[doc_idx].get('text', ''),
127
+ 'metadata': chunks_data[doc_idx].get('metadata', {})
128
+ })
129
+
130
+ return top_k_results