Model Description

An end-to-end multimodal LLM for Scene Graph Generation (SGG), which was introduced in Compile Scene Graphs with Reinforcement Learning

R1-SGG: Compile Scene Graphs with Reinforcement Learning

Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning

![Paper](https://img.shields.io/badge/arXiv-2504.13617-b31b1b.svg)

πŸš€ Update

  • βœ… Hugging FaceR1-SGG-7B, R1-SGG-Zero-7B
  • βœ… Support PSG dataset (bbox format only, not Panoptic)
  • βœ… Updated loss implementation
  • βœ… Always use custom_per_device_train_batch_size instead of per_device_train_batch_size for faster sampling under gradient accumulation
  • ⚠️ Current loss implementation might still be affected by gradient accumulation: trl issue #3021

πŸ› οΈ Setup Environment

bash install.sh

Main dependencies:

- torch == 2.5.0 or 2.5.1 (cu124, optional)
- transformers (supports Qwen2VL, Qwen2.5VL)
- trl
- vLLM

πŸ“š Dataset

Load preprocessed datasets via:

from datasets import load_dataset

db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"]
db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"]

or for PSG:

db_train = load_dataset("JosephZ/psg_train_sg")["train"]  # keys: image_id, image, objects, relationships
db_val = load_dataset("JosephZ/psg_test_sg")["train"]

We transformed VG150 into HuggingFace Datasets format with keys:

  • image_id
  • image
  • prompt_open
  • prompt_close
  • objects
  • relationships

πŸ”₯ Supported Models

  • Qwen/Qwen2-VL-2B-Instruct
  • Qwen/Qwen2-VL-7B-Instruct
  • Qwen/Qwen2.5-VL-3B-Instruct
  • Qwen/Qwen2.5-VL-7B-Instruct

πŸ‹οΈβ€β™‚οΈ Training

Training with Supervised Fine-Tuning (SFT)

For SLURM users:

sbatch scripts/sft/7B_sgg.sh 

For local machines:

bash scripts/sft_local/7B_sgg.sh

⏱️ Approximate training time:

  • 2B models: ~4 hours (4Γ—A100 SXM4 GPUs)
  • 7B models: ~10 hours (4Γ—A100 SXM4 GPUs)

Training with Reinforcement Learning (GRPO)

** Update (11/05/2025): to use "Hard Recall"**:

--reward_funcs format_reward edge_hard_reward 

For A100 GPUs:

sbatch scripts/grpo/train_a100_2B.sh

(12 hours on 16Γ—A100 GPUs)

For GH200 GPUs:

sbatch scripts/grpo/train_gh200.sh

(16 hours on 16Γ—GH200 GPUs)

For clusters with many RTX_3090/4090 GPUs:

sbatch scripts/grpo/train_fused.sh
  • Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks.
  • (Fun fact: training with 120Γ—RTX_4090 is crazy but severely limited by communication latency.)

πŸ’‘ Recommended learning rate: 6e-7.


πŸ§ͺ Inference and Evaluation

Inference with SFT-trained models:

bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR

For models trained with predefined categories, add true:

bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true

Inference with GRPO-trained models:

bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true

Evaluation:

DATASET_TYPE=vg # or psg
python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json
python src/vg150_eval.py $DATASET sgg_pred_results.json

🀝 Acknowledgement

The GRPOTrainer used in this project is based on trl's GRPOTrainer, extended to support multimodal inputs.


πŸ“– Citation

If you find this work helpful, please cite:

@article{chen2025compile,
  title={Compile Scene Graphs with Reinforcement Learning},
  author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen},
  journal={arXiv preprint arXiv:2504.13617},
  year={2025}
}

✨ Happy Compiling!

Downloads last month
10
Safetensors
Model size
8.29B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for JosephZ/R1-SGG-7B-PSG

Base model

Qwen/Qwen2-VL-7B
Finetuned
(344)
this model

Dataset used to train JosephZ/R1-SGG-7B-PSG