Model Description
An end-to-end multimodal LLM for Scene Graph Generation (SGG), which was introduced in Compile Scene Graphs with Reinforcement Learning
R1-SGG: Compile Scene Graphs with Reinforcement Learning
Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning

π Update
- β
R1-SGG-7B, R1-SGG-Zero-7B
- β Support PSG dataset (bbox format only, not Panoptic)
- β Updated loss implementation
- β
Always use
custom_per_device_train_batch_size
instead ofper_device_train_batch_size
for faster sampling under gradient accumulation - β οΈ Current loss implementation might still be affected by gradient accumulation: trl issue #3021
π οΈ Setup Environment
bash install.sh
Main dependencies:
- torch == 2.5.0 or 2.5.1 (cu124, optional)
- transformers (supports Qwen2VL, Qwen2.5VL)
- trl
- vLLM
π Dataset
Load preprocessed datasets via:
from datasets import load_dataset
db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"]
db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"]
or for PSG:
db_train = load_dataset("JosephZ/psg_train_sg")["train"] # keys: image_id, image, objects, relationships
db_val = load_dataset("JosephZ/psg_test_sg")["train"]
We transformed VG150 into HuggingFace Datasets format with keys:
image_id
image
prompt_open
prompt_close
objects
relationships
π₯ Supported Models
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
- Qwen/Qwen2.5-VL-3B-Instruct
- Qwen/Qwen2.5-VL-7B-Instruct
ποΈββοΈ Training
Training with Supervised Fine-Tuning (SFT)
For SLURM users:
sbatch scripts/sft/7B_sgg.sh
For local machines:
bash scripts/sft_local/7B_sgg.sh
β±οΈ Approximate training time:
- 2B models: ~4 hours (4ΓA100 SXM4 GPUs)
- 7B models: ~10 hours (4ΓA100 SXM4 GPUs)
Training with Reinforcement Learning (GRPO)
** Update (11/05/2025): to use "Hard Recall"**:
--reward_funcs format_reward edge_hard_reward
For A100 GPUs:
sbatch scripts/grpo/train_a100_2B.sh
(12 hours on 16ΓA100 GPUs)
For GH200 GPUs:
sbatch scripts/grpo/train_gh200.sh
(16 hours on 16ΓGH200 GPUs)
For clusters with many RTX_3090/4090 GPUs:
sbatch scripts/grpo/train_fused.sh
- Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks.
- (Fun fact: training with 120ΓRTX_4090 is crazy but severely limited by communication latency.)
π‘ Recommended learning rate: 6e-7
.
π§ͺ Inference and Evaluation
Inference with SFT-trained models:
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR
For models trained with predefined categories, add true
:
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true
Inference with GRPO-trained models:
bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true
Evaluation:
DATASET_TYPE=vg # or psg
python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json
python src/vg150_eval.py $DATASET sgg_pred_results.json
π€ Acknowledgement
The GRPOTrainer
used in this project is based on trl's GRPOTrainer, extended to support multimodal inputs.
π Citation
If you find this work helpful, please cite:
@article{chen2025compile,
title={Compile Scene Graphs with Reinforcement Learning},
author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen},
journal={arXiv preprint arXiv:2504.13617},
year={2025}
}
β¨ Happy Compiling!
- Downloads last month
- 10
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support