import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import gradio as gr # 1. Cấu hình tên mô hình gốc (base model) base_model_name = "sail/Sailor-1.8B-Chat" # 2. Load tokenizer từ thư mục adapter adapter_path = "./Sailor-1.8B-Chat-SFT" tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) # 3. Load base model và adapter model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16) model.eval() # 4. Hàm trò chuyện def chat_fn(message, history): # Biên dịch lịch sử hội thoại sang định dạng messages messages = [] for user_msg, bot_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) # Áp dụng chat template chuẩn input_ids = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, truncation=True ).to(model.device) # Sinh phản hồi with torch.no_grad(): outputs = model.generate( input_ids=input_ids, max_new_tokens=512, do_sample=True, top_k=50, top_p=0.85, temperature=0.9, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) # Tách phần phản hồi generated_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() return generated_text # 5. Giao diện Gradio chatbot = gr.ChatInterface( fn=chat_fn, title="🧭 Sailor-1.8B-Chat-SFT", description="Demo chatbot sử dụng mô hình fine-tune từ Sailor-1.8B với PEFT LoRA.", theme="soft", examples=[ "Xin chào!", "Bạn có thể giải thích học máy là gì không?", "Kể cho tôi một sự thật thú vị về khoa học.", ], ) if __name__ == "__main__": chatbot.launch()