ZonUI-3B — A lightweight, resolution-aware GUI grounding model trained with only 24K samples on a single RTX 4090.

⭐ Quick start

  1. Load ZonUI-3B Model
import ast
import torch
from PIL import Image, ImageDraw
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers.generation import GenerationConfig
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize

def draw_point(image_input, point=None, radius=5):
    """Draw a point on the image at normalized coordinates [0-1]"""
    if isinstance(image_input, str):
        image = Image.open(image_input)
    else:
        image = image_input.copy()
    
    if point:
        x, y = point[0] * image.width, point[1] * image.height
        ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
    
    return image

# Load ZonUI-3B model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "zonghanHZH/ZonUI-3B",
    device_map="auto", 
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
).eval()

processor = AutoProcessor.from_pretrained("zonghanHZH/ZonUI-3B")
tokenizer = AutoTokenizer.from_pretrained("zonghanHZH/ZonUI-3B", trust_remote_code=True)

# Set generation config
generation_config = GenerationConfig.from_pretrained("zonghanHZH/ZonUI-3B", trust_remote_code=True)
generation_config.max_length = 4096
generation_config.do_sample = False
generation_config.temperature = 0.0
model.generation_config = generation_config

min_pixels = 256*28*28
max_pixels = 1280*28*28
  1. UI Grounding Example
# Example usage
img_url = './app_store.png'  # Path to your screenshot
query = 'Get OmniFocus 4'    # Text description of the element to find

# Load and prepare image
image = Image.open(img_url).convert('RGB')

# Calculate resized dimensions
resized_height, resized_width = smart_resize(
    image.height,
    image.width,
    factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
    min_pixels=min_pixels,
    max_pixels=max_pixels,
)
resized_image = image.resize((resized_width, resized_height))

# Prepare messages for the model
_SYSTEM = "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element."

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": _SYSTEM},
            {"type": "image", "image": img_url, "min_pixels": min_pixels, "max_pixels": max_pixels},
            {"type": "text", "text": query}
        ],
    }
]

# Process input
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True,
)

inputs = processor(
    text=[text],
    images=[resized_image],
    return_tensors="pt",
    training=False
).to("cuda")

# Generate prediction
generated_ids = model.generate(**inputs)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0].strip()

# print(f"Raw output: {output_text}")

# Parse coordinates
try:
    coordinates = ast.literal_eval(output_text)
    if len(coordinates) == 2:
        click_xy = [coord / resized_width if i == 0 else coord / resized_height 
                   for i, coord in enumerate(coordinates)]
    else:
        raise ValueError("Unexpected coordinate format")
    absolute_x = click_xy[0] * image.width
    absolute_y = click_xy[1] * image.height
    
    print(f"Normalized coordinates: {click_xy}")
    print(f"Predicted text: {absolute_x}, {absolute_y}")
    
    # Visualize result
    result_image = draw_point(img_url, click_xy, radius=10)
    result_image.save("./quick_start_output.png")
    print("Result saved to quick_start_output.png")
    display(result_image)
except Exception as e:
    print(f"Error parsing coordinates: {e}")
Downloads last month
161
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for zonghanHZH/ZonUI-3B

Finetuned
(313)
this model
Quantizations
3 models

Collection including zonghanHZH/ZonUI-3B