import gradio as gr import torch # Cần cho việc kiểm tra CUDA và device # Import các hàm và lớp từ các file của bạn from retrieval import ( process_law_data_to_chunks, # VEHICLE_TYPE_MAP, # Có thể không cần import trực tiếp nếu chỉ dùng trong retrieval.py # get_standardized_vehicle_type, # Tương tự # analyze_query, # Được gọi bởi search_relevant_laws tokenize_vi_for_bm25_setup, # Cần cho BM25 search_relevant_laws ) from llm_handler import generate_response # Giả sử hàm này đã được điều chỉnh để nhận model, tokenizer, etc. from sentence_transformers import SentenceTransformer import faiss from rank_bm25 import BM25Okapi import json from unsloth import FastLanguageModel # Từ llm_handler.py hoặc import trực tiếp nếu logic tải model ở đây # --- KHỞI TẠO MỘT LẦN KHI APP KHỞI ĐỘNG --- # Đường dẫn (điều chỉnh nếu cần, có thể dùng os.path.join) JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json" FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index" LLM_MODEL_PATH = "models/lora_model_base" # Hoặc đường dẫn cục bộ EMBEDDING_MODEL_PATH = "models/embedding_model" # 1. Tải và xử lý dữ liệu luật print("Loading and processing law data...") try: with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f: raw_data_from_file = json.load(f) chunks_data = process_law_data_to_chunks(raw_data_from_file) print(f"Loaded {len(chunks_data)} chunks.") if not chunks_data: raise ValueError("Chunks data is empty after processing.") except Exception as e: print(f"Error loading/processing law data: {e}") chunks_data = [] # Hoặc xử lý lỗi khác # 2. Tải mô hình embedding print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...") device = "cuda" if torch.cuda.is_available() else "cpu" try: embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device) print("Embedding model loaded successfully.") except Exception as e: print(f"Error loading embedding model: {e}") embedding_model = None # Xử lý lỗi # 3. Tải FAISS index print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...") try: faiss_index = faiss.read_index(FAISS_INDEX_PATH) print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}") except Exception as e: print(f"Error loading FAISS index: {e}") faiss_index = None # Xử lý lỗi # 4. Tạo BM25 model print("Creating BM25 model...") bm25_model = None if chunks_data: try: corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] bm25_model = BM25Okapi(tokenized_corpus_bm25) print("BM25 model created successfully.") except Exception as e: print(f"Error creating BM25 model: {e}") else: print("Skipping BM25 model creation as chunks_data is empty.") # 5. Tải mô hình LLM và tokenizer (sử dụng Unsloth) print(f"Loading LLM model: {LLM_MODEL_PATH}...") try: # Nên đặt logic tải model LLM vào llm_handler.py và gọi hàm đó ở đây # Hoặc trực tiếp: llm_model, llm_tokenizer = FastLanguageModel.from_pretrained( model_name=LLM_MODEL_PATH, # Đường dẫn tới model đã fine-tune max_seq_length=2048, dtype=None, # Unsloth sẽ tự động chọn load_in_4bit=True, # Sử dụng 4-bit quantization ) FastLanguageModel.for_inference(llm_model) # Tối ưu cho inference print("LLM model and tokenizer loaded successfully.") except Exception as e: print(f"Error loading LLM model: {e}") llm_model = None llm_tokenizer = None # --- KẾT THÚC KHỞI TẠO MỘT LẦN --- # Hàm respond mới sẽ sử dụng các model và data đã tải ở trên def respond(message, history: list[tuple[str, str]]): if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]): # Ghi log chi tiết hơn ở đây nếu cần để biết thành phần nào bị thiếu missing_components = [] if not chunks_data: missing_components.append("chunks_data") if not embedding_model: missing_components.append("embedding_model") if not faiss_index: missing_components.append("faiss_index") if not bm25_model: missing_components.append("bm25_model") if not llm_model: missing_components.append("llm_model") if not llm_tokenizer: missing_components.append("llm_tokenizer") error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space." print(error_msg) # In ra console log của Space return error_msg # Trả về cho người dùng try: response_text = generate_response( query=message, llama_model=llm_model, tokenizer=llm_tokenizer, faiss_index=faiss_index, embed_model=embedding_model, chunks_data_list=chunks_data, bm25_model=bm25_model, search_function=search_relevant_laws # << RẤT QUAN TRỌNG: Đã thêm tham số này # Bạn có thể truyền thêm các tham số search_k, search_multiplier, # rrf_k_constant, max_new_tokens, temperature, etc. vào đây # nếu bạn muốn ghi đè giá trị mặc định trong llm_handler.generate_response # Ví dụ: # search_k=5, # max_new_tokens=768 ) yield response_text except Exception as e: # Ghi log lỗi chi tiết hơn import traceback print(f"Error during response generation for query '{message}': {e}") print(traceback.format_exc()) # In stack trace để debug yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên." # Giao diện Gradio # Bỏ các additional_inputs không cần thiết nếu chúng được xử lý bên trong generate_response # hoặc nếu bạn không muốn người dùng cuối thay đổi chúng. demo = gr.ChatInterface( respond, # additional_inputs=[ # Bạn có thể thêm lại nếu muốn người dùng tùy chỉnh # gr.Textbox(value="You are a helpful Law Chatbot.", label="System message"), # Ví dụ # ] ) if __name__ == "__main__": demo.launch()