Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from model import SmolLM2 | |
from transformers import AutoTokenizer | |
from config import Config | |
from utils import get_device | |
# Initialize model and tokenizer | |
config = Config() | |
device = get_device(config.seed) | |
print("device: ", device) | |
def load_model(): | |
model = SmolLM2(config) | |
# Load model weights to CPU first | |
model.load_state_dict(torch.load(config.checkpoints_path + "/model_final.pt", map_location=torch.device("cpu"))) | |
model.to(device) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) | |
return model, tokenizer | |
model, tokenizer = load_model() # Get device from load_model | |
def generate_text(input_text, max_new_tokens=100, temperature=0.8, top_k=50): | |
""" | |
Generate text based on the input prompt | |
""" | |
# Tokenize input | |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
# Generate | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_k=top_k | |
) | |
# Move output back to CPU before decoding | |
output_ids = output_ids.cpu() | |
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return generated_text | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Input Text", placeholder="Enter your prompt here..."), | |
gr.Slider(minimum=1, maximum=150, value=30, step=1, label="Max New Tokens"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"), | |
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K"), | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="SmolLM2 Text Generation", | |
description="Enter a prompt and the model will generate text based on it.", | |
) | |
if __name__ == "__main__": | |
demo.launch() |