Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from config import SmolLM2Config | |
| from model import SmolLM2Lightning | |
| def load_model(checkpoint_path): | |
| """Load the trained model from checkpoint""" | |
| try: | |
| config = SmolLM2Config("config.yaml") | |
| model = SmolLM2Lightning.load_from_checkpoint(checkpoint_path, config=config) | |
| model.eval() | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") | |
| else: | |
| print("Model loaded on CPU") | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| return None | |
| def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50): | |
| """Generate text from prompt""" | |
| try: | |
| if model is None: | |
| return "Model not loaded. Please check if checkpoint exists." | |
| inputs = model.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=model.config.model.max_position_embeddings, | |
| padding=True | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| do_sample=True, | |
| pad_token_id=model.tokenizer.pad_token_id, | |
| bos_token_id=model.tokenizer.bos_token_id, | |
| eos_token_id=model.tokenizer.eos_token_id | |
| ) | |
| return model.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| # Load the model | |
| print("Loading model...") | |
| checkpoint_path = "smol-lm2-final.ckpt" | |
| if not os.path.exists(checkpoint_path): | |
| print(f"Warning: Checkpoint not found at {checkpoint_path}") | |
| print("Please train the model first or specify correct checkpoint path") | |
| model = None | |
| else: | |
| model = load_model(checkpoint_path) | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), | |
| gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"), | |
| gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text"), | |
| title="SmolLM2 Text Generation", | |
| description="Enter a prompt and adjust generation parameters to create text with SmolLM2", | |
| examples=[ | |
| ["Explain what machine learning is:", 100, 0.7, 0.9, 50], | |
| ["Once upon a time", 150, 0.8, 0.9, 40], | |
| ["The best way to learn programming is", 120, 0.7, 0.9, 50] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting Gradio interface...") | |
| # Simple launch configuration | |
| demo.launch( | |
| server_port=7860, | |
| share=True | |
| ) |