llama3.1_8b / app.py
Himanshu806's picture
Update app.py
60d2619 verified
import spaces
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Use a chat-tuned LLaMA 3.1 model
model_id = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# Generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=False,
)
# Core chat function with custom system message
@spaces.GPU
def chat_fn(message, history, system_prompt):
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
else:
messages.append({"role": "system", "content": "You are a helpful assistant."})
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
# Manual chat formatting (LLaMA 3.1 style)
prompt = ""
for msg in messages:
if msg["role"] == "system":
prompt += f"<<SYS>>\n{msg['content']}\n<</SYS>>\n\n"
elif msg["role"] == "user":
prompt += f"[INST] {msg['content']} [/INST]\n"
elif msg["role"] == "assistant":
prompt += f"{msg['content']}\n"
output = pipe(
prompt,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
)
return output[0]["generated_text"]
# Gradio UI: includes system prompt
with gr.Blocks() as demo:
gr.Markdown("## LLaMA 3.1 Chat (with Custom System Prompt)")
system_input = gr.Textbox(
label="System Prompt (Optional)",
value="You are a helpful AI assistant.",
lines=2
)
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Enter your message...")
clear = gr.Button("Clear Chat")
history_state = gr.State([])
def user(message, history):
return "", history + [[message, None]]
def bot(history, system_prompt):
message, _ = history[-1]
response = chat_fn(message, history[:-1], system_prompt)
history[-1][1] = response
return history
msg.submit(user, [msg, history_state], [msg, chatbot], queue=False).then(
bot, [chatbot, system_input], chatbot, queue=True
)
clear.click(lambda: ([], []), outputs=[chatbot, history_state])
if __name__ == "__main__":
demo.launch()