chatbot_demo / llm_handler.py
deddoggo's picture
update full version
c69c2f9
raw
history blame
15.6 kB
# llm_handler.py
import torch
import re
import json
from unsloth import FastLanguageModel
# from transformers import TextStreamer # Bỏ comment nếu bạn muốn dùng TextStreamer để stream token
# Giả định rằng hàm search_relevant_laws được import từ file retrieval.py
# Nếu file retrieval.py nằm cùng cấp, bạn có thể import như sau:
# from retrieval import search_relevant_laws
# Hoặc nếu bạn muốn hàm này độc lập hơn, bạn có thể truyền retrieved_results vào generate_response
# thay vì truyền tất cả các thành phần của RAG.
# Tuy nhiên, dựa theo code gốc, generate_response gọi search_relevant_laws.
# --- HÀM TẢI MÔ HÌNH LLM VÀ TOKENIZER ---
def load_llm_model_and_tokenizer(
model_name_or_path: str,
max_seq_length: int = 2048,
load_in_4bit: bool = True,
device_map: str = "auto" # Cho phép Unsloth tự quyết định device map
):
"""
Tải mô hình ngôn ngữ lớn (LLM) đã được fine-tune bằng Unsloth và tokenizer tương ứng.
Args:
model_name_or_path (str): Tên hoặc đường dẫn đến mô hình đã fine-tune.
max_seq_length (int): Độ dài chuỗi tối đa mà mô hình hỗ trợ.
load_in_4bit (bool): Có tải mô hình ở dạng 4-bit quantization hay không.
device_map (str): Cách map model lên các device (ví dụ "auto", "cuda:0").
Returns:
tuple: (model, tokenizer) nếu thành công, (None, None) nếu có lỗi.
"""
print(f"Đang tải LLM model: {model_name_or_path}...")
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name_or_path,
max_seq_length=max_seq_length,
dtype=None, # Unsloth sẽ tự động chọn dtype tối ưu
load_in_4bit=load_in_4bit,
device_map=device_map, # Thêm device_map
# token = "hf_YOUR_TOKEN_HERE" # Nếu model là private trên Hugging Face Hub
)
FastLanguageModel.for_inference(model) # Tối ưu hóa mô hình cho inference
print("Tải LLM model và tokenizer thành công.")
return model, tokenizer
except Exception as e:
print(f"Lỗi khi tải LLM model và tokenizer: {e}")
return None, None
# --- HÀM TẠO CÂU TRẢ LỜI TỪ LLM ---
def generate_response(
query: str,
llama_model, # Mô hình LLM đã tải
tokenizer, # Tokenizer tương ứng
# Các thành phần RAG được truyền từ app.py (đã được tải trước đó)
faiss_index,
embed_model,
chunks_data_list: list,
bm25_model,
# Các tham số cho search_relevant_laws
search_k: int = 5,
search_multiplier: int = 10, # initial_k_multiplier
rrf_k_constant: int = 60,
# Các tham số cho generation của LLM
max_new_tokens: int = 768, # Tăng lên một chút cho câu trả lời đầy đủ hơn
temperature: float = 0.4, # Giảm một chút để câu trả lời bớt ngẫu nhiên, tập trung hơn
top_p: float = 0.9, # Giữ nguyên hoặc giảm nhẹ
top_k: int = 40,
repetition_penalty: float = 1.15, # Tăng nhẹ để tránh lặp từ
# Tham số để import hàm search_relevant_laws
search_function # Đây là hàm search_relevant_laws được truyền vào
):
"""
Truy xuất ngữ cảnh bằng hàm search_relevant_laws (được truyền vào)
và tạo câu trả lời từ LLM dựa trên ngữ cảnh đó.
Args:
query (str): Câu truy vấn của người dùng.
llama_model: Mô hình LLM đã tải.
tokenizer: Tokenizer tương ứng.
faiss_index: FAISS index đã tải.
embed_model: Mô hình embedding đã tải.
chunks_data_list (list): Danh sách các chunk dữ liệu luật.
bm25_model: Mô hình BM25 đã tạo.
search_k (int): Số lượng kết quả cuối cùng muốn lấy từ hàm search.
search_multiplier (int): Hệ số initial_k_multiplier cho hàm search.
rrf_k_constant (int): Hằng số k cho RRF trong hàm search.
max_new_tokens (int): Số token tối đa được tạo mới bởi LLM.
temperature (float): Nhiệt độ cho việc sinh văn bản.
top_p (float): Tham số top-p cho nucleus sampling.
top_k (int): Tham số top-k.
repetition_penalty (float): Phạt cho việc lặp từ.
search_function: Hàm thực hiện tìm kiếm (ví dụ: retrieval.search_relevant_laws).
Returns:
str: Câu trả lời được tạo ra bởi LLM.
"""
print(f"\n--- [LLM Handler] Bắt đầu xử lý query: '{query}' ---")
# === 1. Truy xuất ngữ cảnh (Sử dụng hàm search_function được truyền vào) ===
print("--- [LLM Handler] Bước 1: Truy xuất ngữ cảnh (Hybrid Search)... ---")
try:
retrieved_results = search_function(
query_text=query,
embedding_model=embed_model,
faiss_index=faiss_index,
chunks_data=chunks_data_list,
bm25_model=bm25_model,
k=search_k,
initial_k_multiplier=search_multiplier,
rrf_k_constant=rrf_k_constant
# Các tham số boost có thể lấy giá trị mặc định trong search_function
# hoặc truyền vào đây nếu muốn tùy chỉnh sâu hơn từ app.py
)
print(f"--- [LLM Handler] Truy xuất xong, số kết quả: {len(retrieved_results)} ---")
if not retrieved_results:
print("--- [LLM Handler] Không tìm thấy ngữ cảnh nào. ---")
except Exception as e:
print(f"Lỗi trong quá trình truy xuất ngữ cảnh: {e}")
retrieved_results = [] # Xử lý lỗi bằng cách trả về danh sách rỗng
# === 2. Định dạng Context từ retrieved_results ===
print("--- [LLM Handler] Bước 2: Định dạng context cho LLM... ---")
context_parts = []
if not retrieved_results:
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu để trả lời câu hỏi này."
else:
for i, res in enumerate(retrieved_results):
metadata = res.get('metadata', {})
article_title = metadata.get('article_title', 'N/A') # Lấy tiêu đề Điều
article = metadata.get('article', 'N/A')
clause = metadata.get('clause_number', 'N/A') # Sửa key cho khớp với process_law_data_to_chunks
point = metadata.get('point_id', '')
source = metadata.get('source_document', 'N/A')
text_content = res.get('text', '*Nội dung không có*')
# Header rõ ràng cho mỗi nguồn
header_parts = [f"Trích dẫn {i+1}:"]
if source != 'N/A':
header_parts.append(f"(Nguồn: {source})")
if article != 'N/A':
header_parts.append(f"Điều {article}")
if article_title != 'N/A' and article_title != article: # Chỉ thêm tiêu đề nếu khác số điều
header_parts.append(f"({article_title})")
if clause != 'N/A':
header_parts.append(f", Khoản {clause}")
if point: # point_id có thể là None hoặc rỗng
header_parts.append(f", Điểm {point}")
header = " ".join(header_parts)
# Bổ sung thông tin phạt/điểm vào header nếu query có đề cập
# và metadata của chunk có thông tin đó (lấy từ logic boosting của search_relevant_laws)
query_analysis = metadata.get("query_analysis_for_boost", {}) # Giả sử search_relevant_laws có thể trả về
# Hoặc phân tích lại query ở đây (ít hiệu quả hơn)
mentions_fine_in_query = bool(re.search(r'tiền|phạt|bao nhiêu đồng|mức phạt', query.lower()))
mentions_points_in_query = bool(re.search(r'điểm|trừ điểm|bằng lái|gplx', query.lower()))
fine_info_text = []
if metadata.get("has_fine") and mentions_fine_in_query:
if metadata.get("individual_fine_min") is not None and metadata.get("individual_fine_max") is not None:
fine_info_text.append(f"Phạt tiền: {metadata.get('individual_fine_min'):,} - {metadata.get('individual_fine_max'):,} VND.")
elif metadata.get("overall_fine_note_for_clause_text"):
fine_info_text.append(f"Ghi chú phạt tiền: {metadata.get('overall_fine_note_for_clause_text')}")
points_info_text = []
if metadata.get("has_points_deduction") and mentions_points_in_query:
if metadata.get("points_deducted_values_str"):
points_info_text.append(f"Trừ điểm: {metadata.get('points_deducted_values_str')} điểm.")
elif metadata.get("overall_points_deduction_note_for_clause_text"):
points_info_text.append(f"Ghi chú trừ điểm: {metadata.get('overall_points_deduction_note_for_clause_text')}")
penalty_summary = ""
if fine_info_text or points_info_text:
penalty_summary = " (Liên quan: " + " ".join(fine_info_text + points_info_text) + ")"
context_parts.append(f"{header}{penalty_summary}\nNội dung: {text_content}")
context = "\n\n---\n\n".join(context_parts)
# print("\n--- [LLM Handler] Context đã định dạng ---\n", context[:1000] + "...") # Xem trước context
# === 3. Xây dựng Prompt ===
# Sử dụng định dạng prompt mà mô hình của bạn được fine-tune (ví dụ: Alpaca)
# Dưới đây là một ví dụ, bạn có thể cần điều chỉnh cho phù hợp với `lora_model_base`
prompt = f"""Bạn là một trợ lý AI chuyên tư vấn về luật giao thông đường bộ Việt Nam.
Nhiệm vụ của bạn là dựa vào các thông tin luật được cung cấp dưới đây để trả lời câu hỏi của người dùng một cách chính xác, chi tiết và dễ hiểu.
Nếu thông tin không đủ hoặc không có trong các trích dẫn được cung cấp, hãy trả lời rằng bạn không tìm thấy thông tin cụ thể trong tài liệu được cung cấp.
Tránh đưa ra ý kiến cá nhân hoặc thông tin không có trong ngữ cảnh. Hãy trích dẫn điều, khoản, điểm nếu có thể.
### Thông tin luật được trích dẫn:
{context}
### Câu hỏi của người dùng:
{query}
### Trả lời của bạn:"""
# print("\n--- [LLM Handler] Prompt hoàn chỉnh (một phần) ---\n", prompt[:1000] + "...")
# === 4. Tạo câu trả lời từ LLM ===
print("--- [LLM Handler] Bước 3: Tạo câu trả lời từ LLM... ---")
device = llama_model.device # Lấy device từ model đã tải
inputs = tokenizer(prompt, return_tensors="pt").to(device) # Chuyển inputs lên cùng device với model
generation_config = dict(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True if temperature > 0 else False, # Chỉ sample khi temperature > 0
pad_token_id=tokenizer.eos_token_id, # Quan trọng cho batch generation và padding
eos_token_id=tokenizer.eos_token_id,
)
try:
# # Tùy chọn: Sử dụng TextStreamer nếu bạn muốn stream từng token (cần sửa đổi hàm này để yield)
# text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# output_ids = llama_model.generate(**inputs, streamer=text_streamer, **generation_config)
# response_text = "" # TextStreamer sẽ in ra, không cần decode lại nếu chỉ để hiển thị
# Generate bình thường để trả về chuỗi hoàn chỉnh
output_ids = llama_model.generate(**inputs, **generation_config)
# Lấy phần token được tạo mới (sau prompt)
input_length = inputs.input_ids.shape[1]
generated_ids = output_ids[0][input_length:]
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
print("--- [LLM Handler] Tạo câu trả lời hoàn tất. ---")
# print(f"--- [LLM Handler] Response text: {response_text[:300]}...")
return response_text
except Exception as e:
print(f"Lỗi trong quá trình LLM generating: {e}")
return "Xin lỗi, đã có lỗi xảy ra trong quá trình tạo câu trả lời từ mô hình ngôn ngữ."
# --- (Tùy chọn) Hàm main để test nhanh file này ---
if __name__ == '__main__':
# Phần này chỉ để test, bạn cần mock các đối tượng hoặc tải thật
print("Chạy test cho llm_handler.py (chưa có mock dữ liệu)...")
# # Ví dụ cách mock (bạn cần dữ liệu thật hoặc mock phức tạp hơn để chạy)
# mock_llm_model, mock_tokenizer = load_llm_model_and_tokenizer(
# "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # Thay bằng model bạn dùng hoặc một model nhỏ để test
# # model_name_or_path="path/to/your/lora_model_base" # Nếu test với model đã tải
# )
#
# if mock_llm_model and mock_tokenizer:
# # Mock các thành phần RAG
# class MockFAISSIndex:
# def __init__(self): self.ntotal = 0
# def search(self, query, k): return ([], []) # Trả về không có gì
#
# class MockEmbeddingModel:
# def encode(self, text, convert_to_tensor, device): return torch.randn(1, 10) # Vector dummy
#
# class MockBM25Model:
# def get_scores(self, query_tokens): return []
#
# def mock_search_relevant_laws(**kwargs):
# print(f"Mock search_relevant_laws called with query: {kwargs.get('query_text')}")
# # Trả về một vài kết quả giả để test formatting
# return [
# {
# "text": "Người điều khiển xe máy không đội mũ bảo hiểm sẽ bị phạt tiền.",
# "metadata": {
# "source_document": "Nghị định 100/2019/NĐ-CP",
# "article": "6", "clause_number": "2", "point_id": "i",
# "article_title": "Xử phạt người điều khiển xe mô tô, xe gắn máy",
# "has_fine": True, "individual_fine_min": 200000, "individual_fine_max": 300000,
# }
# }
# ]
#
# test_query = "Không đội mũ bảo hiểm xe máy phạt bao nhiêu?"
# response = generate_response(
# query=test_query,
# llama_model=mock_llm_model,
# tokenizer=mock_tokenizer,
# faiss_index=MockFAISSIndex(),
# embed_model=MockEmbeddingModel(),
# chunks_data_list=[{"text": "dummy chunk", "metadata": {}}],
# bm25_model=MockBM25Model(),
# search_function=mock_search_relevant_laws, # Truyền hàm mock
# search_k=1
# )
# print("\n--- Câu trả lời Test ---")
# print(response)
# else:
# print("Không thể tải mock LLM model để test.")
pass # Bỏ qua phần test nếu không có mock