Spaces:
Running
on
T4
Running
on
T4
- rag_pipeline.py +21 -16
rag_pipeline.py
CHANGED
@@ -82,13 +82,13 @@ def initialize_components(data_path):
|
|
82 |
|
83 |
def generate_response(query: str, components: dict) -> str:
|
84 |
"""
|
85 |
-
Tạo câu trả lời (single-turn).
|
86 |
-
Phiên bản
|
|
|
87 |
"""
|
88 |
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
89 |
|
90 |
-
#
|
91 |
-
# 1. Truy xuất ngữ cảnh
|
92 |
retrieved_results = search_relevant_laws(
|
93 |
query_text=query,
|
94 |
embedding_model=components["embedding_model"],
|
@@ -99,25 +99,24 @@ def generate_response(query: str, components: dict) -> str:
|
|
99 |
initial_k_multiplier=15
|
100 |
)
|
101 |
|
102 |
-
#
|
103 |
-
# 2. Định dạng Context
|
104 |
if not retrieved_results:
|
105 |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
106 |
else:
|
107 |
context_parts = []
|
108 |
for i, res in enumerate(retrieved_results):
|
109 |
metadata = res.get('metadata', {})
|
110 |
-
# Tạo header đơn giản, không có gợi ý
|
111 |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
112 |
text = res.get('text', '*Nội dung không có*')
|
113 |
context_parts.append(f"{header}\n{text}")
|
114 |
context = "\n\n---\n\n".join(context_parts)
|
115 |
|
116 |
-
# 3
|
117 |
-
print("---
|
118 |
llm_model = components["llm_model"]
|
119 |
tokenizer = components["tokenizer"]
|
120 |
|
|
|
121 |
messages = [
|
122 |
{
|
123 |
"role": "system",
|
@@ -136,13 +135,17 @@ def generate_response(query: str, components: dict) -> str:
|
|
136 |
}
|
137 |
]
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
143 |
|
144 |
-
inputs = tokenizer([prompt], return_tensors="pt").to(llm_model.device)
|
145 |
-
|
146 |
generation_config = dict(
|
147 |
max_new_tokens=256,
|
148 |
temperature=0.1,
|
@@ -151,8 +154,10 @@ def generate_response(query: str, components: dict) -> str:
|
|
151 |
pad_token_id=tokenizer.eos_token_id
|
152 |
)
|
153 |
|
154 |
-
output_ids = llm_model.generate(
|
155 |
-
|
|
|
|
|
156 |
|
157 |
print("--- Tạo câu trả lời hoàn tất ---")
|
158 |
return response_text
|
|
|
82 |
|
83 |
def generate_response(query: str, components: dict) -> str:
|
84 |
"""
|
85 |
+
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
|
86 |
+
Phiên bản cuối cùng, sửa lỗi ValueError cho mô hình Vision bằng cách
|
87 |
+
sử dụng apply_chat_template để tokenization trực tiếp.
|
88 |
"""
|
89 |
print("--- Bắt đầu quy trình RAG cho query mới ---")
|
90 |
|
91 |
+
# --- Bước 1: Truy xuất Ngữ cảnh ---
|
|
|
92 |
retrieved_results = search_relevant_laws(
|
93 |
query_text=query,
|
94 |
embedding_model=components["embedding_model"],
|
|
|
99 |
initial_k_multiplier=15
|
100 |
)
|
101 |
|
102 |
+
# --- Bước 2: Định dạng Ngữ cảnh ---
|
|
|
103 |
if not retrieved_results:
|
104 |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
|
105 |
else:
|
106 |
context_parts = []
|
107 |
for i, res in enumerate(retrieved_results):
|
108 |
metadata = res.get('metadata', {})
|
|
|
109 |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
|
110 |
text = res.get('text', '*Nội dung không có*')
|
111 |
context_parts.append(f"{header}\n{text}")
|
112 |
context = "\n\n---\n\n".join(context_parts)
|
113 |
|
114 |
+
# --- Bước 3: Chuẩn bị Dữ liệu và Tokenize bằng Chat Template (Phần sửa lỗi cốt lõi) ---
|
115 |
+
print("--- Chuẩn bị và tokenize prompt bằng chat template ---")
|
116 |
llm_model = components["llm_model"]
|
117 |
tokenizer = components["tokenizer"]
|
118 |
|
119 |
+
# Tạo cấu trúc tin nhắn theo chuẩn
|
120 |
messages = [
|
121 |
{
|
122 |
"role": "system",
|
|
|
135 |
}
|
136 |
]
|
137 |
|
138 |
+
# SỬA LỖI: Dùng apply_chat_template để tokenize trực tiếp
|
139 |
+
# Nó sẽ tự động định dạng và chuyển thành tensor, tương thích với mô hình Vision
|
140 |
+
inputs = tokenizer.apply_chat_template(
|
141 |
+
messages,
|
142 |
+
return_tensors="pt",
|
143 |
+
add_generation_prompt=True,
|
144 |
+
).to(llm_model.device)
|
145 |
+
|
146 |
+
# --- Bước 4: Tạo câu trả lời từ LLM ---
|
147 |
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
|
148 |
|
|
|
|
|
149 |
generation_config = dict(
|
150 |
max_new_tokens=256,
|
151 |
temperature=0.1,
|
|
|
154 |
pad_token_id=tokenizer.eos_token_id
|
155 |
)
|
156 |
|
157 |
+
output_ids = llm_model.generate(inputs, **generation_config)
|
158 |
+
|
159 |
+
# Decode như cũ, nhưng đầu vào là `inputs` thay vì `inputs.input_ids`
|
160 |
+
response_text = tokenizer.decode(output_ids[0][inputs.shape[1]:], skip_special_tokens=True)
|
161 |
|
162 |
print("--- Tạo câu trả lời hoàn tất ---")
|
163 |
return response_text
|