File size: 2,303 Bytes
d9efefa
25a42cf
d007845
d9efefa
 
1664d4f
d007845
 
d9efefa
1664d4f
d007845
d9efefa
d007845
76007e9
08edbce
 
 
 
6eb8c58
d9efefa
1664d4f
d007845
 
 
d9efefa
78ee627
 
969ff86
78ee627
 
19c1f0e
78ee627
 
 
 
d9efefa
 
78ee627
19c1f0e
 
78ee627
969ff86
fea1dfe
19c1f0e
08edbce
d007845
 
d9efefa
d007845
 
08edbce
d9efefa
d007845
78ee627
d9efefa
d007845
 
 
 
d9efefa
d007845
 
d9efefa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import os
import json
from huggingface_hub import InferenceClient

# Load danh sách nhân vật từ file JSON
with open("characters.json", "r", encoding="utf-8") as f:
    characters = json.load(f)["characters"]

# Tra cứu theo ID
character_dict = {c["id"]: c for c in characters}

# Hugging Face Inference API
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if token is None:
    raise ValueError("Bạn cần đặt biến môi trường HUGGINGFACEHUB_API_TOKEN để gọi API.")

# Dùng model Zephyr đã được deploy sẵn
client = InferenceClient("Rookie/Llama-3-8B-Instruct-Chinese", token=token)

# Hàm phản hồi
def respond(message, history, character_id, max_tokens, temperature, top_p):
    char = character_dict[character_id]
    system_message = char["persona_prompt"]

    # Format prompt theo style Zephyr (Instruct-tuned)
    prompt = f"<|system|>\n{system_message}</s>\n"
    for user_msg, bot_msg in history:
        prompt += f"<|user|>\n{user_msg}</s>\n<|assistant|>\n{bot_msg}</s>\n"
    prompt += f"<|user|>\n{message}</s>\n<|assistant|>\n"

    # Gọi text_generation API
    response = client.text_generation(
        prompt=prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )

    reply = response.strip()
    history.append((message, reply))
    return history, history

# Hiển thị tên nhân vật
def format_label(c):
    return f"{c['name']} ({c['personality']}, {c['appearance']}, {c['voice']})"

char_choices = [(format_label(c), c["id"]) for c in characters]

# Giao diện Gradio
demo = gr.ChatInterface(
    fn=respond,
    chatbot=gr.Chatbot(label="Trò chuyện", type="tuples"),  # Cảnh báo deprecated nhưng vẫn hoạt động
    additional_inputs=[
        gr.Dropdown(choices=char_choices, value=characters[0]["id"], label="Chọn nhân vật"),
        gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
    ],
    title="🧠 Trợ lý ảo hoạt hình",
    description="Chọn nhân vật hoạt hình lý tưởng để trò chuyện!",
)

if __name__ == "__main__":
    demo.launch()