gemi / app.py
edwardthefma's picture
Update app.py
1af1cd3 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the model and tokenizer
model_name = "TheDrummer/Gemmasutra-Mini-2B-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Ensure model runs on CPU (default for Hugging Face Spaces free tier)
device = torch.device("cpu")
model.to(device)
# Chatbot function
def chat_with_model(user_input, history):
# Format history and input into a single prompt
if history is None:
history = []
# Build conversation context
prompt = ""
for h in history:
prompt += f"User: {h[0]}\nBot: {h[1]}\n"
prompt += f"User: {user_input}\nBot: "
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate response
outputs = model.generate(
**inputs,
max_new_tokens=150, # Limit response length
do_sample=True, # Enable sampling for varied responses
temperature=0.7, # Control creativity
top_p=0.9, # Nucleus sampling
pad_token_id=tokenizer.eos_token_id # Handle padding
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the bot's reply (after the last "Bot: ")
bot_response = response.split("Bot: ")[-1].strip()
# Update history
history.append((user_input, bot_response))
return bot_response, history
# Gradio Interface
with gr.Blocks(title="Grok-like Chatbot") as iface:
gr.Markdown("## Chat with Gemmasutra-Mini-2B-v1")
chatbot = gr.Chatbot(label="Conversation")
msg = gr.Textbox(label="Your Message", placeholder="Type here...")
submit_btn = gr.Button("Send")
# State to maintain conversation history
state = gr.State(value=[])
def submit_message(user_input, history):
response, updated_history = chat_with_model(user_input, history)
return response, updated_history, updated_history, ""
# Connect button and enter key to submit
submit_btn.click(
fn=submit_message,
inputs=[msg, state],
outputs=[msg, state, chatbot, msg] # Clear input after submission
)
msg.submit(
fn=submit_message,
inputs=[msg, state],
outputs=[msg, state, chatbot, msg]
)
if __name__ == "__main__":
iface.launch()