Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load the model and tokenizer | |
model_name = "TheDrummer/Gemmasutra-Mini-2B-v1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Ensure model runs on CPU (default for Hugging Face Spaces free tier) | |
device = torch.device("cpu") | |
model.to(device) | |
# Chatbot function | |
def chat_with_model(user_input, history): | |
# Format history and input into a single prompt | |
if history is None: | |
history = [] | |
# Build conversation context | |
prompt = "" | |
for h in history: | |
prompt += f"User: {h[0]}\nBot: {h[1]}\n" | |
prompt += f"User: {user_input}\nBot: " | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate response | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=150, # Limit response length | |
do_sample=True, # Enable sampling for varied responses | |
temperature=0.7, # Control creativity | |
top_p=0.9, # Nucleus sampling | |
pad_token_id=tokenizer.eos_token_id # Handle padding | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the bot's reply (after the last "Bot: ") | |
bot_response = response.split("Bot: ")[-1].strip() | |
# Update history | |
history.append((user_input, bot_response)) | |
return bot_response, history | |
# Gradio Interface | |
with gr.Blocks(title="Grok-like Chatbot") as iface: | |
gr.Markdown("## Chat with Gemmasutra-Mini-2B-v1") | |
chatbot = gr.Chatbot(label="Conversation") | |
msg = gr.Textbox(label="Your Message", placeholder="Type here...") | |
submit_btn = gr.Button("Send") | |
# State to maintain conversation history | |
state = gr.State(value=[]) | |
def submit_message(user_input, history): | |
response, updated_history = chat_with_model(user_input, history) | |
return response, updated_history, updated_history, "" | |
# Connect button and enter key to submit | |
submit_btn.click( | |
fn=submit_message, | |
inputs=[msg, state], | |
outputs=[msg, state, chatbot, msg] # Clear input after submission | |
) | |
msg.submit( | |
fn=submit_message, | |
inputs=[msg, state], | |
outputs=[msg, state, chatbot, msg] | |
) | |
if __name__ == "__main__": | |
iface.launch() |