chatbot_demo / app.py
deddoggo's picture
update full version
c69c2f9
raw
history blame
6.66 kB
import gradio as gr
import torch # Cần cho việc kiểm tra CUDA và device
# Import các hàm và lớp từ các file của bạn
from retrieval import (
process_law_data_to_chunks,
# VEHICLE_TYPE_MAP, # Có thể không cần import trực tiếp nếu chỉ dùng trong retrieval.py
# get_standardized_vehicle_type, # Tương tự
# analyze_query, # Được gọi bởi search_relevant_laws
tokenize_vi_for_bm25_setup, # Cần cho BM25
search_relevant_laws
)
from llm_handler import generate_response # Giả sử hàm này đã được điều chỉnh để nhận model, tokenizer, etc.
from sentence_transformers import SentenceTransformer
import faiss
from rank_bm25 import BM25Okapi
import json
from unsloth import FastLanguageModel # Từ llm_handler.py hoặc import trực tiếp nếu logic tải model ở đây
# --- KHỞI TẠO MỘT LẦN KHI APP KHỞI ĐỘNG ---
# Đường dẫn (điều chỉnh nếu cần, có thể dùng os.path.join)
JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json"
FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index"
LLM_MODEL_PATH = "models/lora_model_base" # Hoặc đường dẫn cục bộ
EMBEDDING_MODEL_PATH = "models/embedding_model"
# 1. Tải và xử lý dữ liệu luật
print("Loading and processing law data...")
try:
with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f:
raw_data_from_file = json.load(f)
chunks_data = process_law_data_to_chunks(raw_data_from_file)
print(f"Loaded {len(chunks_data)} chunks.")
if not chunks_data:
raise ValueError("Chunks data is empty after processing.")
except Exception as e:
print(f"Error loading/processing law data: {e}")
chunks_data = [] # Hoặc xử lý lỗi khác
# 2. Tải mô hình embedding
print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device)
print("Embedding model loaded successfully.")
except Exception as e:
print(f"Error loading embedding model: {e}")
embedding_model = None # Xử lý lỗi
# 3. Tải FAISS index
print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...")
try:
faiss_index = faiss.read_index(FAISS_INDEX_PATH)
print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}")
except Exception as e:
print(f"Error loading FAISS index: {e}")
faiss_index = None # Xử lý lỗi
# 4. Tạo BM25 model
print("Creating BM25 model...")
bm25_model = None
if chunks_data:
try:
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
bm25_model = BM25Okapi(tokenized_corpus_bm25)
print("BM25 model created successfully.")
except Exception as e:
print(f"Error creating BM25 model: {e}")
else:
print("Skipping BM25 model creation as chunks_data is empty.")
# 5. Tải mô hình LLM và tokenizer (sử dụng Unsloth)
print(f"Loading LLM model: {LLM_MODEL_PATH}...")
try:
# Nên đặt logic tải model LLM vào llm_handler.py và gọi hàm đó ở đây
# Hoặc trực tiếp:
llm_model, llm_tokenizer = FastLanguageModel.from_pretrained(
model_name=LLM_MODEL_PATH, # Đường dẫn tới model đã fine-tune
max_seq_length=2048,
dtype=None, # Unsloth sẽ tự động chọn
load_in_4bit=True, # Sử dụng 4-bit quantization
)
FastLanguageModel.for_inference(llm_model) # Tối ưu cho inference
print("LLM model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading LLM model: {e}")
llm_model = None
llm_tokenizer = None
# --- KẾT THÚC KHỞI TẠO MỘT LẦN ---
# Hàm respond mới sẽ sử dụng các model và data đã tải ở trên
def respond(message, history: list[tuple[str, str]]):
if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]):
# Ghi log chi tiết hơn ở đây nếu cần để biết thành phần nào bị thiếu
missing_components = []
if not chunks_data: missing_components.append("chunks_data")
if not embedding_model: missing_components.append("embedding_model")
if not faiss_index: missing_components.append("faiss_index")
if not bm25_model: missing_components.append("bm25_model")
if not llm_model: missing_components.append("llm_model")
if not llm_tokenizer: missing_components.append("llm_tokenizer")
error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space."
print(error_msg) # In ra console log của Space
return error_msg # Trả về cho người dùng
try:
response_text = generate_response(
query=message,
llama_model=llm_model,
tokenizer=llm_tokenizer,
faiss_index=faiss_index,
embed_model=embedding_model,
chunks_data_list=chunks_data,
bm25_model=bm25_model,
search_function=search_relevant_laws # << RẤT QUAN TRỌNG: Đã thêm tham số này
# Bạn có thể truyền thêm các tham số search_k, search_multiplier,
# rrf_k_constant, max_new_tokens, temperature, etc. vào đây
# nếu bạn muốn ghi đè giá trị mặc định trong llm_handler.generate_response
# Ví dụ:
# search_k=5,
# max_new_tokens=768
)
yield response_text
except Exception as e:
# Ghi log lỗi chi tiết hơn
import traceback
print(f"Error during response generation for query '{message}': {e}")
print(traceback.format_exc()) # In stack trace để debug
yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên."
# Giao diện Gradio
# Bỏ các additional_inputs không cần thiết nếu chúng được xử lý bên trong generate_response
# hoặc nếu bạn không muốn người dùng cuối thay đổi chúng.
demo = gr.ChatInterface(
respond,
# additional_inputs=[ # Bạn có thể thêm lại nếu muốn người dùng tùy chỉnh
# gr.Textbox(value="You are a helpful Law Chatbot.", label="System message"), # Ví dụ
# ]
)
if __name__ == "__main__":
demo.launch()