Spaces:
Sleeping
Sleeping
# llm_handler.py | |
import torch | |
import re | |
import json | |
from unsloth import FastLanguageModel | |
# from transformers import TextStreamer # Bỏ comment nếu bạn muốn dùng TextStreamer để stream token | |
# Giả định rằng hàm search_relevant_laws được import từ file retrieval.py | |
# Nếu file retrieval.py nằm cùng cấp, bạn có thể import như sau: | |
# from retrieval import search_relevant_laws | |
# Hoặc nếu bạn muốn hàm này độc lập hơn, bạn có thể truyền retrieved_results vào generate_response | |
# thay vì truyền tất cả các thành phần của RAG. | |
# Tuy nhiên, dựa theo code gốc, generate_response gọi search_relevant_laws. | |
# --- HÀM TẢI MÔ HÌNH LLM VÀ TOKENIZER --- | |
def load_llm_model_and_tokenizer( | |
model_name_or_path: str, | |
max_seq_length: int = 2048, | |
load_in_4bit: bool = True, | |
device_map: str = "auto" # Cho phép Unsloth tự quyết định device map | |
): | |
""" | |
Tải mô hình ngôn ngữ lớn (LLM) đã được fine-tune bằng Unsloth và tokenizer tương ứng. | |
Args: | |
model_name_or_path (str): Tên hoặc đường dẫn đến mô hình đã fine-tune. | |
max_seq_length (int): Độ dài chuỗi tối đa mà mô hình hỗ trợ. | |
load_in_4bit (bool): Có tải mô hình ở dạng 4-bit quantization hay không. | |
device_map (str): Cách map model lên các device (ví dụ "auto", "cuda:0"). | |
Returns: | |
tuple: (model, tokenizer) nếu thành công, (None, None) nếu có lỗi. | |
""" | |
print(f"Đang tải LLM model: {model_name_or_path}...") | |
try: | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name_or_path, | |
max_seq_length=max_seq_length, | |
dtype=None, # Unsloth sẽ tự động chọn dtype tối ưu | |
load_in_4bit=load_in_4bit, | |
device_map=device_map, # Thêm device_map | |
# token = "hf_YOUR_TOKEN_HERE" # Nếu model là private trên Hugging Face Hub | |
) | |
FastLanguageModel.for_inference(model) # Tối ưu hóa mô hình cho inference | |
print("Tải LLM model và tokenizer thành công.") | |
return model, tokenizer | |
except Exception as e: | |
print(f"Lỗi khi tải LLM model và tokenizer: {e}") | |
return None, None | |
# --- HÀM TẠO CÂU TRẢ LỜI TỪ LLM --- | |
def generate_response( | |
query: str, | |
llama_model, # Mô hình LLM đã tải | |
tokenizer, # Tokenizer tương ứng | |
# Các thành phần RAG được truyền từ app.py (đã được tải trước đó) | |
faiss_index, | |
embed_model, | |
chunks_data_list: list, | |
bm25_model, | |
# Các tham số cho search_relevant_laws | |
search_k: int = 5, | |
search_multiplier: int = 10, # initial_k_multiplier | |
rrf_k_constant: int = 60, | |
# Các tham số cho generation của LLM | |
max_new_tokens: int = 768, # Tăng lên một chút cho câu trả lời đầy đủ hơn | |
temperature: float = 0.4, # Giảm một chút để câu trả lời bớt ngẫu nhiên, tập trung hơn | |
top_p: float = 0.9, # Giữ nguyên hoặc giảm nhẹ | |
top_k: int = 40, | |
repetition_penalty: float = 1.15, # Tăng nhẹ để tránh lặp từ | |
# Tham số để import hàm search_relevant_laws | |
search_function # Đây là hàm search_relevant_laws được truyền vào | |
): | |
""" | |
Truy xuất ngữ cảnh bằng hàm search_relevant_laws (được truyền vào) | |
và tạo câu trả lời từ LLM dựa trên ngữ cảnh đó. | |
Args: | |
query (str): Câu truy vấn của người dùng. | |
llama_model: Mô hình LLM đã tải. | |
tokenizer: Tokenizer tương ứng. | |
faiss_index: FAISS index đã tải. | |
embed_model: Mô hình embedding đã tải. | |
chunks_data_list (list): Danh sách các chunk dữ liệu luật. | |
bm25_model: Mô hình BM25 đã tạo. | |
search_k (int): Số lượng kết quả cuối cùng muốn lấy từ hàm search. | |
search_multiplier (int): Hệ số initial_k_multiplier cho hàm search. | |
rrf_k_constant (int): Hằng số k cho RRF trong hàm search. | |
max_new_tokens (int): Số token tối đa được tạo mới bởi LLM. | |
temperature (float): Nhiệt độ cho việc sinh văn bản. | |
top_p (float): Tham số top-p cho nucleus sampling. | |
top_k (int): Tham số top-k. | |
repetition_penalty (float): Phạt cho việc lặp từ. | |
search_function: Hàm thực hiện tìm kiếm (ví dụ: retrieval.search_relevant_laws). | |
Returns: | |
str: Câu trả lời được tạo ra bởi LLM. | |
""" | |
print(f"\n--- [LLM Handler] Bắt đầu xử lý query: '{query}' ---") | |
# === 1. Truy xuất ngữ cảnh (Sử dụng hàm search_function được truyền vào) === | |
print("--- [LLM Handler] Bước 1: Truy xuất ngữ cảnh (Hybrid Search)... ---") | |
try: | |
retrieved_results = search_function( | |
query_text=query, | |
embedding_model=embed_model, | |
faiss_index=faiss_index, | |
chunks_data=chunks_data_list, | |
bm25_model=bm25_model, | |
k=search_k, | |
initial_k_multiplier=search_multiplier, | |
rrf_k_constant=rrf_k_constant | |
# Các tham số boost có thể lấy giá trị mặc định trong search_function | |
# hoặc truyền vào đây nếu muốn tùy chỉnh sâu hơn từ app.py | |
) | |
print(f"--- [LLM Handler] Truy xuất xong, số kết quả: {len(retrieved_results)} ---") | |
if not retrieved_results: | |
print("--- [LLM Handler] Không tìm thấy ngữ cảnh nào. ---") | |
except Exception as e: | |
print(f"Lỗi trong quá trình truy xuất ngữ cảnh: {e}") | |
retrieved_results = [] # Xử lý lỗi bằng cách trả về danh sách rỗng | |
# === 2. Định dạng Context từ retrieved_results === | |
print("--- [LLM Handler] Bước 2: Định dạng context cho LLM... ---") | |
context_parts = [] | |
if not retrieved_results: | |
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu để trả lời câu hỏi này." | |
else: | |
for i, res in enumerate(retrieved_results): | |
metadata = res.get('metadata', {}) | |
article_title = metadata.get('article_title', 'N/A') # Lấy tiêu đề Điều | |
article = metadata.get('article', 'N/A') | |
clause = metadata.get('clause_number', 'N/A') # Sửa key cho khớp với process_law_data_to_chunks | |
point = metadata.get('point_id', '') | |
source = metadata.get('source_document', 'N/A') | |
text_content = res.get('text', '*Nội dung không có*') | |
# Header rõ ràng cho mỗi nguồn | |
header_parts = [f"Trích dẫn {i+1}:"] | |
if source != 'N/A': | |
header_parts.append(f"(Nguồn: {source})") | |
if article != 'N/A': | |
header_parts.append(f"Điều {article}") | |
if article_title != 'N/A' and article_title != article: # Chỉ thêm tiêu đề nếu khác số điều | |
header_parts.append(f"({article_title})") | |
if clause != 'N/A': | |
header_parts.append(f", Khoản {clause}") | |
if point: # point_id có thể là None hoặc rỗng | |
header_parts.append(f", Điểm {point}") | |
header = " ".join(header_parts) | |
# Bổ sung thông tin phạt/điểm vào header nếu query có đề cập | |
# và metadata của chunk có thông tin đó (lấy từ logic boosting của search_relevant_laws) | |
query_analysis = metadata.get("query_analysis_for_boost", {}) # Giả sử search_relevant_laws có thể trả về | |
# Hoặc phân tích lại query ở đây (ít hiệu quả hơn) | |
mentions_fine_in_query = bool(re.search(r'tiền|phạt|bao nhiêu đồng|mức phạt', query.lower())) | |
mentions_points_in_query = bool(re.search(r'điểm|trừ điểm|bằng lái|gplx', query.lower())) | |
fine_info_text = [] | |
if metadata.get("has_fine") and mentions_fine_in_query: | |
if metadata.get("individual_fine_min") is not None and metadata.get("individual_fine_max") is not None: | |
fine_info_text.append(f"Phạt tiền: {metadata.get('individual_fine_min'):,} - {metadata.get('individual_fine_max'):,} VND.") | |
elif metadata.get("overall_fine_note_for_clause_text"): | |
fine_info_text.append(f"Ghi chú phạt tiền: {metadata.get('overall_fine_note_for_clause_text')}") | |
points_info_text = [] | |
if metadata.get("has_points_deduction") and mentions_points_in_query: | |
if metadata.get("points_deducted_values_str"): | |
points_info_text.append(f"Trừ điểm: {metadata.get('points_deducted_values_str')} điểm.") | |
elif metadata.get("overall_points_deduction_note_for_clause_text"): | |
points_info_text.append(f"Ghi chú trừ điểm: {metadata.get('overall_points_deduction_note_for_clause_text')}") | |
penalty_summary = "" | |
if fine_info_text or points_info_text: | |
penalty_summary = " (Liên quan: " + " ".join(fine_info_text + points_info_text) + ")" | |
context_parts.append(f"{header}{penalty_summary}\nNội dung: {text_content}") | |
context = "\n\n---\n\n".join(context_parts) | |
# print("\n--- [LLM Handler] Context đã định dạng ---\n", context[:1000] + "...") # Xem trước context | |
# === 3. Xây dựng Prompt === | |
# Sử dụng định dạng prompt mà mô hình của bạn được fine-tune (ví dụ: Alpaca) | |
# Dưới đây là một ví dụ, bạn có thể cần điều chỉnh cho phù hợp với `lora_model_base` | |
prompt = f"""Bạn là một trợ lý AI chuyên tư vấn về luật giao thông đường bộ Việt Nam. | |
Nhiệm vụ của bạn là dựa vào các thông tin luật được cung cấp dưới đây để trả lời câu hỏi của người dùng một cách chính xác, chi tiết và dễ hiểu. | |
Nếu thông tin không đủ hoặc không có trong các trích dẫn được cung cấp, hãy trả lời rằng bạn không tìm thấy thông tin cụ thể trong tài liệu được cung cấp. | |
Tránh đưa ra ý kiến cá nhân hoặc thông tin không có trong ngữ cảnh. Hãy trích dẫn điều, khoản, điểm nếu có thể. | |
### Thông tin luật được trích dẫn: | |
{context} | |
### Câu hỏi của người dùng: | |
{query} | |
### Trả lời của bạn:""" | |
# print("\n--- [LLM Handler] Prompt hoàn chỉnh (một phần) ---\n", prompt[:1000] + "...") | |
# === 4. Tạo câu trả lời từ LLM === | |
print("--- [LLM Handler] Bước 3: Tạo câu trả lời từ LLM... ---") | |
device = llama_model.device # Lấy device từ model đã tải | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) # Chuyển inputs lên cùng device với model | |
generation_config = dict( | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True if temperature > 0 else False, # Chỉ sample khi temperature > 0 | |
pad_token_id=tokenizer.eos_token_id, # Quan trọng cho batch generation và padding | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
try: | |
# # Tùy chọn: Sử dụng TextStreamer nếu bạn muốn stream từng token (cần sửa đổi hàm này để yield) | |
# text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# output_ids = llama_model.generate(**inputs, streamer=text_streamer, **generation_config) | |
# response_text = "" # TextStreamer sẽ in ra, không cần decode lại nếu chỉ để hiển thị | |
# Generate bình thường để trả về chuỗi hoàn chỉnh | |
output_ids = llama_model.generate(**inputs, **generation_config) | |
# Lấy phần token được tạo mới (sau prompt) | |
input_length = inputs.input_ids.shape[1] | |
generated_ids = output_ids[0][input_length:] | |
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
print("--- [LLM Handler] Tạo câu trả lời hoàn tất. ---") | |
# print(f"--- [LLM Handler] Response text: {response_text[:300]}...") | |
return response_text | |
except Exception as e: | |
print(f"Lỗi trong quá trình LLM generating: {e}") | |
return "Xin lỗi, đã có lỗi xảy ra trong quá trình tạo câu trả lời từ mô hình ngôn ngữ." | |
# --- (Tùy chọn) Hàm main để test nhanh file này --- | |
if __name__ == '__main__': | |
# Phần này chỉ để test, bạn cần mock các đối tượng hoặc tải thật | |
print("Chạy test cho llm_handler.py (chưa có mock dữ liệu)...") | |
# # Ví dụ cách mock (bạn cần dữ liệu thật hoặc mock phức tạp hơn để chạy) | |
# mock_llm_model, mock_tokenizer = load_llm_model_and_tokenizer( | |
# "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # Thay bằng model bạn dùng hoặc một model nhỏ để test | |
# # model_name_or_path="path/to/your/lora_model_base" # Nếu test với model đã tải | |
# ) | |
# | |
# if mock_llm_model and mock_tokenizer: | |
# # Mock các thành phần RAG | |
# class MockFAISSIndex: | |
# def __init__(self): self.ntotal = 0 | |
# def search(self, query, k): return ([], []) # Trả về không có gì | |
# | |
# class MockEmbeddingModel: | |
# def encode(self, text, convert_to_tensor, device): return torch.randn(1, 10) # Vector dummy | |
# | |
# class MockBM25Model: | |
# def get_scores(self, query_tokens): return [] | |
# | |
# def mock_search_relevant_laws(**kwargs): | |
# print(f"Mock search_relevant_laws called with query: {kwargs.get('query_text')}") | |
# # Trả về một vài kết quả giả để test formatting | |
# return [ | |
# { | |
# "text": "Người điều khiển xe máy không đội mũ bảo hiểm sẽ bị phạt tiền.", | |
# "metadata": { | |
# "source_document": "Nghị định 100/2019/NĐ-CP", | |
# "article": "6", "clause_number": "2", "point_id": "i", | |
# "article_title": "Xử phạt người điều khiển xe mô tô, xe gắn máy", | |
# "has_fine": True, "individual_fine_min": 200000, "individual_fine_max": 300000, | |
# } | |
# } | |
# ] | |
# | |
# test_query = "Không đội mũ bảo hiểm xe máy phạt bao nhiêu?" | |
# response = generate_response( | |
# query=test_query, | |
# llama_model=mock_llm_model, | |
# tokenizer=mock_tokenizer, | |
# faiss_index=MockFAISSIndex(), | |
# embed_model=MockEmbeddingModel(), | |
# chunks_data_list=[{"text": "dummy chunk", "metadata": {}}], | |
# bm25_model=MockBM25Model(), | |
# search_function=mock_search_relevant_laws, # Truyền hàm mock | |
# search_k=1 | |
# ) | |
# print("\n--- Câu trả lời Test ---") | |
# print(response) | |
# else: | |
# print("Không thể tải mock LLM model để test.") | |
pass # Bỏ qua phần test nếu không có mock |