Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# svg_compare_gradio.py | |
# ------------------------------------------------------------ | |
import spaces | |
import re, os, torch, cairosvg, lpips, clip, gradio as gr | |
from io import BytesIO | |
from pathlib import Path | |
from PIL import Image | |
from transformers import BitsAndBytesConfig, AutoTokenizer | |
import gradio as gr | |
# ---------- paths YOU may want to edit ---------------------- | |
ADAPTER_DIR = "unsloth_trained_weights/checkpoint-1700" # LoRA ckpt | |
BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
MAX_NEW = 512 | |
DEVICE = "cuda" # if torch.cuda.is_available() else "cpu" | |
# ---------- utils ------------------------------------------- | |
SVG_PAT = re.compile(r"<svg[^>]*>.*?</svg>", re.S | re.I) | |
def extract_svg(txt:str): | |
m = list(SVG_PAT.finditer(txt)) | |
return m[-1].group(0) if m else None # last match β | |
def svg2pil(svg:str): | |
try: | |
png = cairosvg.svg2png(bytestring=svg.encode()) | |
return Image.open(BytesIO(png)).convert("RGB") | |
except Exception: | |
return None | |
# ---------- backbone loaders (CLIP + LPIPS) ----------------- | |
_CLIP,_PREP,_LP=None,None,None | |
def _load_backbones(): | |
global _CLIP,_PREP,_LP | |
if _CLIP is None: | |
_CLIP,_PREP = clip.load("ViT-L/14", device=DEVICE); _CLIP.eval() | |
if _LP is None: | |
_LP = lpips.LPIPS(net="vgg").to(DEVICE).eval() | |
def fused_sim(a:Image.Image,b:Image.Image,Ξ±=.5): | |
_load_backbones() | |
ta,tb = _PREP(a).unsqueeze(0).to(DEVICE), _PREP(b).unsqueeze(0).to(DEVICE) | |
fa = _CLIP.encode_image(ta); fa/=fa.norm(dim=-1,keepdim=True) | |
fb = _CLIP.encode_image(tb); fb/=fb.norm(dim=-1,keepdim=True) | |
clip_sim=(([email protected]).item()+1)/2 | |
lp_sim = 1 - _LP(ta,tb,normalize=True).item() | |
return Ξ±*clip_sim + (1-Ξ±)*lp_sim | |
bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True) | |
# ---------- load models once at startup --------------------- | |
_base = _lora = _tok = None | |
_CLIP = _PREP = _LP = None | |
def ensure_models(): | |
"""Create base, lora, tok **once per worker**.""" | |
from unsloth import FastLanguageModel | |
global _base, _lora, _tok | |
if _base is None: | |
_base, _tok = FastLanguageModel.from_pretrained( | |
BASE_MODEL, max_seq_length=2048, | |
quantization_config=bnb_cfg, device_map="auto") | |
_tok.pad_token = _tok.eos_token | |
_lora, _ = FastLanguageModel.from_pretrained( | |
ADAPTER_DIR, max_seq_length=2048, | |
quantization_config=bnb_cfg, device_map="auto") | |
return True | |
def draw(model_flag, desc): | |
ensure_models() | |
model = _base if model_flag == "base" else _lora | |
prompt = _tok.apply_chat_template( | |
[{"role":"system","content":"You are an SVG illustrator."}, | |
{"role":"user", | |
"content":f"ONLY reply with a valid, complete <svg>β¦</svg> file that depicts: {desc}"}], | |
tokenize=False, add_generation_prompt=True) | |
ids = _tok(prompt, return_tensors="pt").to(DEVICE) | |
out = model.generate(**ids, max_new_tokens=MAX_NEW, | |
do_sample=True, temperature=.7, top_p=.8) | |
svg = extract_svg(_tok.decode(out[0], skip_special_tokens=True)) | |
img = svg2pil(svg) if svg else None | |
return img, svg or "(no SVG found)" | |
# ---------- gradio interface -------------------------------- | |
# | |
def compare(desc): | |
img_b, svg_b = draw("base", desc) | |
img_l, svg_l = draw("lora", desc) | |
caption = "Thanks for trying our model π\nIf you don't see an image for the base or GRPO model that means it didn't generate a valid SVG!" | |
return img_b, img_l, caption, svg_b, svg_l | |
with gr.Blocks(theme="gradio/Base") as demo: | |
gr.Markdown("## ποΈ Qwen-2.5 SVG Generator β base vs GRPO-LoRA") | |
gr.Markdown( | |
"Type an image **description** (e.g. *a purple forest at dusk*). " | |
"Click **Generate** to see what the base model and your fine-tuned LoRA produce." | |
) | |
inp = gr.Textbox(label="Description", placeholder="a purple forest at dusk") | |
btn = gr.Button("Generate") | |
with gr.Row(): | |
out_base = gr.Image(label="Base model", type="pil") | |
out_lora = gr.Image(label="LoRA-tuned model", type="pil") | |
sim_lbl = gr.Markdown() | |
with gr.Accordion("βοΈ Raw SVG code", open=False): | |
svg_base_box = gr.Textbox(label="Base SVG", lines=6) | |
svg_lora_box = gr.Textbox(label="LoRA SVG", lines=6) | |
btn.click(compare, inp, [out_base, out_lora, sim_lbl, svg_base_box, svg_lora_box]) | |
demo.launch() | |