Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
import gradio as gr | |
# Tên mô hình gốc | |
base_model_name = "sail/Sailor-1.8B-Chat" | |
adapter_path = "." # Tên thư mục adapter đã upload | |
# Load tokenizer và model | |
tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
model = PeftModel.from_pretrained(model, adapter_path) | |
model.eval() | |
# Hàm xử lý hội thoại | |
def chat_fn(message, history): | |
messages = [] | |
for user_msg, bot_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
return_tensors="pt", | |
add_generation_prompt=True, | |
truncation=True | |
).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=32, | |
do_sample=True, | |
top_k=50, | |
top_p=0.85, | |
temperature=0.9, | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip() | |
generated_text = generated_text.split("\n")[0].strip() | |
return generated_text | |
# Tạo giao diện Gradio | |
chatbot = gr.ChatInterface( | |
fn=chat_fn, | |
title="Sailor-1.8B-Chat-SFT", | |
description="Demo chatbot chạy trên adapter của Sailor-1.8B-SFT.", | |
theme="soft", | |
) | |
if __name__ == "__main__": | |
chatbot.launch() | |