deddoggo commited on
Commit
734d387
·
1 Parent(s): 5a1649d
Files changed (1) hide show
  1. rag_pipeline.py +21 -16
rag_pipeline.py CHANGED
@@ -82,13 +82,13 @@ def initialize_components(data_path):
82
 
83
  def generate_response(query: str, components: dict) -> str:
84
  """
85
- Tạo câu trả lời (single-turn).
86
- Phiên bản đơn giản hóa, không logic vehicle_type.
 
87
  """
88
  print("--- Bắt đầu quy trình RAG cho query mới ---")
89
 
90
- # === THAY ĐỔI 1: Chỉ nhận 1 giá trị trả về ===
91
- # 1. Truy xuất ngữ cảnh
92
  retrieved_results = search_relevant_laws(
93
  query_text=query,
94
  embedding_model=components["embedding_model"],
@@ -99,25 +99,24 @@ def generate_response(query: str, components: dict) -> str:
99
  initial_k_multiplier=15
100
  )
101
 
102
- # === THAY ĐỔI 2: Loại bỏ logic vehicle_type trong context ===
103
- # 2. Định dạng Context
104
  if not retrieved_results:
105
  context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
106
  else:
107
  context_parts = []
108
  for i, res in enumerate(retrieved_results):
109
  metadata = res.get('metadata', {})
110
- # Tạo header đơn giản, không có gợi ý
111
  header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
112
  text = res.get('text', '*Nội dung không có*')
113
  context_parts.append(f"{header}\n{text}")
114
  context = "\n\n---\n\n".join(context_parts)
115
 
116
- # 3. Xây dựng Prompt bằng Chat Template (giữ nguyên logic tương thích Vision)
117
- print("--- Xây dựng prompt bằng chat template ---")
118
  llm_model = components["llm_model"]
119
  tokenizer = components["tokenizer"]
120
 
 
121
  messages = [
122
  {
123
  "role": "system",
@@ -136,13 +135,17 @@ def generate_response(query: str, components: dict) -> str:
136
  }
137
  ]
138
 
139
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
140
-
141
- # 4. Tạo câu trả lời từ LLM
 
 
 
 
 
 
142
  print("--- Bắt đầu tạo câu trả lời từ LLM ---")
143
 
144
- inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
145
-
146
  generation_config = dict(
147
  max_new_tokens=256,
148
  temperature=0.1,
@@ -151,8 +154,10 @@ def generate_response(query: str, components: dict) -> str:
151
  pad_token_id=tokenizer.eos_token_id
152
  )
153
 
154
- output_ids = llm_model.generate(**inputs, **generation_config)
155
- response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
 
156
 
157
  print("--- Tạo câu trả lời hoàn tất ---")
158
  return response_text
 
82
 
83
  def generate_response(query: str, components: dict) -> str:
84
  """
85
+ Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
86
+ Phiên bản cuối cùng, sửa lỗi ValueError cho mô hình Vision bằng cách
87
+ sử dụng apply_chat_template để tokenization trực tiếp.
88
  """
89
  print("--- Bắt đầu quy trình RAG cho query mới ---")
90
 
91
+ # --- Bước 1: Truy xuất Ngữ cảnh ---
 
92
  retrieved_results = search_relevant_laws(
93
  query_text=query,
94
  embedding_model=components["embedding_model"],
 
99
  initial_k_multiplier=15
100
  )
101
 
102
+ # --- Bước 2: Định dạng Ngữ cảnh ---
 
103
  if not retrieved_results:
104
  context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
105
  else:
106
  context_parts = []
107
  for i, res in enumerate(retrieved_results):
108
  metadata = res.get('metadata', {})
 
109
  header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
110
  text = res.get('text', '*Nội dung không có*')
111
  context_parts.append(f"{header}\n{text}")
112
  context = "\n\n---\n\n".join(context_parts)
113
 
114
+ # --- Bước 3: Chuẩn bị Dữ liệu và Tokenize bằng Chat Template (Phần sửa lỗi cốt lõi) ---
115
+ print("--- Chuẩn bị và tokenize prompt bằng chat template ---")
116
  llm_model = components["llm_model"]
117
  tokenizer = components["tokenizer"]
118
 
119
+ # Tạo cấu trúc tin nhắn theo chuẩn
120
  messages = [
121
  {
122
  "role": "system",
 
135
  }
136
  ]
137
 
138
+ # SỬA LỖI: Dùng apply_chat_template để tokenize trực tiếp
139
+ # Nó sẽ tự động định dạng và chuyển thành tensor, tương thích với mô hình Vision
140
+ inputs = tokenizer.apply_chat_template(
141
+ messages,
142
+ return_tensors="pt",
143
+ add_generation_prompt=True,
144
+ ).to(llm_model.device)
145
+
146
+ # --- Bước 4: Tạo câu trả lời từ LLM ---
147
  print("--- Bắt đầu tạo câu trả lời từ LLM ---")
148
 
 
 
149
  generation_config = dict(
150
  max_new_tokens=256,
151
  temperature=0.1,
 
154
  pad_token_id=tokenizer.eos_token_id
155
  )
156
 
157
+ output_ids = llm_model.generate(inputs, **generation_config)
158
+
159
+ # Decode như cũ, nhưng đầu vào là `inputs` thay vì `inputs.input_ids`
160
+ response_text = tokenizer.decode(output_ids[0][inputs.shape[1]:], skip_special_tokens=True)
161
 
162
  print("--- Tạo câu trả lời hoàn tất ---")
163
  return response_text