Spaces:
Sleeping
Sleeping
HaRin2806
commited on
Commit
·
76a8f20
1
Parent(s):
59a1c47
fix bug
Browse files- core/data_processor.py +10 -60
- core/embedding_model.py +251 -81
- core/rag_pipeline.py +52 -76
core/data_processor.py
CHANGED
@@ -5,7 +5,6 @@ import logging
|
|
5 |
import datetime
|
6 |
from typing import Dict, List, Any, Union, Tuple
|
7 |
|
8 |
-
# Cấu hình logging
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
class DataProcessor:
|
@@ -26,18 +25,14 @@ class DataProcessor:
|
|
26 |
"""Tải tất cả dữ liệu từ các thư mục con trong data"""
|
27 |
logger.info(f"Đang tải dữ liệu từ thư mục: {self.data_dir}")
|
28 |
|
29 |
-
# Quét qua tất cả thư mục trong data
|
30 |
for item in os.listdir(self.data_dir):
|
31 |
folder_path = os.path.join(self.data_dir, item)
|
32 |
|
33 |
-
# Kiểm tra xem đây có phải là thư mục không
|
34 |
if os.path.isdir(folder_path):
|
35 |
metadata_file = os.path.join(folder_path, "metadata.json")
|
36 |
|
37 |
-
# Nếu có file metadata.json
|
38 |
if os.path.exists(metadata_file):
|
39 |
try:
|
40 |
-
# Tải metadata
|
41 |
with open(metadata_file, 'r', encoding='utf-8') as f:
|
42 |
content = f.read()
|
43 |
if not content.strip():
|
@@ -45,7 +40,6 @@ class DataProcessor:
|
|
45 |
continue
|
46 |
folder_metadata = json.loads(content)
|
47 |
|
48 |
-
# Xác định ID của thư mục
|
49 |
folder_id = None
|
50 |
if "bai_info" in folder_metadata:
|
51 |
folder_id = folder_metadata["bai_info"].get("id", item)
|
@@ -54,10 +48,8 @@ class DataProcessor:
|
|
54 |
else:
|
55 |
folder_id = item
|
56 |
|
57 |
-
# Lưu metadata vào từ điển
|
58 |
self.metadata[folder_id] = folder_metadata
|
59 |
|
60 |
-
# Tải tất cả chunks, tables và figures
|
61 |
self._load_content_from_metadata(folder_path, folder_metadata)
|
62 |
|
63 |
logger.info(f"Đã tải xong thư mục: {item}")
|
@@ -68,33 +60,28 @@ class DataProcessor:
|
|
68 |
|
69 |
def _load_content_from_metadata(self, folder_path: str, folder_metadata: Dict[str, Any]):
|
70 |
"""Tải nội dung chunks, tables và figures từ metadata"""
|
71 |
-
# Tải chunks
|
72 |
for chunk_meta in folder_metadata.get("chunks", []):
|
73 |
chunk_id = chunk_meta.get("id")
|
74 |
chunk_path = os.path.join(folder_path, "chunks", f"{chunk_id}.md")
|
75 |
|
76 |
-
chunk_data = chunk_meta.copy()
|
77 |
|
78 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
79 |
if os.path.exists(chunk_path):
|
80 |
with open(chunk_path, 'r', encoding='utf-8') as f:
|
81 |
content = f.read()
|
82 |
chunk_data["content"] = self._extract_content_from_markdown(content)
|
83 |
else:
|
84 |
-
# Nếu không tìm thấy file, tạo nội dung mẫu và ghi log ở debug level
|
85 |
chunk_data["content"] = f"Nội dung cho {chunk_id} không tìm thấy."
|
86 |
logger.debug(f"Không tìm thấy file chunk: {chunk_path}")
|
87 |
|
88 |
self.chunks.append(chunk_data)
|
89 |
|
90 |
-
# Tải tables
|
91 |
for table_meta in folder_metadata.get("tables", []):
|
92 |
table_id = table_meta.get("id")
|
93 |
table_path = os.path.join(folder_path, "tables", f"{table_id}.md")
|
94 |
|
95 |
table_data = table_meta.copy()
|
96 |
|
97 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
98 |
if os.path.exists(table_path):
|
99 |
with open(table_path, 'r', encoding='utf-8') as f:
|
100 |
content = f.read()
|
@@ -105,13 +92,11 @@ class DataProcessor:
|
|
105 |
|
106 |
self.tables.append(table_data)
|
107 |
|
108 |
-
# Tải figures
|
109 |
for figure_meta in folder_metadata.get("figures", []):
|
110 |
figure_id = figure_meta.get("id")
|
111 |
figure_path = os.path.join(folder_path, "figures", f"{figure_id}.md")
|
112 |
figure_data = figure_meta.copy()
|
113 |
|
114 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
115 |
content_loaded = False
|
116 |
if os.path.exists(figure_path):
|
117 |
with open(figure_path, 'r', encoding='utf-8') as f:
|
@@ -119,7 +104,6 @@ class DataProcessor:
|
|
119 |
figure_data["content"] = self._extract_content_from_markdown(content)
|
120 |
content_loaded = True
|
121 |
|
122 |
-
# Thêm đường dẫn đến file hình ảnh nếu có
|
123 |
image_path = None
|
124 |
image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.svg']
|
125 |
for ext in image_extensions:
|
@@ -130,18 +114,15 @@ class DataProcessor:
|
|
130 |
|
131 |
if image_path:
|
132 |
figure_data["image_path"] = image_path
|
133 |
-
# Tạo nội dung mặc định nếu không có file markdown
|
134 |
if not content_loaded:
|
135 |
figure_caption = figure_meta.get("title", f"Hình {figure_id}")
|
136 |
figure_data["content"] = f""
|
137 |
elif not content_loaded:
|
138 |
-
# Nếu không có cả file markdown và file hình
|
139 |
figure_data["content"] = f"Hình {figure_id} không tìm thấy."
|
140 |
logger.debug(f"Không tìm thấy file hình cho {figure_id}")
|
141 |
|
142 |
self.figures.append(figure_data)
|
143 |
|
144 |
-
# Tải data_files (trường hợp phụ lục)
|
145 |
if "data_files" in folder_metadata:
|
146 |
for data_file_meta in folder_metadata.get("data_files", []):
|
147 |
data_id = data_file_meta.get("id")
|
@@ -149,16 +130,13 @@ class DataProcessor:
|
|
149 |
|
150 |
data_file = data_file_meta.copy()
|
151 |
|
152 |
-
# Thêm nội dung từ file markdown nếu tồn tại
|
153 |
if os.path.exists(data_path):
|
154 |
with open(data_path, 'r', encoding='utf-8') as f:
|
155 |
content = f.read()
|
156 |
data_file["content"] = self._extract_content_from_markdown(content)
|
157 |
|
158 |
-
# Xác định loại nội dung
|
159 |
content_type = data_file.get("content_type", "table")
|
160 |
|
161 |
-
# Thêm vào danh sách phù hợp dựa trên loại nội dung
|
162 |
if content_type == "table":
|
163 |
self.tables.append(data_file)
|
164 |
elif content_type == "text":
|
@@ -173,7 +151,6 @@ class DataProcessor:
|
|
173 |
|
174 |
def _extract_content_from_markdown(self, md_content: str) -> str:
|
175 |
"""Trích xuất nội dung từ markdown, bỏ qua phần frontmatter"""
|
176 |
-
# Tách frontmatter (nằm giữa "---")
|
177 |
if md_content.startswith("---"):
|
178 |
parts = md_content.split("---", 2)
|
179 |
if len(parts) >= 3:
|
@@ -214,26 +191,23 @@ class DataProcessor:
|
|
214 |
return None
|
215 |
|
216 |
def find_items_by_age(self, age: int) -> Dict[str, List[Dict[str, Any]]]:
|
217 |
-
"""Tìm các items
|
218 |
relevant_chunks = []
|
219 |
relevant_tables = []
|
220 |
relevant_figures = []
|
221 |
|
222 |
-
# Lọc chunks
|
223 |
for chunk in self.chunks:
|
224 |
-
age_range = chunk.get("age_range", [0,
|
225 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
226 |
relevant_chunks.append(chunk)
|
227 |
|
228 |
-
# Lọc tables
|
229 |
for table in self.tables:
|
230 |
-
age_range = table.get("age_range", [0,
|
231 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
232 |
relevant_tables.append(table)
|
233 |
|
234 |
-
# Lọc figures
|
235 |
for figure in self.figures:
|
236 |
-
age_range = figure.get("age_range", [0,
|
237 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
238 |
relevant_figures.append(figure)
|
239 |
|
@@ -249,7 +223,6 @@ class DataProcessor:
|
|
249 |
related_tables = []
|
250 |
related_figures = []
|
251 |
|
252 |
-
# Tìm item gốc
|
253 |
source_item = None
|
254 |
for item in self.chunks + self.tables + self.figures:
|
255 |
if item.get("id") == item_id:
|
@@ -263,24 +236,19 @@ class DataProcessor:
|
|
263 |
"figures": []
|
264 |
}
|
265 |
|
266 |
-
# Lấy danh sách IDs của các items liên quan
|
267 |
related_ids = source_item.get("related_chunks", [])
|
268 |
|
269 |
-
# Tìm các items liên quan
|
270 |
for related_id in related_ids:
|
271 |
-
# Tìm trong chunks
|
272 |
for chunk in self.chunks:
|
273 |
if chunk.get("id") == related_id:
|
274 |
related_chunks.append(chunk)
|
275 |
break
|
276 |
|
277 |
-
# Tìm trong tables
|
278 |
for table in self.tables:
|
279 |
if table.get("id") == related_id:
|
280 |
related_tables.append(table)
|
281 |
break
|
282 |
|
283 |
-
# Tìm trong figures
|
284 |
for figure in self.figures:
|
285 |
if figure.get("id") == related_id:
|
286 |
related_figures.append(figure)
|
@@ -294,9 +262,7 @@ class DataProcessor:
|
|
294 |
|
295 |
def preprocess_query(self, query: str) -> str:
|
296 |
"""Tiền xử lý câu truy vấn"""
|
297 |
-
# Loại bỏ ký tự đặc biệt
|
298 |
query = re.sub(r'[^\w\s\d]', ' ', query)
|
299 |
-
# Loại bỏ khoảng trắng thừa
|
300 |
query = re.sub(r'\s+', ' ', query).strip()
|
301 |
return query
|
302 |
|
@@ -310,10 +276,8 @@ class DataProcessor:
|
|
310 |
content = item.get("content", "")
|
311 |
content_type = item.get("content_type", "text")
|
312 |
|
313 |
-
# Nếu là bảng, thêm tiêu đề "B��ng:"
|
314 |
if content_type == "table":
|
315 |
title = f"Bảng: {title}"
|
316 |
-
# Nếu là hình, thêm tiêu đề "Hình:"
|
317 |
elif content_type == "figure":
|
318 |
title = f"Hình: {title}"
|
319 |
|
@@ -326,9 +290,7 @@ class DataProcessor:
|
|
326 |
"""Chuẩn bị dữ liệu cho việc nhúng (embedding)"""
|
327 |
all_items = []
|
328 |
|
329 |
-
# Thêm chunks
|
330 |
for chunk in self.chunks:
|
331 |
-
# Tìm chapter từ chunk ID
|
332 |
chunk_id = chunk.get("id", "")
|
333 |
chapter = "unknown"
|
334 |
if chunk_id.startswith("bai1_"):
|
@@ -346,13 +308,11 @@ class DataProcessor:
|
|
346 |
if chunk.get("title"):
|
347 |
content = f"Tiêu đề: {chunk.get('title')}\n\nNội dung: {content}"
|
348 |
|
349 |
-
|
350 |
-
age_range = chunk.get("age_range", [0, 100])
|
351 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
352 |
-
age_max = age_range[1] if len(age_range) > 1 else
|
353 |
age_range_str = f"{age_min}-{age_max}"
|
354 |
|
355 |
-
# Xử lý related_chunks - convert list thành string
|
356 |
related_chunks = chunk.get("related_chunks", [])
|
357 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
358 |
|
@@ -379,9 +339,7 @@ class DataProcessor:
|
|
379 |
}
|
380 |
all_items.append(embedding_item)
|
381 |
|
382 |
-
# Thêm tables
|
383 |
for table in self.tables:
|
384 |
-
# Tìm chapter từ table ID
|
385 |
table_id = table.get("id", "")
|
386 |
chapter = "unknown"
|
387 |
if table_id.startswith("bai1_"):
|
@@ -399,13 +357,11 @@ class DataProcessor:
|
|
399 |
if table.get("title"):
|
400 |
content = f"Bảng: {table.get('title')}\n\nNội dung: {content}"
|
401 |
|
402 |
-
|
403 |
-
age_range = table.get("age_range", [0, 100])
|
404 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
405 |
-
age_max = age_range[1] if len(age_range) > 1 else
|
406 |
age_range_str = f"{age_min}-{age_max}"
|
407 |
|
408 |
-
# Xử lý related_chunks và table_columns
|
409 |
related_chunks = table.get("related_chunks", [])
|
410 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
411 |
table_columns = table.get("table_columns", [])
|
@@ -433,9 +389,7 @@ class DataProcessor:
|
|
433 |
}
|
434 |
all_items.append(embedding_item)
|
435 |
|
436 |
-
# Thêm figures
|
437 |
for figure in self.figures:
|
438 |
-
# Tìm chapter từ figure ID
|
439 |
figure_id = figure.get("id", "")
|
440 |
chapter = "unknown"
|
441 |
if figure_id.startswith("bai1_"):
|
@@ -453,13 +407,11 @@ class DataProcessor:
|
|
453 |
if figure.get("title"):
|
454 |
content = f"Hình: {figure.get('title')}\n\nMô tả: {content}"
|
455 |
|
456 |
-
# Xử lý age_range
|
457 |
age_range = figure.get("age_range", [0, 100])
|
458 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
459 |
age_max = age_range[1] if len(age_range) > 1 else 100
|
460 |
age_range_str = f"{age_min}-{age_max}"
|
461 |
|
462 |
-
# Xử lý related_chunks
|
463 |
related_chunks = figure.get("related_chunks", [])
|
464 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
465 |
|
@@ -509,16 +461,14 @@ class DataProcessor:
|
|
509 |
"by_age": {}
|
510 |
}
|
511 |
|
512 |
-
# Thống kê theo bài
|
513 |
for item in os.listdir(self.data_dir):
|
514 |
if os.path.isdir(os.path.join(self.data_dir, item)):
|
515 |
item_stats = self.count_items_by_prefix(f"{item}_")
|
516 |
stats["by_lesson"][item] = item_stats
|
517 |
|
518 |
-
# Thống kê theo độ tuổi
|
519 |
age_ranges = {}
|
520 |
for chunk in self.chunks + self.tables + self.figures:
|
521 |
-
age_range = chunk.get("age_range", [0,
|
522 |
if len(age_range) == 2:
|
523 |
range_key = f"{age_range[0]}-{age_range[1]}"
|
524 |
if range_key not in age_ranges:
|
|
|
5 |
import datetime
|
6 |
from typing import Dict, List, Any, Union, Tuple
|
7 |
|
|
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
class DataProcessor:
|
|
|
25 |
"""Tải tất cả dữ liệu từ các thư mục con trong data"""
|
26 |
logger.info(f"Đang tải dữ liệu từ thư mục: {self.data_dir}")
|
27 |
|
|
|
28 |
for item in os.listdir(self.data_dir):
|
29 |
folder_path = os.path.join(self.data_dir, item)
|
30 |
|
|
|
31 |
if os.path.isdir(folder_path):
|
32 |
metadata_file = os.path.join(folder_path, "metadata.json")
|
33 |
|
|
|
34 |
if os.path.exists(metadata_file):
|
35 |
try:
|
|
|
36 |
with open(metadata_file, 'r', encoding='utf-8') as f:
|
37 |
content = f.read()
|
38 |
if not content.strip():
|
|
|
40 |
continue
|
41 |
folder_metadata = json.loads(content)
|
42 |
|
|
|
43 |
folder_id = None
|
44 |
if "bai_info" in folder_metadata:
|
45 |
folder_id = folder_metadata["bai_info"].get("id", item)
|
|
|
48 |
else:
|
49 |
folder_id = item
|
50 |
|
|
|
51 |
self.metadata[folder_id] = folder_metadata
|
52 |
|
|
|
53 |
self._load_content_from_metadata(folder_path, folder_metadata)
|
54 |
|
55 |
logger.info(f"Đã tải xong thư mục: {item}")
|
|
|
60 |
|
61 |
def _load_content_from_metadata(self, folder_path: str, folder_metadata: Dict[str, Any]):
|
62 |
"""Tải nội dung chunks, tables và figures từ metadata"""
|
|
|
63 |
for chunk_meta in folder_metadata.get("chunks", []):
|
64 |
chunk_id = chunk_meta.get("id")
|
65 |
chunk_path = os.path.join(folder_path, "chunks", f"{chunk_id}.md")
|
66 |
|
67 |
+
chunk_data = chunk_meta.copy()
|
68 |
|
|
|
69 |
if os.path.exists(chunk_path):
|
70 |
with open(chunk_path, 'r', encoding='utf-8') as f:
|
71 |
content = f.read()
|
72 |
chunk_data["content"] = self._extract_content_from_markdown(content)
|
73 |
else:
|
|
|
74 |
chunk_data["content"] = f"Nội dung cho {chunk_id} không tìm thấy."
|
75 |
logger.debug(f"Không tìm thấy file chunk: {chunk_path}")
|
76 |
|
77 |
self.chunks.append(chunk_data)
|
78 |
|
|
|
79 |
for table_meta in folder_metadata.get("tables", []):
|
80 |
table_id = table_meta.get("id")
|
81 |
table_path = os.path.join(folder_path, "tables", f"{table_id}.md")
|
82 |
|
83 |
table_data = table_meta.copy()
|
84 |
|
|
|
85 |
if os.path.exists(table_path):
|
86 |
with open(table_path, 'r', encoding='utf-8') as f:
|
87 |
content = f.read()
|
|
|
92 |
|
93 |
self.tables.append(table_data)
|
94 |
|
|
|
95 |
for figure_meta in folder_metadata.get("figures", []):
|
96 |
figure_id = figure_meta.get("id")
|
97 |
figure_path = os.path.join(folder_path, "figures", f"{figure_id}.md")
|
98 |
figure_data = figure_meta.copy()
|
99 |
|
|
|
100 |
content_loaded = False
|
101 |
if os.path.exists(figure_path):
|
102 |
with open(figure_path, 'r', encoding='utf-8') as f:
|
|
|
104 |
figure_data["content"] = self._extract_content_from_markdown(content)
|
105 |
content_loaded = True
|
106 |
|
|
|
107 |
image_path = None
|
108 |
image_extensions = ['.png', '.jpg', '.jpeg', '.gif', '.svg']
|
109 |
for ext in image_extensions:
|
|
|
114 |
|
115 |
if image_path:
|
116 |
figure_data["image_path"] = image_path
|
|
|
117 |
if not content_loaded:
|
118 |
figure_caption = figure_meta.get("title", f"Hình {figure_id}")
|
119 |
figure_data["content"] = f""
|
120 |
elif not content_loaded:
|
|
|
121 |
figure_data["content"] = f"Hình {figure_id} không tìm thấy."
|
122 |
logger.debug(f"Không tìm thấy file hình cho {figure_id}")
|
123 |
|
124 |
self.figures.append(figure_data)
|
125 |
|
|
|
126 |
if "data_files" in folder_metadata:
|
127 |
for data_file_meta in folder_metadata.get("data_files", []):
|
128 |
data_id = data_file_meta.get("id")
|
|
|
130 |
|
131 |
data_file = data_file_meta.copy()
|
132 |
|
|
|
133 |
if os.path.exists(data_path):
|
134 |
with open(data_path, 'r', encoding='utf-8') as f:
|
135 |
content = f.read()
|
136 |
data_file["content"] = self._extract_content_from_markdown(content)
|
137 |
|
|
|
138 |
content_type = data_file.get("content_type", "table")
|
139 |
|
|
|
140 |
if content_type == "table":
|
141 |
self.tables.append(data_file)
|
142 |
elif content_type == "text":
|
|
|
151 |
|
152 |
def _extract_content_from_markdown(self, md_content: str) -> str:
|
153 |
"""Trích xuất nội dung từ markdown, bỏ qua phần frontmatter"""
|
|
|
154 |
if md_content.startswith("---"):
|
155 |
parts = md_content.split("---", 2)
|
156 |
if len(parts) >= 3:
|
|
|
191 |
return None
|
192 |
|
193 |
def find_items_by_age(self, age: int) -> Dict[str, List[Dict[str, Any]]]:
|
194 |
+
"""Tìm các items liên quan đến độ tuổi của người dùng"""
|
195 |
relevant_chunks = []
|
196 |
relevant_tables = []
|
197 |
relevant_figures = []
|
198 |
|
|
|
199 |
for chunk in self.chunks:
|
200 |
+
age_range = chunk.get("age_range", [0, 19])
|
201 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
202 |
relevant_chunks.append(chunk)
|
203 |
|
|
|
204 |
for table in self.tables:
|
205 |
+
age_range = table.get("age_range", [0, 19])
|
206 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
207 |
relevant_tables.append(table)
|
208 |
|
|
|
209 |
for figure in self.figures:
|
210 |
+
age_range = figure.get("age_range", [0, 19])
|
211 |
if len(age_range) == 2 and age_range[0] <= age <= age_range[1]:
|
212 |
relevant_figures.append(figure)
|
213 |
|
|
|
223 |
related_tables = []
|
224 |
related_figures = []
|
225 |
|
|
|
226 |
source_item = None
|
227 |
for item in self.chunks + self.tables + self.figures:
|
228 |
if item.get("id") == item_id:
|
|
|
236 |
"figures": []
|
237 |
}
|
238 |
|
|
|
239 |
related_ids = source_item.get("related_chunks", [])
|
240 |
|
|
|
241 |
for related_id in related_ids:
|
|
|
242 |
for chunk in self.chunks:
|
243 |
if chunk.get("id") == related_id:
|
244 |
related_chunks.append(chunk)
|
245 |
break
|
246 |
|
|
|
247 |
for table in self.tables:
|
248 |
if table.get("id") == related_id:
|
249 |
related_tables.append(table)
|
250 |
break
|
251 |
|
|
|
252 |
for figure in self.figures:
|
253 |
if figure.get("id") == related_id:
|
254 |
related_figures.append(figure)
|
|
|
262 |
|
263 |
def preprocess_query(self, query: str) -> str:
|
264 |
"""Tiền xử lý câu truy vấn"""
|
|
|
265 |
query = re.sub(r'[^\w\s\d]', ' ', query)
|
|
|
266 |
query = re.sub(r'\s+', ' ', query).strip()
|
267 |
return query
|
268 |
|
|
|
276 |
content = item.get("content", "")
|
277 |
content_type = item.get("content_type", "text")
|
278 |
|
|
|
279 |
if content_type == "table":
|
280 |
title = f"Bảng: {title}"
|
|
|
281 |
elif content_type == "figure":
|
282 |
title = f"Hình: {title}"
|
283 |
|
|
|
290 |
"""Chuẩn bị dữ liệu cho việc nhúng (embedding)"""
|
291 |
all_items = []
|
292 |
|
|
|
293 |
for chunk in self.chunks:
|
|
|
294 |
chunk_id = chunk.get("id", "")
|
295 |
chapter = "unknown"
|
296 |
if chunk_id.startswith("bai1_"):
|
|
|
308 |
if chunk.get("title"):
|
309 |
content = f"Tiêu đề: {chunk.get('title')}\n\nNội dung: {content}"
|
310 |
|
311 |
+
age_range = chunk.get("age_range", [0, 19])
|
|
|
312 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
313 |
+
age_max = age_range[1] if len(age_range) > 1 else 19
|
314 |
age_range_str = f"{age_min}-{age_max}"
|
315 |
|
|
|
316 |
related_chunks = chunk.get("related_chunks", [])
|
317 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
318 |
|
|
|
339 |
}
|
340 |
all_items.append(embedding_item)
|
341 |
|
|
|
342 |
for table in self.tables:
|
|
|
343 |
table_id = table.get("id", "")
|
344 |
chapter = "unknown"
|
345 |
if table_id.startswith("bai1_"):
|
|
|
357 |
if table.get("title"):
|
358 |
content = f"Bảng: {table.get('title')}\n\nNội dung: {content}"
|
359 |
|
360 |
+
age_range = table.get("age_range", [0, 19])
|
|
|
361 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
362 |
+
age_max = age_range[1] if len(age_range) > 1 else 19
|
363 |
age_range_str = f"{age_min}-{age_max}"
|
364 |
|
|
|
365 |
related_chunks = table.get("related_chunks", [])
|
366 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
367 |
table_columns = table.get("table_columns", [])
|
|
|
389 |
}
|
390 |
all_items.append(embedding_item)
|
391 |
|
|
|
392 |
for figure in self.figures:
|
|
|
393 |
figure_id = figure.get("id", "")
|
394 |
chapter = "unknown"
|
395 |
if figure_id.startswith("bai1_"):
|
|
|
407 |
if figure.get("title"):
|
408 |
content = f"Hình: {figure.get('title')}\n\nMô tả: {content}"
|
409 |
|
|
|
410 |
age_range = figure.get("age_range", [0, 100])
|
411 |
age_min = age_range[0] if len(age_range) > 0 else 0
|
412 |
age_max = age_range[1] if len(age_range) > 1 else 100
|
413 |
age_range_str = f"{age_min}-{age_max}"
|
414 |
|
|
|
415 |
related_chunks = figure.get("related_chunks", [])
|
416 |
related_chunks_str = ",".join(related_chunks) if related_chunks else ""
|
417 |
|
|
|
461 |
"by_age": {}
|
462 |
}
|
463 |
|
|
|
464 |
for item in os.listdir(self.data_dir):
|
465 |
if os.path.isdir(os.path.join(self.data_dir, item)):
|
466 |
item_stats = self.count_items_by_prefix(f"{item}_")
|
467 |
stats["by_lesson"][item] = item_stats
|
468 |
|
|
|
469 |
age_ranges = {}
|
470 |
for chunk in self.chunks + self.tables + self.figures:
|
471 |
+
age_range = chunk.get("age_range", [0, 19])
|
472 |
if len(age_range) == 2:
|
473 |
range_key = f"{age_range[0]}-{age_range[1]}"
|
474 |
if range_key not in age_ranges:
|
core/embedding_model.py
CHANGED
@@ -6,16 +6,12 @@ import uuid
|
|
6 |
import os
|
7 |
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME
|
8 |
|
9 |
-
# Cấu hình logging
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
-
# Global instance để implement singleton pattern
|
13 |
_embedding_model_instance = None
|
14 |
|
15 |
def get_embedding_model():
|
16 |
-
"""
|
17 |
-
Singleton pattern để đảm bảo chỉ có một instance của EmbeddingModel
|
18 |
-
"""
|
19 |
global _embedding_model_instance
|
20 |
if _embedding_model_instance is None:
|
21 |
logger.info("Khởi tạo EmbeddingModel instance lần đầu")
|
@@ -40,52 +36,199 @@ class EmbeddingModel:
|
|
40 |
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True)
|
41 |
logger.info("Đã tải sentence transformer model với cache folder explicit")
|
42 |
|
|
|
|
|
|
|
43 |
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi
|
44 |
try:
|
45 |
-
os.makedirs(
|
46 |
# Test ghi file để kiểm tra permission
|
47 |
-
test_file = os.path.join(
|
48 |
with open(test_file, 'w') as f:
|
49 |
f.write('test')
|
50 |
os.remove(test_file)
|
51 |
-
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {
|
52 |
except Exception as e:
|
53 |
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}")
|
54 |
# Fallback to /tmp directory
|
55 |
import tempfile
|
56 |
-
|
57 |
-
os.makedirs(
|
58 |
-
logger.warning(f"Sử dụng thư mục tạm thời: {
|
59 |
|
60 |
# Khởi tạo ChromaDB client với persistent storage
|
61 |
try:
|
62 |
self.chroma_client = chromadb.PersistentClient(
|
63 |
-
path=
|
64 |
settings=Settings(
|
65 |
anonymized_telemetry=False,
|
66 |
allow_reset=True
|
67 |
)
|
68 |
)
|
69 |
-
logger.info(f"Đã kết nối ChromaDB tại: {
|
70 |
except Exception as e:
|
71 |
logger.error(f"Lỗi kết nối ChromaDB: {e}")
|
72 |
# Fallback to in-memory client
|
73 |
logger.warning("Fallback to in-memory ChromaDB client")
|
74 |
self.chroma_client = chromadb.Client()
|
75 |
|
76 |
-
# Lấy hoặc tạo collection
|
77 |
try:
|
78 |
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
|
79 |
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items")
|
80 |
except Exception:
|
81 |
-
logger.
|
82 |
-
self.collection = self.chroma_client.create_collection(
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def _add_prefix_to_text(self, text, is_query=True):
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
# Kiểm tra xem text đã có prefix chưa
|
90 |
if text.startswith(('query:', 'passage:')):
|
91 |
return text
|
@@ -98,24 +241,32 @@ class EmbeddingModel:
|
|
98 |
|
99 |
def encode(self, texts, is_query=True):
|
100 |
"""
|
101 |
-
Encode văn bản thành embeddings
|
102 |
-
|
103 |
-
Args:
|
104 |
-
texts (str or list): Văn bản hoặc danh sách văn bản cần encode
|
105 |
-
is_query (bool): True nếu là query, False nếu là passage
|
106 |
-
|
107 |
-
Returns:
|
108 |
-
list: Embeddings vector
|
109 |
"""
|
110 |
try:
|
111 |
if isinstance(texts, str):
|
112 |
texts = [texts]
|
113 |
|
114 |
-
# Thêm prefix cho texts
|
115 |
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts]
|
116 |
|
117 |
-
logger.debug(f"Đang encode {len(processed_texts)} văn bản")
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
return embeddings.tolist()
|
121 |
|
@@ -124,24 +275,10 @@ class EmbeddingModel:
|
|
124 |
raise
|
125 |
|
126 |
def search(self, query, top_k=5, age_filter=None):
|
127 |
-
"""
|
128 |
-
Tìm kiếm văn bản tương tự trong ChromaDB
|
129 |
-
|
130 |
-
Args:
|
131 |
-
query (str): Câu hỏi cần tìm kiếm
|
132 |
-
top_k (int): Số lượng kết quả trả về
|
133 |
-
age_filter (int): Lọc theo độ tuổi (optional)
|
134 |
-
|
135 |
-
Returns:
|
136 |
-
list: Danh sách kết quả tìm kiếm
|
137 |
-
"""
|
138 |
try:
|
139 |
-
logger.debug(f"Dang tim kiem cho query: {query[:50]}...")
|
140 |
-
|
141 |
-
# Encode query thành embedding (với prefix query:)
|
142 |
query_embedding = self.encode(query, is_query=True)[0]
|
143 |
|
144 |
-
# Tạo where clause cho age filter
|
145 |
where_clause = None
|
146 |
if age_filter:
|
147 |
where_clause = {
|
@@ -150,34 +287,53 @@ class EmbeddingModel:
|
|
150 |
{"age_max": {"$gte": age_filter}}
|
151 |
]
|
152 |
}
|
153 |
-
|
154 |
-
|
|
|
|
|
155 |
search_results = self.collection.query(
|
156 |
query_embeddings=[query_embedding],
|
157 |
n_results=top_k,
|
158 |
where=where_clause,
|
159 |
include=['documents', 'metadatas', 'distances']
|
160 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
if not search_results or not search_results['documents']:
|
163 |
-
logger.warning("
|
164 |
return []
|
165 |
|
166 |
-
# Format kết quả
|
167 |
results = []
|
168 |
documents = search_results['documents'][0]
|
169 |
metadatas = search_results['metadatas'][0]
|
170 |
distances = search_results['distances'][0]
|
171 |
|
172 |
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
results.append({
|
174 |
'document': doc,
|
175 |
'metadata': metadata or {},
|
176 |
'distance': distance,
|
177 |
-
'similarity':
|
178 |
'rank': i + 1
|
179 |
})
|
180 |
|
|
|
181 |
logger.info(f"Tim thay {len(results)} ket qua cho query")
|
182 |
return results
|
183 |
|
@@ -186,36 +342,22 @@ class EmbeddingModel:
|
|
186 |
return []
|
187 |
|
188 |
def add_documents(self, documents, metadatas=None, ids=None):
|
189 |
-
"""
|
190 |
-
Thêm documents vào ChromaDB
|
191 |
-
|
192 |
-
Args:
|
193 |
-
documents (list): Danh sách văn bản
|
194 |
-
metadatas (list): Danh sách metadata tương ứng
|
195 |
-
ids (list): Danh sách ID tương ứng (optional)
|
196 |
-
|
197 |
-
Returns:
|
198 |
-
bool: True nếu thành công
|
199 |
-
"""
|
200 |
try:
|
201 |
if not documents:
|
202 |
logger.warning("Không có documents để thêm")
|
203 |
return False
|
204 |
|
205 |
-
# Tạo IDs nếu không được cung cấp
|
206 |
if not ids:
|
207 |
ids = [str(uuid.uuid4()) for _ in documents]
|
208 |
|
209 |
-
# Tạo metadatas rỗng nếu không được cung cấp
|
210 |
if not metadatas:
|
211 |
metadatas = [{} for _ in documents]
|
212 |
|
213 |
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB")
|
214 |
|
215 |
-
# Encode documents thành embeddings (với prefix passage:)
|
216 |
embeddings = self.encode(documents, is_query=False)
|
217 |
|
218 |
-
# Thêm vào collection
|
219 |
self.collection.add(
|
220 |
embeddings=embeddings,
|
221 |
documents=documents,
|
@@ -231,9 +373,7 @@ class EmbeddingModel:
|
|
231 |
return False
|
232 |
|
233 |
def index_chunks(self, chunks):
|
234 |
-
"""
|
235 |
-
Index các chunks dữ liệu vào ChromaDB
|
236 |
-
"""
|
237 |
try:
|
238 |
if not chunks:
|
239 |
logger.warning("Không có chunks để index")
|
@@ -250,11 +390,9 @@ class EmbeddingModel:
|
|
250 |
|
251 |
documents.append(chunk['content'])
|
252 |
|
253 |
-
# Lấy metadata đã được chuẩn bị sẵn
|
254 |
metadata = chunk.get('metadata', {})
|
255 |
metadatas.append(metadata)
|
256 |
|
257 |
-
# Sử dụng ID có sẵn hoặc tạo mới
|
258 |
chunk_id = chunk.get('id') or str(uuid.uuid4())
|
259 |
ids.append(chunk_id)
|
260 |
|
@@ -262,7 +400,6 @@ class EmbeddingModel:
|
|
262 |
logger.warning("Không có documents hợp lệ để index")
|
263 |
return False
|
264 |
|
265 |
-
# Batch processing để tránh overload
|
266 |
batch_size = 100
|
267 |
total_batches = (len(documents) + batch_size - 1) // batch_size
|
268 |
|
@@ -300,9 +437,9 @@ class EmbeddingModel:
|
|
300 |
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}")
|
301 |
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
302 |
|
303 |
-
# Tạo lại collection
|
304 |
-
self.
|
305 |
-
logger.info("Đã tạo lại collection mới")
|
306 |
|
307 |
return True
|
308 |
|
@@ -310,15 +447,49 @@ class EmbeddingModel:
|
|
310 |
logger.error(f"Lỗi xóa collection: {e}")
|
311 |
return False
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
def get_stats(self):
|
314 |
"""Lấy thống kê về collection"""
|
315 |
try:
|
316 |
total_count = self.count()
|
|
|
317 |
|
318 |
-
# Lấy sample để phân tích metadata
|
319 |
sample_results = self.collection.get(limit=min(100, total_count))
|
320 |
|
321 |
-
# Thống kê content types
|
322 |
content_types = {}
|
323 |
chapters = {}
|
324 |
age_groups = {}
|
@@ -328,15 +499,12 @@ class EmbeddingModel:
|
|
328 |
if not metadata:
|
329 |
continue
|
330 |
|
331 |
-
# Content type stats
|
332 |
content_type = metadata.get('content_type', 'unknown')
|
333 |
content_types[content_type] = content_types.get(content_type, 0) + 1
|
334 |
|
335 |
-
# Chapter stats
|
336 |
chapter = metadata.get('chapter', 'unknown')
|
337 |
chapters[chapter] = chapters.get(chapter, 0) + 1
|
338 |
|
339 |
-
# Age group stats
|
340 |
age_group = metadata.get('age_group', 'unknown')
|
341 |
age_groups[age_group] = age_groups.get(age_group, 0) + 1
|
342 |
|
@@ -346,7 +514,9 @@ class EmbeddingModel:
|
|
346 |
'chapters': chapters,
|
347 |
'age_groups': age_groups,
|
348 |
'collection_name': COLLECTION_NAME,
|
349 |
-
'embedding_model': EMBEDDING_MODEL
|
|
|
|
|
350 |
}
|
351 |
|
352 |
except Exception as e:
|
|
|
6 |
import os
|
7 |
from config import EMBEDDING_MODEL, CHROMA_PERSIST_DIRECTORY, COLLECTION_NAME
|
8 |
|
|
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
11 |
_embedding_model_instance = None
|
12 |
|
13 |
def get_embedding_model():
|
14 |
+
"""Kiểm tra và khởi tạo embedding đảm bảo chỉ khởi tạo một lần"""
|
|
|
|
|
15 |
global _embedding_model_instance
|
16 |
if _embedding_model_instance is None:
|
17 |
logger.info("Khởi tạo EmbeddingModel instance lần đầu")
|
|
|
36 |
self.model = SentenceTransformer(EMBEDDING_MODEL, cache_folder=cache_dir, trust_remote_code=True)
|
37 |
logger.info("Đã tải sentence transformer model với cache folder explicit")
|
38 |
|
39 |
+
# SỬA: Khai báo biến persist_directory local để tránh lỗi scope
|
40 |
+
persist_directory = CHROMA_PERSIST_DIRECTORY
|
41 |
+
|
42 |
# Đảm bảo thư mục ChromaDB tồn tại và có quyền ghi
|
43 |
try:
|
44 |
+
os.makedirs(persist_directory, exist_ok=True)
|
45 |
# Test ghi file để kiểm tra permission
|
46 |
+
test_file = os.path.join(persist_directory, 'test_permission.tmp')
|
47 |
with open(test_file, 'w') as f:
|
48 |
f.write('test')
|
49 |
os.remove(test_file)
|
50 |
+
logger.info(f"Thư mục ChromaDB đã sẵn sàng: {persist_directory}")
|
51 |
except Exception as e:
|
52 |
logger.error(f"Lỗi tạo/kiểm tra thư mục ChromaDB: {e}")
|
53 |
# Fallback to /tmp directory
|
54 |
import tempfile
|
55 |
+
persist_directory = os.path.join(tempfile.gettempdir(), 'chroma_db')
|
56 |
+
os.makedirs(persist_directory, exist_ok=True)
|
57 |
+
logger.warning(f"Sử dụng thư mục tạm thời: {persist_directory}")
|
58 |
|
59 |
# Khởi tạo ChromaDB client với persistent storage
|
60 |
try:
|
61 |
self.chroma_client = chromadb.PersistentClient(
|
62 |
+
path=persist_directory,
|
63 |
settings=Settings(
|
64 |
anonymized_telemetry=False,
|
65 |
allow_reset=True
|
66 |
)
|
67 |
)
|
68 |
+
logger.info(f"Đã kết nối ChromaDB tại: {persist_directory}")
|
69 |
except Exception as e:
|
70 |
logger.error(f"Lỗi kết nối ChromaDB: {e}")
|
71 |
# Fallback to in-memory client
|
72 |
logger.warning("Fallback to in-memory ChromaDB client")
|
73 |
self.chroma_client = chromadb.Client()
|
74 |
|
75 |
+
# Lấy hoặc tạo collection với cosine similarity
|
76 |
try:
|
77 |
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
|
78 |
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với {self.collection.count()} items")
|
79 |
except Exception:
|
80 |
+
logger.info(f"Collection '{COLLECTION_NAME}' không tồn tại, tạo mới với cosine similarity...")
|
81 |
+
self.collection = self.chroma_client.create_collection(
|
82 |
+
name=COLLECTION_NAME,
|
83 |
+
metadata={
|
84 |
+
"hnsw:space": "cosine", # Cosine distance
|
85 |
+
"hnsw:M": 16, # Optimize for accuracy
|
86 |
+
"hnsw:construction_ef": 100
|
87 |
+
}
|
88 |
+
)
|
89 |
+
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}")
|
90 |
+
|
91 |
+
def _initialize_collection(self):
|
92 |
+
"""Khởi tạo collection với cosine similarity"""
|
93 |
+
try:
|
94 |
+
# Kiểm tra xem collection đã tồn tại chưa
|
95 |
+
existing_collections = [col.name for col in self.chroma_client.list_collections()]
|
96 |
+
|
97 |
+
if COLLECTION_NAME in existing_collections:
|
98 |
+
self.collection = self.chroma_client.get_collection(name=COLLECTION_NAME)
|
99 |
+
|
100 |
+
# Kiểm tra distance function hiện tại
|
101 |
+
current_metadata = self.collection.metadata or {}
|
102 |
+
current_space = current_metadata.get("hnsw:space", "l2")
|
103 |
+
|
104 |
+
if current_space != "cosine":
|
105 |
+
logger.warning(f"Collection hiện tại đang dùng {current_space}, cần migration sang cosine")
|
106 |
+
if self.collection.count() > 0:
|
107 |
+
self._migrate_to_cosine()
|
108 |
+
else:
|
109 |
+
# Collection trống, xóa và tạo lại
|
110 |
+
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
111 |
+
self._create_cosine_collection()
|
112 |
+
else:
|
113 |
+
logger.info(f"Đã kết nối collection '{COLLECTION_NAME}' với cosine similarity, {self.collection.count()} items")
|
114 |
+
else:
|
115 |
+
# Collection chưa tồn tại, tạo mới với cosine
|
116 |
+
self._create_cosine_collection()
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(f"Lỗi khởi tạo collection: {e}")
|
120 |
+
# Fallback: tạo collection mới
|
121 |
+
self._create_cosine_collection()
|
122 |
+
|
123 |
+
def _create_cosine_collection(self):
|
124 |
+
"""Tạo collection mới với cosine similarity"""
|
125 |
+
try:
|
126 |
+
self.collection = self.chroma_client.create_collection(
|
127 |
+
name=COLLECTION_NAME,
|
128 |
+
metadata={"hnsw:space": "cosine"}
|
129 |
+
)
|
130 |
+
logger.info(f"Đã tạo collection mới với cosine similarity: {COLLECTION_NAME}")
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Lỗi tạo collection với cosine: {e}")
|
133 |
+
# Fallback về collection mặc định
|
134 |
+
self.collection = self.chroma_client.get_or_create_collection(name=COLLECTION_NAME)
|
135 |
+
logger.warning("Đã fallback về collection mặc định (có thể dùng L2)")
|
136 |
+
|
137 |
+
def _migrate_to_cosine(self):
|
138 |
+
"""Migration collection từ L2 sang cosine"""
|
139 |
+
try:
|
140 |
+
logger.info("Bắt đầu migration collection sang cosine similarity...")
|
141 |
+
|
142 |
+
# Backup toàn bộ data
|
143 |
+
all_data = self.collection.get(
|
144 |
+
include=['documents', 'metadatas', 'embeddings'],
|
145 |
+
limit=self.collection.count()
|
146 |
+
)
|
147 |
+
|
148 |
+
if not all_data['documents']:
|
149 |
+
logger.info("Collection trống, chỉ cần tạo lại")
|
150 |
+
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
151 |
+
self._create_cosine_collection()
|
152 |
+
return
|
153 |
+
|
154 |
+
# Xóa collection cũ và tạo mới với cosine
|
155 |
+
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
156 |
+
self._create_cosine_collection()
|
157 |
+
|
158 |
+
# Restore data theo batch
|
159 |
+
documents = all_data['documents']
|
160 |
+
metadatas = all_data['metadatas']
|
161 |
+
embeddings = all_data['embeddings']
|
162 |
+
ids = all_data['ids']
|
163 |
+
|
164 |
+
batch_size = 100
|
165 |
+
total_items = len(documents)
|
166 |
+
|
167 |
+
for i in range(0, total_items, batch_size):
|
168 |
+
batch_docs = documents[i:i + batch_size]
|
169 |
+
batch_metas = metadatas[i:i + batch_size] if metadatas else None
|
170 |
+
batch_embeds = embeddings[i:i + batch_size] if embeddings else None
|
171 |
+
batch_ids = ids[i:i + batch_size]
|
172 |
+
|
173 |
+
if batch_embeds:
|
174 |
+
# Có embeddings sẵn, dùng luôn
|
175 |
+
self.collection.add(
|
176 |
+
documents=batch_docs,
|
177 |
+
metadatas=batch_metas,
|
178 |
+
embeddings=batch_embeds,
|
179 |
+
ids=batch_ids
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
# Tính lại embeddings
|
183 |
+
new_embeddings = self.encode(batch_docs, is_query=False)
|
184 |
+
self.collection.add(
|
185 |
+
documents=batch_docs,
|
186 |
+
metadatas=batch_metas,
|
187 |
+
embeddings=new_embeddings,
|
188 |
+
ids=batch_ids
|
189 |
+
)
|
190 |
+
|
191 |
+
logger.info(f"Migration progress: {min(i + batch_size, total_items)}/{total_items}")
|
192 |
+
|
193 |
+
logger.info(f"Migration hoàn thành! Đã chuyển {total_items} items sang cosine similarity")
|
194 |
+
|
195 |
+
except Exception as e:
|
196 |
+
logger.error(f"Lỗi migration: {e}")
|
197 |
+
# Tạo collection mới nếu migration thất bại
|
198 |
+
self._create_cosine_collection()
|
199 |
+
|
200 |
+
def test_embedding_quality(self):
|
201 |
+
try:
|
202 |
+
# Test cases
|
203 |
+
test_cases = [
|
204 |
+
("query: Tháp dinh dưỡng cho trẻ", "passage: Tháp dinh dưỡng cho trẻ từ 6-11 tuổi"),
|
205 |
+
("query: dinh dưỡng", "passage: dinh dưỡng cho học sinh"),
|
206 |
+
("query: xin chào", "passage: Tháp dinh dưỡng cho trẻ")
|
207 |
+
]
|
208 |
+
|
209 |
+
for query_text, doc_text in test_cases:
|
210 |
+
# Encode
|
211 |
+
query_emb = self.model.encode([query_text], normalize_embeddings=True)[0]
|
212 |
+
doc_emb = self.model.encode([doc_text], normalize_embeddings=True)[0]
|
213 |
+
|
214 |
+
# Calculate cosine similarity manually
|
215 |
+
import numpy as np
|
216 |
+
similarity = np.dot(query_emb, doc_emb)
|
217 |
+
|
218 |
+
logger.info(f"Query: {query_text}")
|
219 |
+
logger.info(f"Doc: {doc_text}")
|
220 |
+
logger.info(f"Similarity: {similarity:.3f}")
|
221 |
+
logger.info(f"Query norm: {np.linalg.norm(query_emb):.3f}")
|
222 |
+
logger.info(f"Doc norm: {np.linalg.norm(doc_emb):.3f}")
|
223 |
+
logger.info("-" * 50)
|
224 |
+
|
225 |
+
except Exception as e:
|
226 |
+
logger.error(f"Test embedding error: {e}")
|
227 |
|
228 |
def _add_prefix_to_text(self, text, is_query=True):
|
229 |
+
# Clean text trước
|
230 |
+
text = text.strip()
|
231 |
+
|
232 |
# Kiểm tra xem text đã có prefix chưa
|
233 |
if text.startswith(('query:', 'passage:')):
|
234 |
return text
|
|
|
241 |
|
242 |
def encode(self, texts, is_query=True):
|
243 |
"""
|
244 |
+
Encode văn bản thành embeddings với proper normalization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
"""
|
246 |
try:
|
247 |
if isinstance(texts, str):
|
248 |
texts = [texts]
|
249 |
|
250 |
+
# Thêm prefix cho texts (QUAN TRỌNG cho multilingual-e5-base)
|
251 |
processed_texts = [self._add_prefix_to_text(text, is_query) for text in texts]
|
252 |
|
253 |
+
logger.debug(f"Đang encode {len(processed_texts)} văn bản với prefix")
|
254 |
+
logger.debug(f"Sample processed text: {processed_texts[0][:100]}...")
|
255 |
+
|
256 |
+
# Encode với normalize_embeddings=True (QUAN TRỌNG!)
|
257 |
+
embeddings = self.model.encode(
|
258 |
+
processed_texts,
|
259 |
+
show_progress_bar=False,
|
260 |
+
normalize_embeddings=True # ✅ THÊM DÒNG NÀY
|
261 |
+
)
|
262 |
+
|
263 |
+
# Double-check normalization
|
264 |
+
import numpy as np
|
265 |
+
for i, emb in enumerate(embeddings[:2]): # Check first 2 embeddings
|
266 |
+
norm = np.linalg.norm(emb)
|
267 |
+
logger.debug(f"Embedding {i} norm: {norm}")
|
268 |
+
if abs(norm - 1.0) > 0.01:
|
269 |
+
logger.warning(f"Embedding {i} not properly normalized: norm = {norm}")
|
270 |
|
271 |
return embeddings.tolist()
|
272 |
|
|
|
275 |
raise
|
276 |
|
277 |
def search(self, query, top_k=5, age_filter=None):
|
278 |
+
"""Tìm kiếm văn bản tương tự trong ChromaDB"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
try:
|
|
|
|
|
|
|
280 |
query_embedding = self.encode(query, is_query=True)[0]
|
281 |
|
|
|
282 |
where_clause = None
|
283 |
if age_filter:
|
284 |
where_clause = {
|
|
|
287 |
{"age_max": {"$gte": age_filter}}
|
288 |
]
|
289 |
}
|
290 |
+
print(f"🔍 AGE FILTER: Tìm kiếm cho tuổi {age_filter}")
|
291 |
+
print(f"🔍 WHERE CLAUSE: {where_clause}")
|
292 |
+
else:
|
293 |
+
print(f"⚠️ KHÔNG CÓ AGE FILTER - Tìm tất cả chunks")
|
294 |
search_results = self.collection.query(
|
295 |
query_embeddings=[query_embedding],
|
296 |
n_results=top_k,
|
297 |
where=where_clause,
|
298 |
include=['documents', 'metadatas', 'distances']
|
299 |
)
|
300 |
+
|
301 |
+
print(f"\n{'='*60}")
|
302 |
+
print(f"📊 CHROMADB SEARCH RESULTS")
|
303 |
+
print(f"{'='*60}")
|
304 |
+
print(f"Query: {query}")
|
305 |
+
print(f"Age filter: {age_filter}")
|
306 |
+
print(f"Found {len(search_results['documents'][0]) if search_results['documents'] else 0} chunks")
|
307 |
+
print(f"{'='*60}")
|
308 |
|
309 |
if not search_results or not search_results['documents']:
|
310 |
+
logger.warning("Không tìm thấy kết quả nào")
|
311 |
return []
|
312 |
|
|
|
313 |
results = []
|
314 |
documents = search_results['documents'][0]
|
315 |
metadatas = search_results['metadatas'][0]
|
316 |
distances = search_results['distances'][0]
|
317 |
|
318 |
for i, (doc, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
|
319 |
+
chunk_id = metadata.get('chunk_id', f'chunk_{i}')
|
320 |
+
title = metadata.get('title', 'No title')
|
321 |
+
age_range = metadata.get('age_range', 'Unknown')
|
322 |
+
age_min = metadata.get('age_min', 'N/A')
|
323 |
+
age_max = metadata.get('age_max', 'N/A')
|
324 |
+
content_type = metadata.get('content_type', 'text')
|
325 |
+
chapter = metadata.get('chapter', 'Unknown')
|
326 |
+
similarity = round(1 - distance, 3)
|
327 |
+
|
328 |
results.append({
|
329 |
'document': doc,
|
330 |
'metadata': metadata or {},
|
331 |
'distance': distance,
|
332 |
+
'similarity': similarity,
|
333 |
'rank': i + 1
|
334 |
})
|
335 |
|
336 |
+
print(f"\n{'='*60}")
|
337 |
logger.info(f"Tim thay {len(results)} ket qua cho query")
|
338 |
return results
|
339 |
|
|
|
342 |
return []
|
343 |
|
344 |
def add_documents(self, documents, metadatas=None, ids=None):
|
345 |
+
"""Thêm documents vào ChromaDB"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
try:
|
347 |
if not documents:
|
348 |
logger.warning("Không có documents để thêm")
|
349 |
return False
|
350 |
|
|
|
351 |
if not ids:
|
352 |
ids = [str(uuid.uuid4()) for _ in documents]
|
353 |
|
|
|
354 |
if not metadatas:
|
355 |
metadatas = [{} for _ in documents]
|
356 |
|
357 |
logger.info(f"Đang thêm {len(documents)} documents vào ChromaDB")
|
358 |
|
|
|
359 |
embeddings = self.encode(documents, is_query=False)
|
360 |
|
|
|
361 |
self.collection.add(
|
362 |
embeddings=embeddings,
|
363 |
documents=documents,
|
|
|
373 |
return False
|
374 |
|
375 |
def index_chunks(self, chunks):
|
376 |
+
"""Index các chunks dữ liệu vào ChromaDB"""
|
|
|
|
|
377 |
try:
|
378 |
if not chunks:
|
379 |
logger.warning("Không có chunks để index")
|
|
|
390 |
|
391 |
documents.append(chunk['content'])
|
392 |
|
|
|
393 |
metadata = chunk.get('metadata', {})
|
394 |
metadatas.append(metadata)
|
395 |
|
|
|
396 |
chunk_id = chunk.get('id') or str(uuid.uuid4())
|
397 |
ids.append(chunk_id)
|
398 |
|
|
|
400 |
logger.warning("Không có documents hợp lệ để index")
|
401 |
return False
|
402 |
|
|
|
403 |
batch_size = 100
|
404 |
total_batches = (len(documents) + batch_size - 1) // batch_size
|
405 |
|
|
|
437 |
logger.warning(f"Đang xóa collection: {COLLECTION_NAME}")
|
438 |
self.chroma_client.delete_collection(name=COLLECTION_NAME)
|
439 |
|
440 |
+
# Tạo lại collection với cosine similarity
|
441 |
+
self._create_cosine_collection()
|
442 |
+
logger.info("Đã tạo lại collection mới với cosine similarity")
|
443 |
|
444 |
return True
|
445 |
|
|
|
447 |
logger.error(f"Lỗi xóa collection: {e}")
|
448 |
return False
|
449 |
|
450 |
+
def get_collection_info(self):
|
451 |
+
"""Lấy thông tin về collection và distance function"""
|
452 |
+
try:
|
453 |
+
metadata = self.collection.metadata or {}
|
454 |
+
distance_func = metadata.get("hnsw:space", "l2")
|
455 |
+
|
456 |
+
return {
|
457 |
+
'collection_name': COLLECTION_NAME,
|
458 |
+
'distance_function': distance_func,
|
459 |
+
'total_documents': self.count(),
|
460 |
+
'metadata': metadata
|
461 |
+
}
|
462 |
+
except Exception as e:
|
463 |
+
logger.error(f"Lỗi lấy collection info: {e}")
|
464 |
+
return {'error': str(e)}
|
465 |
+
|
466 |
+
def verify_cosine_similarity(self):
|
467 |
+
"""Kiểm tra và xác nhận đang sử dụng cosine similarity"""
|
468 |
+
try:
|
469 |
+
info = self.get_collection_info()
|
470 |
+
distance_func = info.get('distance_function', 'unknown')
|
471 |
+
|
472 |
+
logger.info(f"Collection đang sử dụng distance function: {distance_func}")
|
473 |
+
|
474 |
+
if distance_func == "cosine":
|
475 |
+
logger.info("Xác nhận: Đang sử dụng cosine similarity")
|
476 |
+
return True
|
477 |
+
else:
|
478 |
+
logger.warning(f"Cảnh báo: Đang sử dụng {distance_func}, không phải cosine")
|
479 |
+
return False
|
480 |
+
|
481 |
+
except Exception as e:
|
482 |
+
logger.error(f"Lỗi verify cosine: {e}")
|
483 |
+
return False
|
484 |
+
|
485 |
def get_stats(self):
|
486 |
"""Lấy thống kê về collection"""
|
487 |
try:
|
488 |
total_count = self.count()
|
489 |
+
collection_info = self.get_collection_info()
|
490 |
|
|
|
491 |
sample_results = self.collection.get(limit=min(100, total_count))
|
492 |
|
|
|
493 |
content_types = {}
|
494 |
chapters = {}
|
495 |
age_groups = {}
|
|
|
499 |
if not metadata:
|
500 |
continue
|
501 |
|
|
|
502 |
content_type = metadata.get('content_type', 'unknown')
|
503 |
content_types[content_type] = content_types.get(content_type, 0) + 1
|
504 |
|
|
|
505 |
chapter = metadata.get('chapter', 'unknown')
|
506 |
chapters[chapter] = chapters.get(chapter, 0) + 1
|
507 |
|
|
|
508 |
age_group = metadata.get('age_group', 'unknown')
|
509 |
age_groups[age_group] = age_groups.get(age_group, 0) + 1
|
510 |
|
|
|
514 |
'chapters': chapters,
|
515 |
'age_groups': age_groups,
|
516 |
'collection_name': COLLECTION_NAME,
|
517 |
+
'embedding_model': EMBEDDING_MODEL,
|
518 |
+
'distance_function': collection_info.get('distance_function', 'unknown'),
|
519 |
+
'using_cosine_similarity': collection_info.get('distance_function') == 'cosine'
|
520 |
}
|
521 |
|
522 |
except Exception as e:
|
core/rag_pipeline.py
CHANGED
@@ -5,68 +5,56 @@ from config import GEMINI_API_KEY, HUMAN_PROMPT_TEMPLATE, SYSTEM_PROMPT, TOP_K_R
|
|
5 |
import os
|
6 |
import re
|
7 |
|
8 |
-
# Cấu hình logging
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
11 |
-
# Cấu hình Gemini
|
12 |
genai.configure(api_key=GEMINI_API_KEY)
|
13 |
|
14 |
class RAGPipeline:
|
15 |
def __init__(self):
|
16 |
-
|
17 |
-
logger.info("
|
18 |
|
19 |
self.embedding_model = get_embedding_model()
|
20 |
-
|
21 |
-
# Khởi tạo Gemini model
|
22 |
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')
|
23 |
|
24 |
-
logger.info("RAG Pipeline đã sẵn sàng")
|
25 |
|
26 |
def generate_response(self, query, age=1):
|
27 |
-
|
28 |
-
Generate response cho user query sử dụng RAG
|
29 |
-
|
30 |
-
Args:
|
31 |
-
query (str): Câu hỏi của người dùng
|
32 |
-
age (int): Tuổi của người dùng (1-19)
|
33 |
-
|
34 |
-
Returns:
|
35 |
-
dict: Response data with success status
|
36 |
-
"""
|
37 |
try:
|
38 |
-
logger.info(f"Bắt đầu
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
# SỬA: Chỉ search trong ChromaDB, không load lại dữ liệu
|
41 |
-
logger.info("Đang tìm kiếm thông tin liên quan...")
|
42 |
-
search_results = self.embedding_model.search(query, top_k=TOP_K_RESULTS)
|
43 |
if not search_results or len(search_results) == 0:
|
44 |
-
logger.warning("Không tìm thấy thông tin liên quan")
|
45 |
return {
|
46 |
"success": True,
|
47 |
"response": "Xin lỗi, tôi không tìm thấy thông tin liên quan đến câu hỏi của bạn trong tài liệu.",
|
48 |
"sources": []
|
49 |
}
|
50 |
|
51 |
-
# Chuẩn bị
|
52 |
contexts = []
|
53 |
sources = []
|
54 |
|
55 |
for result in search_results:
|
56 |
-
# Lấy thông tin từ metadata
|
57 |
metadata = result.get('metadata', {})
|
58 |
content = result.get('document', '')
|
59 |
|
60 |
-
# Thêm
|
61 |
contexts.append({
|
62 |
"content": content,
|
63 |
"metadata": metadata
|
64 |
})
|
65 |
|
66 |
-
#
|
67 |
source_info = {
|
68 |
-
"
|
69 |
-
"title": metadata.get('title', metadata.get('chapter', 'Tài liệu dinh dưỡng')), # Giữ title nếu cần
|
70 |
"pages": metadata.get('pages'),
|
71 |
"content_type": metadata.get('content_type', 'text')
|
72 |
}
|
@@ -74,14 +62,14 @@ class RAGPipeline:
|
|
74 |
if source_info not in sources:
|
75 |
sources.append(source_info)
|
76 |
|
77 |
-
#
|
78 |
formatted_contexts = self._format_contexts(contexts)
|
79 |
|
80 |
-
# Tạo prompt với
|
81 |
full_prompt = self._create_prompt_with_age_context(query, age, formatted_contexts)
|
82 |
|
83 |
-
#
|
84 |
-
logger.info("Đang tạo phản hồi với Gemini
|
85 |
response = self.gemini_model.generate_content(
|
86 |
full_prompt,
|
87 |
generation_config=genai.types.GenerationConfig(
|
@@ -91,7 +79,7 @@ class RAGPipeline:
|
|
91 |
)
|
92 |
|
93 |
if not response or not response.text:
|
94 |
-
logger.error("Gemini không trả về
|
95 |
return {
|
96 |
"success": False,
|
97 |
"error": "Không thể tạo phản hồi"
|
@@ -99,7 +87,7 @@ class RAGPipeline:
|
|
99 |
|
100 |
response_text = response.text.strip()
|
101 |
|
102 |
-
#
|
103 |
response_text = self._process_image_links(response_text)
|
104 |
|
105 |
logger.info("Đã tạo phản hồi thành công")
|
@@ -111,25 +99,23 @@ class RAGPipeline:
|
|
111 |
}
|
112 |
|
113 |
except Exception as e:
|
114 |
-
logger.error(f"Lỗi
|
115 |
return {
|
116 |
"success": False,
|
117 |
"error": f"Lỗi tạo phản hồi: {str(e)}"
|
118 |
}
|
119 |
|
120 |
def _format_contexts(self, contexts):
|
121 |
-
|
122 |
formatted = []
|
123 |
|
124 |
for i, context in enumerate(contexts, 1):
|
125 |
content = context['content']
|
126 |
metadata = context['metadata']
|
127 |
|
128 |
-
# Thêm thông tin metadata
|
129 |
context_str = f"[Tài liệu {i}]"
|
130 |
-
if metadata.get('
|
131 |
-
context_str += f" - ID: {metadata['chunk_id']}"
|
132 |
-
elif metadata.get('title'):
|
133 |
context_str += f" - {metadata['title']}"
|
134 |
if metadata.get('pages'):
|
135 |
context_str += f" (Trang {metadata['pages']})"
|
@@ -139,9 +125,8 @@ class RAGPipeline:
|
|
139 |
|
140 |
return "\n".join(formatted)
|
141 |
|
142 |
-
def _create_prompt_with_age_context(self, query, age, contexts):
|
143 |
-
|
144 |
-
# Xác định age group
|
145 |
if age <= 3:
|
146 |
age_guidance = "Sử dụng ngôn ngữ đơn giản, dễ hiểu cho phụ huynh có con nhỏ."
|
147 |
elif age <= 6:
|
@@ -153,7 +138,7 @@ class RAGPipeline:
|
|
153 |
else:
|
154 |
age_guidance = "Thông tin đầy đủ, chi tiết cho học sinh trung học phổ thông."
|
155 |
|
156 |
-
# Tạo system prompt
|
157 |
age_aware_system_prompt = f"""{SYSTEM_PROMPT}
|
158 |
|
159 |
QUAN TRỌNG - Hướng dẫn theo độ tuổi:
|
@@ -163,7 +148,7 @@ Người dùng hiện tại {age} tuổi. {age_guidance}
|
|
163 |
- Tránh thông tin quá phức tạp hoặc không phù hợp
|
164 |
"""
|
165 |
|
166 |
-
# Tạo human prompt
|
167 |
human_prompt = HUMAN_PROMPT_TEMPLATE.format(
|
168 |
query=query,
|
169 |
age=age,
|
@@ -173,30 +158,28 @@ Người dùng hiện tại {age} tuổi. {age_guidance}
|
|
173 |
return f"{age_aware_system_prompt}\n\n{human_prompt}"
|
174 |
|
175 |
def _process_image_links(self, response_text):
|
176 |
-
|
177 |
try:
|
178 |
import re
|
179 |
|
180 |
-
# Tìm các pattern markdown
|
181 |
image_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
|
182 |
|
183 |
def replace_image_path(match):
|
184 |
alt_text = match.group(1)
|
185 |
image_path = match.group(2)
|
186 |
|
187 |
-
# Xử lý đường dẫn local Windows/Linux
|
188 |
if '\\' in image_path or image_path.startswith('/') or ':' in image_path:
|
189 |
-
#
|
190 |
filename = image_path.split('\\')[-1].split('/')[-1]
|
191 |
|
192 |
-
# Tìm bai_id từ
|
193 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
194 |
if bai_match:
|
195 |
bai_id = bai_match.group(1)
|
196 |
-
else:
|
197 |
-
bai_id = 'bai1' # default
|
198 |
|
199 |
-
# Tạo API
|
200 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
201 |
return f""
|
202 |
|
@@ -210,39 +193,29 @@ Người dùng hiện tại {age} tuổi. {age_guidance}
|
|
210 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
211 |
if bai_match:
|
212 |
bai_id = bai_match.group(1)
|
213 |
-
else:
|
214 |
-
bai_id = 'bai1'
|
215 |
|
216 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
217 |
return f""
|
218 |
-
|
219 |
-
# Các trường hợp khác, giữ nguyên
|
220 |
return match.group(0)
|
221 |
|
222 |
-
# Thay thế tất cả
|
223 |
processed_text = re.sub(image_pattern, replace_image_path, response_text)
|
224 |
|
225 |
-
|
|
|
|
|
|
|
226 |
return processed_text
|
227 |
|
228 |
except Exception as e:
|
229 |
-
logger.error(f"Lỗi xử lý
|
230 |
return response_text
|
231 |
|
232 |
def generate_follow_up_questions(self, query, answer, age=1):
|
233 |
-
|
234 |
-
Tạo câu hỏi gợi ý dựa trên query và answer
|
235 |
-
|
236 |
-
Args:
|
237 |
-
query (str): Câu hỏi gốc
|
238 |
-
answer (str): Câu trả lời đã được tạo
|
239 |
-
age (int): Tuổi người dùng
|
240 |
-
|
241 |
-
Returns:
|
242 |
-
dict: Response data với danh sách câu hỏi gợi ý
|
243 |
-
"""
|
244 |
try:
|
245 |
-
logger.info("Đang tạo câu hỏi
|
246 |
|
247 |
follow_up_prompt = f"""
|
248 |
Dựa trên cuộc hội thoại sau, hãy tạo 3-5 câu hỏi gợi ý phù hợp cho người dùng {age} tuổi về chủ đề dinh dưỡng:
|
@@ -273,27 +246,30 @@ Trả về danh sách câu hỏi, mỗi câu một dòng, không đánh số.
|
|
273 |
"error": "Không thể tạo câu hỏi gợi ý"
|
274 |
}
|
275 |
|
276 |
-
#
|
277 |
questions = []
|
278 |
lines = response.text.strip().split('\n')
|
279 |
|
280 |
for line in lines:
|
281 |
line = line.strip()
|
|
|
282 |
if line and not line.startswith('#') and len(line) > 10:
|
283 |
-
# Loại bỏ số thứ tự nếu có
|
284 |
line = re.sub(r'^\d+[\.\)]\s*', '', line)
|
285 |
questions.append(line)
|
286 |
|
287 |
-
# Giới hạn 5 câu hỏi
|
288 |
questions = questions[:5]
|
289 |
|
|
|
|
|
290 |
return {
|
291 |
"success": True,
|
292 |
"questions": questions
|
293 |
}
|
294 |
|
295 |
except Exception as e:
|
296 |
-
logger.error(f"Lỗi tạo
|
297 |
return {
|
298 |
"success": False,
|
299 |
"error": f"Lỗi tạo câu hỏi gợi ý: {str(e)}"
|
|
|
5 |
import os
|
6 |
import re
|
7 |
|
|
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
+
# Cấu hình Gemini API
|
11 |
genai.configure(api_key=GEMINI_API_KEY)
|
12 |
|
13 |
class RAGPipeline:
|
14 |
def __init__(self):
|
15 |
+
# Khởi tạo RAG Pipeline với embedding model
|
16 |
+
logger.info("Đang khởi tạo RAG Pipeline")
|
17 |
|
18 |
self.embedding_model = get_embedding_model()
|
|
|
|
|
19 |
self.gemini_model = genai.GenerativeModel('gemini-2.0-flash')
|
20 |
|
21 |
+
logger.info("RAG Pipeline đã sẵn sàng hoạt động")
|
22 |
|
23 |
def generate_response(self, query, age=1):
|
24 |
+
# Tạo phản hồi cho câu hỏi của người dùng sử dụng RAG
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
try:
|
26 |
+
logger.info(f"Bắt đầu tạo phản hồi cho câu hỏi: {query[:50]}... (tuổi: {age})")
|
27 |
+
|
28 |
+
# Tìm kiếm thông tin liên quan trong ChromaDB
|
29 |
+
logger.info("Đang tìm kiếm thông tin liên quan trong cơ sở dữ liệu")
|
30 |
+
search_results = self.embedding_model.search(query, top_k=TOP_K_RESULTS, age_filter=age)
|
31 |
+
# search_results = self.embedding_model.search(query, top_k=TOP_K_RESULTS)
|
32 |
|
|
|
|
|
|
|
33 |
if not search_results or len(search_results) == 0:
|
34 |
+
logger.warning("Không tìm thấy thông tin liên quan trong cơ sở dữ liệu")
|
35 |
return {
|
36 |
"success": True,
|
37 |
"response": "Xin lỗi, tôi không tìm thấy thông tin liên quan đến câu hỏi của bạn trong tài liệu.",
|
38 |
"sources": []
|
39 |
}
|
40 |
|
41 |
+
# Chuẩn bị ngữ cảnh từ kết quả tìm kiếm
|
42 |
contexts = []
|
43 |
sources = []
|
44 |
|
45 |
for result in search_results:
|
|
|
46 |
metadata = result.get('metadata', {})
|
47 |
content = result.get('document', '')
|
48 |
|
49 |
+
# Thêm nội dung vào ngữ cảnh
|
50 |
contexts.append({
|
51 |
"content": content,
|
52 |
"metadata": metadata
|
53 |
})
|
54 |
|
55 |
+
# Tạo thông tin nguồn tài liệu
|
56 |
source_info = {
|
57 |
+
"title": metadata.get('title', metadata.get('chapter', 'Tài liệu dinh dưỡng')),
|
|
|
58 |
"pages": metadata.get('pages'),
|
59 |
"content_type": metadata.get('content_type', 'text')
|
60 |
}
|
|
|
62 |
if source_info not in sources:
|
63 |
sources.append(source_info)
|
64 |
|
65 |
+
# Định dạng ngữ cảnh cho prompt
|
66 |
formatted_contexts = self._format_contexts(contexts)
|
67 |
|
68 |
+
# Tạo prompt với ngữ cảnh độ tuổi
|
69 |
full_prompt = self._create_prompt_with_age_context(query, age, formatted_contexts)
|
70 |
|
71 |
+
# Tạo phản hồi với Gemini AI
|
72 |
+
logger.info("Đang tạo phản hồi với Gemini AI")
|
73 |
response = self.gemini_model.generate_content(
|
74 |
full_prompt,
|
75 |
generation_config=genai.types.GenerationConfig(
|
|
|
79 |
)
|
80 |
|
81 |
if not response or not response.text:
|
82 |
+
logger.error("Gemini AI không trả về phản hồi")
|
83 |
return {
|
84 |
"success": False,
|
85 |
"error": "Không thể tạo phản hồi"
|
|
|
87 |
|
88 |
response_text = response.text.strip()
|
89 |
|
90 |
+
# Xử lý các đường dẫn hình ảnh trong phản hồi
|
91 |
response_text = self._process_image_links(response_text)
|
92 |
|
93 |
logger.info("Đã tạo phản hồi thành công")
|
|
|
99 |
}
|
100 |
|
101 |
except Exception as e:
|
102 |
+
logger.error(f"Lỗi khi tạo phản hồi: {str(e)}")
|
103 |
return {
|
104 |
"success": False,
|
105 |
"error": f"Lỗi tạo phản hồi: {str(e)}"
|
106 |
}
|
107 |
|
108 |
def _format_contexts(self, contexts):
|
109 |
+
# Định dạng ngữ cảnh thành chuỗi cho prompt
|
110 |
formatted = []
|
111 |
|
112 |
for i, context in enumerate(contexts, 1):
|
113 |
content = context['content']
|
114 |
metadata = context['metadata']
|
115 |
|
116 |
+
# Thêm thông tin metadata vào ngữ cảnh
|
117 |
context_str = f"[Tài liệu {i}]"
|
118 |
+
if metadata.get('title'):
|
|
|
|
|
119 |
context_str += f" - {metadata['title']}"
|
120 |
if metadata.get('pages'):
|
121 |
context_str += f" (Trang {metadata['pages']})"
|
|
|
125 |
|
126 |
return "\n".join(formatted)
|
127 |
|
128 |
+
def _create_prompt_with_age_context(self, query, age, contexts):
|
129 |
+
# Xác định hướng dẫn theo nhóm tuổi
|
|
|
130 |
if age <= 3:
|
131 |
age_guidance = "Sử dụng ngôn ngữ đơn giản, dễ hiểu cho phụ huynh có con nhỏ."
|
132 |
elif age <= 6:
|
|
|
138 |
else:
|
139 |
age_guidance = "Thông tin đầy đủ, chi tiết cho học sinh trung học phổ thông."
|
140 |
|
141 |
+
# Tạo system prompt có tính đến độ tuổi
|
142 |
age_aware_system_prompt = f"""{SYSTEM_PROMPT}
|
143 |
|
144 |
QUAN TRỌNG - Hướng dẫn theo độ tuổi:
|
|
|
148 |
- Tránh thông tin quá phức tạp hoặc không phù hợp
|
149 |
"""
|
150 |
|
151 |
+
# Tạo human prompt từ template
|
152 |
human_prompt = HUMAN_PROMPT_TEMPLATE.format(
|
153 |
query=query,
|
154 |
age=age,
|
|
|
158 |
return f"{age_aware_system_prompt}\n\n{human_prompt}"
|
159 |
|
160 |
def _process_image_links(self, response_text):
|
161 |
+
# Xử lý và chuyển đổi các đường dẫn hình ảnh trong phản hồi
|
162 |
try:
|
163 |
import re
|
164 |
|
165 |
+
# Tìm các pattern markdown: 
|
166 |
image_pattern = r'!\[([^\]]*)\]\(([^)]+)\)'
|
167 |
|
168 |
def replace_image_path(match):
|
169 |
alt_text = match.group(1)
|
170 |
image_path = match.group(2)
|
171 |
|
172 |
+
# Xử lý đường dẫn local (Windows/Linux)
|
173 |
if '\\' in image_path or image_path.startswith('/') or ':' in image_path:
|
174 |
+
# Trích xuất tên file từ đường dẫn local
|
175 |
filename = image_path.split('\\')[-1].split('/')[-1]
|
176 |
|
177 |
+
# Tìm bai_id từ tên file (format: baiX_filename)
|
178 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
179 |
if bai_match:
|
180 |
bai_id = bai_match.group(1)
|
|
|
|
|
181 |
|
182 |
+
# Tạo URL API
|
183 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
184 |
return f""
|
185 |
|
|
|
193 |
bai_match = re.match(r'^(bai\d+)_', filename)
|
194 |
if bai_match:
|
195 |
bai_id = bai_match.group(1)
|
|
|
|
|
196 |
|
197 |
api_url = f"/api/figures/{bai_id}/{filename}"
|
198 |
return f""
|
199 |
+
|
|
|
200 |
return match.group(0)
|
201 |
|
202 |
+
# Thay thế tất cả các liên kết hình ảnh
|
203 |
processed_text = re.sub(image_pattern, replace_image_path, response_text)
|
204 |
|
205 |
+
image_count = len(re.findall(image_pattern, response_text))
|
206 |
+
if image_count > 0:
|
207 |
+
logger.info(f"Đã xử lý {image_count} liên kết hình ảnh")
|
208 |
+
|
209 |
return processed_text
|
210 |
|
211 |
except Exception as e:
|
212 |
+
logger.error(f"Lỗi khi xử lý liên kết hình ảnh: {e}")
|
213 |
return response_text
|
214 |
|
215 |
def generate_follow_up_questions(self, query, answer, age=1):
|
216 |
+
# Tạo câu hỏi gợi ý dựa trên cuộc hội thoại hiện tại
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
try:
|
218 |
+
logger.info("Đang tạo câu hỏi gợi ý")
|
219 |
|
220 |
follow_up_prompt = f"""
|
221 |
Dựa trên cuộc hội thoại sau, hãy tạo 3-5 câu hỏi gợi ý phù hợp cho người dùng {age} tuổi về chủ đề dinh dưỡng:
|
|
|
246 |
"error": "Không thể tạo câu hỏi gợi ý"
|
247 |
}
|
248 |
|
249 |
+
# Chuyển đổi phản hồi thành danh sách câu hỏi
|
250 |
questions = []
|
251 |
lines = response.text.strip().split('\n')
|
252 |
|
253 |
for line in lines:
|
254 |
line = line.strip()
|
255 |
+
# Lọc các dòng hợp lệ (không rỗng, không phải comment, đủ dài)
|
256 |
if line and not line.startswith('#') and len(line) > 10:
|
257 |
+
# Loại bỏ số thứ tự nếu có (1. 2. hoặc 1) 2))
|
258 |
line = re.sub(r'^\d+[\.\)]\s*', '', line)
|
259 |
questions.append(line)
|
260 |
|
261 |
+
# Giới hạn tối đa 5 câu hỏi
|
262 |
questions = questions[:5]
|
263 |
|
264 |
+
logger.info(f"Đã tạo {len(questions)} câu hỏi gợi ý")
|
265 |
+
|
266 |
return {
|
267 |
"success": True,
|
268 |
"questions": questions
|
269 |
}
|
270 |
|
271 |
except Exception as e:
|
272 |
+
logger.error(f"Lỗi khi tạo câu hỏi gợi ý: {str(e)}")
|
273 |
return {
|
274 |
"success": False,
|
275 |
"error": f"Lỗi tạo câu hỏi gợi ý: {str(e)}"
|