This is the official implementation from the paper SATORI-R1: Incentivizing Multimodal Reasoning with Spatial Grounding and Verifiable Rewards. Arxiv Here.
SATORI is a vision-language model fine-tuned from Qwen2.5-VL to perform structured visual reasoning for Visual Question Answering (VQA). It generates:
- A concise image caption describing the overall scene.
- Coordinates of relevant bounding boxes that support reasoning.
- A final answer to the user’s question.
Inference Example
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
# Load the SATORI model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"justairr/SATORI", torch_dtype="auto", device_map="auto"
)
min_pixels = 256 * 28 * 28
max_pixels = 512 * 28 * 28
processor = AutoProcessor.from_pretrained(
"justairr/SATORI", min_pixels=min_pixels, max_pixels=max_pixels
)
# Chat-style messages guiding structured output
messages = [
{
"role": "system",
"content": (
"Given an image and a question, follow these steps:\n"
"1. Generate a brief image caption describing the overall scene inside <caption>...</caption>.\n"
"2. Determine the most relevant image regions, output their coordinates inside <bbox>...</bbox>.\n"
"3. Provide the final answer inside <answer>...</answer>."
),
},
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{
"type": "text",
"text": (
"What's the girl playing with?\n"
"First, provide an image caption inside <caption>...</caption>, "
"then bounding boxes inside <bbox>...</bbox>, and finally <answer>...</answer>."
),
},
],
},
]
# Prepare inputs
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=128)
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
- Downloads last month
- 24
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for justairr/SATORI
Base model
Qwen/Qwen2.5-VL-3B-Instruct