Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch # Cần cho việc kiểm tra CUDA và device | |
# Import các hàm và lớp từ các file của bạn | |
from retrieval import ( | |
process_law_data_to_chunks, | |
# VEHICLE_TYPE_MAP, # Có thể không cần import trực tiếp nếu chỉ dùng trong retrieval.py | |
# get_standardized_vehicle_type, # Tương tự | |
# analyze_query, # Được gọi bởi search_relevant_laws | |
tokenize_vi_for_bm25_setup, # Cần cho BM25 | |
search_relevant_laws | |
) | |
from llm_handler import generate_response # Giả sử hàm này đã được điều chỉnh để nhận model, tokenizer, etc. | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from rank_bm25 import BM25Okapi | |
import json | |
from unsloth import FastLanguageModel # Từ llm_handler.py hoặc import trực tiếp nếu logic tải model ở đây | |
# --- KHỞI TẠO MỘT LẦN KHI APP KHỞI ĐỘNG --- | |
# Đường dẫn (điều chỉnh nếu cần, có thể dùng os.path.join) | |
JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json" | |
FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index" | |
LLM_MODEL_PATH = "models/lora_model_base" # Hoặc đường dẫn cục bộ | |
EMBEDDING_MODEL_PATH = "models/embedding_model" | |
# 1. Tải và xử lý dữ liệu luật | |
print("Loading and processing law data...") | |
try: | |
with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f: | |
raw_data_from_file = json.load(f) | |
chunks_data = process_law_data_to_chunks(raw_data_from_file) | |
print(f"Loaded {len(chunks_data)} chunks.") | |
if not chunks_data: | |
raise ValueError("Chunks data is empty after processing.") | |
except Exception as e: | |
print(f"Error loading/processing law data: {e}") | |
chunks_data = [] # Hoặc xử lý lỗi khác | |
# 2. Tải mô hình embedding | |
print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device) | |
print("Embedding model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading embedding model: {e}") | |
embedding_model = None # Xử lý lỗi | |
# 3. Tải FAISS index | |
print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...") | |
try: | |
faiss_index = faiss.read_index(FAISS_INDEX_PATH) | |
print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}") | |
except Exception as e: | |
print(f"Error loading FAISS index: {e}") | |
faiss_index = None # Xử lý lỗi | |
# 4. Tạo BM25 model | |
print("Creating BM25 model...") | |
bm25_model = None | |
if chunks_data: | |
try: | |
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] | |
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] | |
bm25_model = BM25Okapi(tokenized_corpus_bm25) | |
print("BM25 model created successfully.") | |
except Exception as e: | |
print(f"Error creating BM25 model: {e}") | |
else: | |
print("Skipping BM25 model creation as chunks_data is empty.") | |
# 5. Tải mô hình LLM và tokenizer (sử dụng Unsloth) | |
print(f"Loading LLM model: {LLM_MODEL_PATH}...") | |
try: | |
# Nên đặt logic tải model LLM vào llm_handler.py và gọi hàm đó ở đây | |
# Hoặc trực tiếp: | |
llm_model, llm_tokenizer = FastLanguageModel.from_pretrained( | |
model_name=LLM_MODEL_PATH, # Đường dẫn tới model đã fine-tune | |
max_seq_length=2048, | |
dtype=None, # Unsloth sẽ tự động chọn | |
load_in_4bit=True, # Sử dụng 4-bit quantization | |
) | |
FastLanguageModel.for_inference(llm_model) # Tối ưu cho inference | |
print("LLM model and tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"Error loading LLM model: {e}") | |
llm_model = None | |
llm_tokenizer = None | |
# --- KẾT THÚC KHỞI TẠO MỘT LẦN --- | |
# Hàm respond mới sẽ sử dụng các model và data đã tải ở trên | |
def respond(message, history: list[tuple[str, str]]): | |
if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]): | |
# Ghi log chi tiết hơn ở đây nếu cần để biết thành phần nào bị thiếu | |
missing_components = [] | |
if not chunks_data: missing_components.append("chunks_data") | |
if not embedding_model: missing_components.append("embedding_model") | |
if not faiss_index: missing_components.append("faiss_index") | |
if not bm25_model: missing_components.append("bm25_model") | |
if not llm_model: missing_components.append("llm_model") | |
if not llm_tokenizer: missing_components.append("llm_tokenizer") | |
error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space." | |
print(error_msg) # In ra console log của Space | |
return error_msg # Trả về cho người dùng | |
try: | |
response_text = generate_response( | |
query=message, | |
llama_model=llm_model, | |
tokenizer=llm_tokenizer, | |
faiss_index=faiss_index, | |
embed_model=embedding_model, | |
chunks_data_list=chunks_data, | |
bm25_model=bm25_model, | |
search_function=search_relevant_laws # << RẤT QUAN TRỌNG: Đã thêm tham số này | |
# Bạn có thể truyền thêm các tham số search_k, search_multiplier, | |
# rrf_k_constant, max_new_tokens, temperature, etc. vào đây | |
# nếu bạn muốn ghi đè giá trị mặc định trong llm_handler.generate_response | |
# Ví dụ: | |
# search_k=5, | |
# max_new_tokens=768 | |
) | |
yield response_text | |
except Exception as e: | |
# Ghi log lỗi chi tiết hơn | |
import traceback | |
print(f"Error during response generation for query '{message}': {e}") | |
print(traceback.format_exc()) # In stack trace để debug | |
yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên." | |
# Giao diện Gradio | |
# Bỏ các additional_inputs không cần thiết nếu chúng được xử lý bên trong generate_response | |
# hoặc nếu bạn không muốn người dùng cuối thay đổi chúng. | |
demo = gr.ChatInterface( | |
respond, | |
# additional_inputs=[ # Bạn có thể thêm lại nếu muốn người dùng tùy chỉnh | |
# gr.Textbox(value="You are a helpful Law Chatbot.", label="System message"), # Ví dụ | |
# ] | |
) | |
if __name__ == "__main__": | |
demo.launch() |