|
|
--- |
|
|
license: apache-2.0 |
|
|
base_model: |
|
|
- Qwen/Qwen2.5-VL-7B-Instruct |
|
|
tags: |
|
|
- vision |
|
|
- llm |
|
|
- critical |
|
|
- sft |
|
|
- d3.js |
|
|
- visualization |
|
|
--- |
|
|
|
|
|
# VIS-Shepherd: Constructing Critic for LLM-based Data Visualization Generation |
|
|
|
|
|
[GitHub Repo](https://github.com/bopan3/VIS-Shepherd) |
|
|
|
|
|
 |
|
|
|
|
|
This repository is the official implementation of **VIS-Shepherd: Constructing Critic for LLM-based Data Visualization Generation**. |
|
|
|
|
|
|
|
|
## Requirements |
|
|
|
|
|
### Common Dependencies |
|
|
|
|
|
#### Pyhton Environment Setup |
|
|
To install requirements for python environment (we recommend python 3.10): |
|
|
|
|
|
```bash |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
|
|
|
You can use some virtual environment to install dependencies, e.g. conda or venv. |
|
|
|
|
|
#### LLaMA-Factory |
|
|
|
|
|
We use [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) for training and model inference. If you reproduce our training experiments, please follow the instructions in the repository: |
|
|
|
|
|
```bash |
|
|
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git |
|
|
cd LLaMA-Factory |
|
|
# More optional dependencies can be found at https://llamafactory.readthedocs.io/en/latest/getting_started/installation.html |
|
|
pip install -e ".[torch,metrics,deepspeed]" |
|
|
``` |
|
|
|
|
|
## Training |
|
|
The dataset for training is available at *train/data/viscrafter_20250521.json*, with the format as follows: |
|
|
```json |
|
|
[ |
|
|
{ |
|
|
"input": "the input instruction", |
|
|
"output": "the output response", |
|
|
"images": [ |
|
|
"the image path" |
|
|
] |
|
|
}, |
|
|
] |
|
|
``` |
|
|
To train the model(s) in the paper, directly run this command at the root of the project: |
|
|
|
|
|
```bash |
|
|
llamafactory-cli train train/configs/train-sft-full-viscrafter-20250521.yml |
|
|
``` |
|
|
|
|
|
We trained the model on 8 A800 GPUs (80G memory) using DeepSpeed. You can find more configuration methods in the [LLaMA-Factory documentation](https://llamafactory.readthedocs.io/en/latest/advanced/arguments.html) to modify training parameters to adapt to your training environment. |
|
|
|
|
|
## Setup Local Inference Server |
|
|
|
|
|
You can set up an inference server using the following command, which will start a server compatible with the OpenAI API that you can use to test your model. |
|
|
|
|
|
```bash |
|
|
llamafactory-cli api train/configs/infer-sft-full-viscrafter-20250521.yml |
|
|
``` |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
First move to the folder for evaluation and fill your API_BASE, API_KEY, and list of the name of models to use in evaluation/config/config.yaml. Note that we use Azure's API for GPT-4o, local inference server for locally trained models and OpenRouters for other models (e.g. llama-4-maverick). |
|
|
```bash |
|
|
cd evaluation |
|
|
``` |
|
|
```yaml |
|
|
## config for openai key |
|
|
OPENAI_API_BASE: "put your api base here" |
|
|
OPENAI_API_KEY: "put your api key here" |
|
|
OPENAI_API_MODEL_LIST: ["gpt-4o", "qwen/qwen-2.5-vl-7b-instruct", "qwen/qwen2.5-vl-72b-instruct", "meta-llama/llama-4-maverick"] |
|
|
OPENAI_TEMPERATURE: 0.01 |
|
|
OPENAI_TOP_P: 0.1 |
|
|
``` |
|
|
|
|
|
To run inference on the test dataset for certain model, execute the following command (set --model_used to the name of model used as critic) and automatically save the inference result at folder *critic_outputs*: |
|
|
```bash |
|
|
python run_parallel_autoCritic.py --input_base_path test_set --output_base_path critic_outputs --model_used "The name of the LLM used as critic" |
|
|
``` |
|
|
|
|
|
To run auto evaluation for all the inference results under the folder *critic_outputs*, execute: |
|
|
```bash |
|
|
./run_all_autoEvaluate.sh |
|
|
``` |
|
|
|
|
|
The Evaluation result will be saved as *evaluation/result.md*. |
|
|
|
|
|
|
|
|
## Results |
|
|
|
|
|
| Model | Mean Score | % Scores 3-5 | |
|
|
|-------|------------|-------------| |
|
|
| GPT-4o | 3.41 | 72.0% | |
|
|
| VIS-Shepherd | 2.98 | 67.1% | |
|
|
| Llama-4-Maverick | 2.94 | 52.8% | |
|
|
| Qwen-2.5-VL-72B | 2.78 | 49.1% | |
|
|
| qwen-2.5-VL-7B_1.2k | 2.5 | 52.2% | |
|
|
| qwen-2.5-VL-7B_0.3k | 2.4 | 44.1% | |
|
|
| qwen-2.5-VL-7B | 2.2 | 44.1% | |