Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
# Use a chat-tuned LLaMA 3.1 model | |
model_id = "meta-llama/Llama-3.1-8B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
# Generation pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=False, | |
) | |
# Core chat function with custom system message | |
def chat_fn(message, history, system_prompt): | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
else: | |
messages.append({"role": "system", "content": "You are a helpful assistant."}) | |
for user_msg, bot_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
# Manual chat formatting (LLaMA 3.1 style) | |
prompt = "" | |
for msg in messages: | |
if msg["role"] == "system": | |
prompt += f"<<SYS>>\n{msg['content']}\n<</SYS>>\n\n" | |
elif msg["role"] == "user": | |
prompt += f"[INST] {msg['content']} [/INST]\n" | |
elif msg["role"] == "assistant": | |
prompt += f"{msg['content']}\n" | |
output = pipe( | |
prompt, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
) | |
return output[0]["generated_text"] | |
# Gradio UI: includes system prompt | |
with gr.Blocks() as demo: | |
gr.Markdown("## LLaMA 3.1 Chat (with Custom System Prompt)") | |
system_input = gr.Textbox( | |
label="System Prompt (Optional)", | |
value="You are a helpful AI assistant.", | |
lines=2 | |
) | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Enter your message...") | |
clear = gr.Button("Clear Chat") | |
history_state = gr.State([]) | |
def user(message, history): | |
return "", history + [[message, None]] | |
def bot(history, system_prompt): | |
message, _ = history[-1] | |
response = chat_fn(message, history[:-1], system_prompt) | |
history[-1][1] = response | |
return history | |
msg.submit(user, [msg, history_state], [msg, chatbot], queue=False).then( | |
bot, [chatbot, system_input], chatbot, queue=True | |
) | |
clear.click(lambda: ([], []), outputs=[chatbot, history_state]) | |
if __name__ == "__main__": | |
demo.launch() | |