SmolLM2-135m / app.py
gitesh-grover's picture
Upload 6 files
960a17b verified
import gradio as gr
import torch
from model import SmolLM2
from transformers import AutoTokenizer
from config import Config
from utils import get_device
# Initialize model and tokenizer
config = Config()
device = get_device(config.seed)
print("device: ", device)
def load_model():
model = SmolLM2(config)
# Load model weights to CPU first
model.load_state_dict(torch.load(config.checkpoints_path + "/model_final.pt", map_location=torch.device("cpu")))
model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
return model, tokenizer
model, tokenizer = load_model() # Get device from load_model
def generate_text(input_text, max_new_tokens=100, temperature=0.8, top_k=50):
"""
Generate text based on the input prompt
"""
# Tokenize input
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Generate
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k
)
# Move output back to CPU before decoding
output_ids = output_ids.cpu()
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text
# Create Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Input Text", placeholder="Enter your prompt here..."),
gr.Slider(minimum=1, maximum=150, value=30, step=1, label="Max New Tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K"),
],
outputs=gr.Textbox(label="Generated Text"),
title="SmolLM2 Text Generation",
description="Enter a prompt and the model will generate text based on it.",
)
if __name__ == "__main__":
demo.launch()