File size: 15,630 Bytes
c69c2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# llm_handler.py

import torch
import re
import json
from unsloth import FastLanguageModel
# from transformers import TextStreamer # Bỏ comment nếu bạn muốn dùng TextStreamer để stream token

# Giả định rằng hàm search_relevant_laws được import từ file retrieval.py
# Nếu file retrieval.py nằm cùng cấp, bạn có thể import như sau:
# from retrieval import search_relevant_laws
# Hoặc nếu bạn muốn hàm này độc lập hơn, bạn có thể truyền retrieved_results vào generate_response
# thay vì truyền tất cả các thành phần của RAG.
# Tuy nhiên, dựa theo code gốc, generate_response gọi search_relevant_laws.

# --- HÀM TẢI MÔ HÌNH LLM VÀ TOKENIZER ---
def load_llm_model_and_tokenizer(
    model_name_or_path: str,
    max_seq_length: int = 2048,
    load_in_4bit: bool = True,
    device_map: str = "auto" # Cho phép Unsloth tự quyết định device map
):
    """
    Tải mô hình ngôn ngữ lớn (LLM) đã được fine-tune bằng Unsloth và tokenizer tương ứng.

    Args:
        model_name_or_path (str): Tên hoặc đường dẫn đến mô hình đã fine-tune.
        max_seq_length (int): Độ dài chuỗi tối đa mà mô hình hỗ trợ.
        load_in_4bit (bool): Có tải mô hình ở dạng 4-bit quantization hay không.
        device_map (str): Cách map model lên các device (ví dụ "auto", "cuda:0").

    Returns:
        tuple: (model, tokenizer) nếu thành công, (None, None) nếu có lỗi.
    """
    print(f"Đang tải LLM model: {model_name_or_path}...")
    try:
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_name_or_path,
            max_seq_length=max_seq_length,
            dtype=None,  # Unsloth sẽ tự động chọn dtype tối ưu
            load_in_4bit=load_in_4bit,
            device_map=device_map, # Thêm device_map
            # token = "hf_YOUR_TOKEN_HERE" # Nếu model là private trên Hugging Face Hub
        )
        FastLanguageModel.for_inference(model)  # Tối ưu hóa mô hình cho inference
        print("Tải LLM model và tokenizer thành công.")
        return model, tokenizer
    except Exception as e:
        print(f"Lỗi khi tải LLM model và tokenizer: {e}")
        return None, None

