import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # Initialize model and tokenizer model_id = "Tesslate/WEBGEN-4B-Preview" # Load model and tokenizer once during app initialization tok = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" ) def generate_code(prompt): inputs = tok(prompt, return_tensors="pt").to(model.device) # Generate with streaming from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer(tok, skip_special_tokens=True) generation_kwargs = dict( **inputs, max_new_tokens=10000, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tok.eos_token_id, streamer=streamer ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text # Extract only the code portion (remove prompt and any non-code text) if "```html" in generated_text: code_start = generated_text.find("```html") + 7 code_end = generated_text.find("```", code_start) if code_end != -1: clean_code = generated_text[code_start:code_end].strip() else: clean_code = generated_text[code_start:].strip() elif "