import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import gradio as gr # Tên mô hình gốc base_model_name = "sail/Sailor-1.8B-Chat" adapter_path = "." # Tên thư mục adapter đã upload # Load tokenizer và model tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( base_model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True ) model = PeftModel.from_pretrained(model, adapter_path) model.eval() # Hàm xử lý hội thoại def chat_fn(message, history): messages = [] for user_msg, bot_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, truncation=True ).to(model.device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, max_new_tokens=32, do_sample=True, top_k=50, top_p=0.85, temperature=0.9, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() generated_text = generated_text.split("\n")[0].strip() return generated_text # Tạo giao diện Gradio chatbot = gr.ChatInterface( fn=chat_fn, title="Sailor-1.8B-Chat-SFT", description="Demo chatbot chạy trên adapter của Sailor-1.8B-SFT.", theme="soft", ) if __name__ == "__main__": chatbot.launch()