# --- HÀM TẠO CÂU TRẢ LỜI TỪ LLM ---
def generate_response(
    query: str,
    llama_model, # Mô hình LLM đã tải
    tokenizer,   # Tokenizer tương ứng
    # Các thành phần RAG được truyền từ app.py (đã được tải trước đó)
    faiss_index,
    embed_model,
    chunks_data_list: list,
    bm25_model,
    # Các tham số cho search_relevant_laws
    search_k: int = 5,
    search_multiplier: int = 10, # initial_k_multiplier
    rrf_k_constant: int = 60,
    # Các tham số cho generation của LLM
    max_new_tokens: int = 768, # Tăng lên một chút cho câu trả lời đầy đủ hơn
    temperature: float = 0.4, # Giảm một chút để câu trả lời bớt ngẫu nhiên, tập trung hơn
    top_p: float = 0.9,       # Giữ nguyên hoặc giảm nhẹ
    top_k: int = 40,
    repetition_penalty: float = 1.15, # Tăng nhẹ để tránh lặp từ
    # Tham số để import hàm search_relevant_laws
    search_function # Đây là hàm search_relevant_laws được truyền vào
):
    """
    Truy xuất ngữ cảnh bằng hàm search_relevant_laws (được truyền vào)
    và tạo câu trả lời từ LLM dựa trên ngữ cảnh đó.

    Args:
        query (str): Câu truy vấn của người dùng.
        llama_model: Mô hình LLM đã tải.
        tokenizer: Tokenizer tương ứng.
        faiss_index: FAISS index đã tải.
        embed_model: Mô hình embedding đã tải.
        chunks_data_list (list): Danh sách các chunk dữ liệu luật.
        bm25_model: Mô hình BM25 đã tạo.
        search_k (int): Số lượng kết quả cuối cùng muốn lấy từ hàm search.
        search_multiplier (int): Hệ số initial_k_multiplier cho hàm search.
        rrf_k_constant (int): Hằng số k cho RRF trong hàm search.
        max_new_tokens (int): Số token tối đa được tạo mới bởi LLM.
        temperature (float): Nhiệt độ cho việc sinh văn bản.
        top_p (float): Tham số top-p cho nucleus sampling.
        top_k (int): Tham số top-k.
        repetition_penalty (float): Phạt cho việc lặp từ.
        search_function: Hàm thực hiện tìm kiếm (ví dụ: retrieval.search_relevant_laws).

    Returns:
        str: Câu trả lời được tạo ra bởi LLM.
    """
    print(f"\n--- [LLM Handler] Bắt đầu xử lý query: '{query}' ---")

    # === 1. Truy xuất ngữ cảnh (Sử dụng hàm search_function được truyền vào) ===
    print("--- [LLM Handler] Bước 1: Truy xuất ngữ cảnh (Hybrid Search)... ---")
    try:
        retrieved_results = search_function(
            query_text=query,
            embedding_model=embed_model,
            faiss_index=faiss_index,
            chunks_data=chunks_data_list,
            bm25_model=bm25_model,
            k=search_k,
            initial_k_multiplier=search_multiplier,
            rrf_k_constant=rrf_k_constant
            # Các tham số boost có thể lấy giá trị mặc định trong search_function
            # hoặc truyền vào đây nếu muốn tùy chỉnh sâu hơn từ app.py
        )
        print(f"--- [LLM Handler] Truy xuất xong, số kết quả: {len(retrieved_results)} ---")
        if not retrieved_results:
            print("--- [LLM Handler] Không tìm thấy ngữ cảnh nào. ---")
    except Exception as e:
        print(f"Lỗi trong quá trình truy xuất ngữ cảnh: {e}")
        retrieved_results = [] # Xử lý lỗi bằng cách trả về danh sách rỗng

    # === 2. Định dạng Context từ retrieved_results ===
    print("--- [LLM Handler] Bước 2: Định dạng context cho LLM... ---")
    context_parts = []
    if not retrieved_results:
        context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu để trả lời câu hỏi này."
    else:
        for i, res in enumerate(retrieved_results):
            metadata = res.get('metadata', {})
            article_title = metadata.get('article_title', 'N/A') # Lấy tiêu đề Điều
            article = metadata.get('article', 'N/A')
            clause = metadata.get('clause_number', 'N/A') # Sửa key cho khớp với process_law_data_to_chunks
            point = metadata.get('point_id', '')
            source = metadata.get('source_document', 'N/A')
            text_content = res.get('text', '*Nội dung không có*')

            # Header rõ ràng cho mỗi nguồn
            header_parts = [f"Trích dẫn {i+1}:"]
            if source != 'N/A':
                header_parts.append(f"(Nguồn: {source})")
            if article != 'N/A':
                header_parts.append(f"Điều {article}")
                if article_title != 'N/A' and article_title != article: # Chỉ thêm tiêu đề nếu khác số điều
                    header_parts.append(f"({article_title})")
            if clause != 'N/A':
                header_parts.append(f", Khoản {clause}")
            if point: # point_id có thể là None hoặc rỗng
                header_parts.append(f", Điểm {point}")

            header = " ".join(header_parts)

            # Bổ sung thông tin phạt/điểm vào header nếu query có đề cập
            # và metadata của chunk có thông tin đó (lấy từ logic boosting của search_relevant_laws)
            query_analysis = metadata.get("query_analysis_for_boost", {}) # Giả sử search_relevant_laws có thể trả về
            # Hoặc phân tích lại query ở đây (ít hiệu quả hơn)
            mentions_fine_in_query = bool(re.search(r'tiền|phạt|bao nhiêu đồng|mức phạt', query.lower()))
            mentions_points_in_query = bool(re.search(r'điểm|trừ điểm|bằng lái|gplx', query.lower()))

            fine_info_text = []
            if metadata.get("has_fine") and mentions_fine_in_query:
                if metadata.get("individual_fine_min") is not None and metadata.get("individual_fine_max") is not None:
                    fine_info_text.append(f"Phạt tiền: {metadata.get('individual_fine_min'):,} - {metadata.get('individual_fine_max'):,} VND.")
                elif metadata.get("overall_fine_note_for_clause_text"):
                     fine_info_text.append(f"Ghi chú phạt tiền: {metadata.get('overall_fine_note_for_clause_text')}")

            points_info_text = []
            if metadata.get("has_points_deduction") and mentions_points_in_query:
                if metadata.get("points_deducted_values_str"):
                    points_info_text.append(f"Trừ điểm: {metadata.get('points_deducted_values_str')} điểm.")
                elif metadata.get("overall_points_deduction_note_for_clause_text"):
                    points_info_text.append(f"Ghi chú trừ điểm: {metadata.get('overall_points_deduction_note_for_clause_text')}")


            penalty_summary = ""
            if fine_info_text or points_info_text:
                penalty_summary = " (Liên quan: " + " ".join(fine_info_text + points_info_text) + ")"

            context_parts.append(f"{header}{penalty_summary}\nNội dung: {text_content}")

        context = "\n\n---\n\n".join(context_parts)
        # print("\n--- [LLM Handler] Context đã định dạng ---\n", context[:1000] + "...") # Xem trước context

    # === 3. Xây dựng Prompt ===
    # Sử dụng định dạng prompt mà mô hình của bạn được fine-tune (ví dụ: Alpaca)
    # Dưới đây là một ví dụ, bạn có thể cần điều chỉnh cho phù hợp với `lora_model_base`
    prompt = f"""Bạn là một trợ lý AI chuyên tư vấn về luật giao thông đường bộ Việt Nam.
Nhiệm vụ của bạn là dựa vào các thông tin luật được cung cấp dưới đây để trả lời câu hỏi của người dùng một cách chính xác, chi tiết và dễ hiểu.
Nếu thông tin không đủ hoặc không có trong các trích dẫn được cung cấp, hãy trả lời rằng bạn không tìm thấy thông tin cụ thể trong tài liệu được cung cấp.
Tránh đưa ra ý kiến cá nhân hoặc thông tin không có trong ngữ cảnh. Hãy trích dẫn điều, khoản, điểm nếu có thể.

### Thông tin luật được trích dẫn:
{context}

### Câu hỏi của người dùng:
{query}

### Trả lời của bạn:"""

    # print("\n--- [LLM Handler] Prompt hoàn chỉnh (một phần) ---\n", prompt[:1000] + "...")

    # === 4. Tạo câu trả lời từ LLM ===
    print("--- [LLM Handler] Bước 3: Tạo câu trả lời từ LLM... ---")
    device = llama_model.device # Lấy device từ model đã tải
    inputs = tokenizer(prompt, return_tensors="pt").to(device) # Chuyển inputs lên cùng device với model

    generation_config = dict(
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        do_sample=True if temperature > 0 else False, # Chỉ sample khi temperature > 0
        pad_token_id=tokenizer.eos_token_id, # Quan trọng cho batch generation và padding
        eos_token_id=tokenizer.eos_token_id,
    )

    try:
        # # Tùy chọn: Sử dụng TextStreamer nếu bạn muốn stream từng token (cần sửa đổi hàm này để yield)
        # text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        # output_ids = llama_model.generate(**inputs, streamer=text_streamer, **generation_config)
        # response_text = "" # TextStreamer sẽ in ra, không cần decode lại nếu chỉ để hiển thị

        # Generate bình thường để trả về chuỗi hoàn chỉnh
        output_ids = llama_model.generate(**inputs, **generation_config)

        # Lấy phần token được tạo mới (sau prompt)
        input_length = inputs.input_ids.shape[1]
        generated_ids = output_ids[0][input_length:]
        response_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        print("--- [LLM Handler] Tạo câu trả lời hoàn tất. ---")
        # print(f"--- [LLM Handler] Response text: {response_text[:300]}...")
        return response_text

    except Exception as e:
        print(f"Lỗi trong quá trình LLM generating: {e}")
        return "Xin lỗi, đã có lỗi xảy ra trong quá trình tạo câu trả lời từ mô hình ngôn ngữ."


