ZonUI-3B
Collection
ZonUI-3B — A lightweight GUI grounding model optimized for high-resolution screens, trained with just 24K examples on a single RTX 4090.
•
4 items
•
Updated
•
1
ZonUI-3B — A lightweight, resolution-aware GUI grounding model trained with only 24K samples on a single RTX 4090.
import ast
import torch
from PIL import Image, ImageDraw
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers.generation import GenerationConfig
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize
def draw_point(image_input, point=None, radius=5):
"""Draw a point on the image at normalized coordinates [0-1]"""
if isinstance(image_input, str):
image = Image.open(image_input)
else:
image = image_input.copy()
if point:
x, y = point[0] * image.width, point[1] * image.height
ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
return image
# Load ZonUI-3B model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"zonghanHZH/ZonUI-3B",
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
).eval()
processor = AutoProcessor.from_pretrained("zonghanHZH/ZonUI-3B")
tokenizer = AutoTokenizer.from_pretrained("zonghanHZH/ZonUI-3B", trust_remote_code=True)
# Set generation config
generation_config = GenerationConfig.from_pretrained("zonghanHZH/ZonUI-3B", trust_remote_code=True)
generation_config.max_length = 4096
generation_config.do_sample = False
generation_config.temperature = 0.0
model.generation_config = generation_config
min_pixels = 256*28*28
max_pixels = 1280*28*28
# Example usage
img_url = './app_store.png' # Path to your screenshot
query = 'Get OmniFocus 4' # Text description of the element to find
# Load and prepare image
image = Image.open(img_url).convert('RGB')
# Calculate resized dimensions
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=processor.image_processor.patch_size * processor.image_processor.merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
# Prepare messages for the model
_SYSTEM = "Based on the screenshot of the page, I give a text description and you give its corresponding location. The coordinate represents a clickable location [x, y] for an element."
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": _SYSTEM},
{"type": "image", "image": img_url, "min_pixels": min_pixels, "max_pixels": max_pixels},
{"type": "text", "text": query}
],
}
]
# Process input
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = processor(
text=[text],
images=[resized_image],
return_tensors="pt",
training=False
).to("cuda")
# Generate prediction
generated_ids = model.generate(**inputs)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0].strip()
# print(f"Raw output: {output_text}")
# Parse coordinates
try:
coordinates = ast.literal_eval(output_text)
if len(coordinates) == 2:
click_xy = [coord / resized_width if i == 0 else coord / resized_height
for i, coord in enumerate(coordinates)]
else:
raise ValueError("Unexpected coordinate format")
absolute_x = click_xy[0] * image.width
absolute_y = click_xy[1] * image.height
print(f"Normalized coordinates: {click_xy}")
print(f"Predicted text: {absolute_x}, {absolute_y}")
# Visualize result
result_image = draw_point(img_url, click_xy, radius=10)
result_image.save("./quick_start_output.png")
print("Result saved to quick_start_output.png")
display(result_image)
except Exception as e:
print(f"Error parsing coordinates: {e}")