Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
import gradio as gr | |
# 1. Cấu hình tên mô hình gốc (base model) | |
base_model_name = "sail/Sailor-1.8B-Chat" | |
# 2. Load tokenizer từ thư mục adapter | |
adapter_path = "./Sailor-1.8B-Chat-SFT" | |
tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) | |
# 3. Load base model và adapter | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16) | |
model.eval() | |
# 4. Hàm trò chuyện | |
def chat_fn(message, history): | |
# Biên dịch lịch sử hội thoại sang định dạng messages | |
messages = [] | |
for user_msg, bot_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
# Áp dụng chat template chuẩn | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
return_tensors="pt", | |
add_generation_prompt=True, | |
truncation=True | |
).to(model.device) | |
# Sinh phản hồi | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=512, | |
do_sample=True, | |
top_k=50, | |
top_p=0.85, | |
temperature=0.9, | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# Tách phần phản hồi | |
generated_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() | |
return generated_text | |
# 5. Giao diện Gradio | |
chatbot = gr.ChatInterface( | |
fn=chat_fn, | |
title="🧭 Sailor-1.8B-Chat-SFT", | |
description="Demo chatbot sử dụng mô hình fine-tune từ Sailor-1.8B với PEFT LoRA.", | |
theme="soft", | |
examples=[ | |
"Xin chào!", | |
"Bạn có thể giải thích học máy là gì không?", | |
"Kể cho tôi một sự thật thú vị về khoa học.", | |
], | |
) | |
if __name__ == "__main__": | |
chatbot.launch() | |