# --- (Tùy chọn) Hàm main để test nhanh file này ---
if __name__ == '__main__':
    # Phần này chỉ để test, bạn cần mock các đối tượng hoặc tải thật
    print("Chạy test cho llm_handler.py (chưa có mock dữ liệu)...")

    # # Ví dụ cách mock (bạn cần dữ liệu thật hoặc mock phức tạp hơn để chạy)
    # mock_llm_model, mock_tokenizer = load_llm_model_and_tokenizer(
    #     "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # Thay bằng model bạn dùng hoặc một model nhỏ để test
    #     # model_name_or_path="path/to/your/lora_model_base" # Nếu test với model đã tải
    # )
    #
    # if mock_llm_model and mock_tokenizer:
    #     # Mock các thành phần RAG
    #     class MockFAISSIndex:
    #         def __init__(self): self.ntotal = 0
    #         def search(self, query, k): return ([], []) # Trả về không có gì
    #
    #     class MockEmbeddingModel:
    #         def encode(self, text, convert_to_tensor, device): return torch.randn(1, 10) # Vector dummy
    #
    #     class MockBM25Model:
    #         def get_scores(self, query_tokens): return []
    #
    #     def mock_search_relevant_laws(**kwargs):
    #         print(f"Mock search_relevant_laws called with query: {kwargs.get('query_text')}")
    #         # Trả về một vài kết quả giả để test formatting
    #         return [
    #             {
    #                 "text": "Người điều khiển xe máy không đội mũ bảo hiểm sẽ bị phạt tiền.",
    #                 "metadata": {
    #                     "source_document": "Nghị định 100/2019/NĐ-CP",
    #                     "article": "6", "clause_number": "2", "point_id": "i",
    #                     "article_title": "Xử phạt người điều khiển xe mô tô, xe gắn máy",
    #                     "has_fine": True, "individual_fine_min": 200000, "individual_fine_max": 300000,
    #                 }
    #             }
    #         ]
    #
    #     test_query = "Không đội mũ bảo hiểm xe máy phạt bao nhiêu?"
    #     response = generate_response(
    #         query=test_query,
    #         llama_model=mock_llm_model,
    #         tokenizer=mock_tokenizer,
    #         faiss_index=MockFAISSIndex(),
    #         embed_model=MockEmbeddingModel(),
    #         chunks_data_list=[{"text": "dummy chunk", "metadata": {}}],
    #         bm25_model=MockBM25Model(),
    #         search_function=mock_search_relevant_laws, # Truyền hàm mock
    #         search_k=1
    #     )
    #     print("\n--- Câu trả lời Test ---")
    #     print(response)
    # else:
    #     print("Không thể tải mock LLM model để test.")
    pass # Bỏ qua phần test nếu không có mock