tuandunghcmut commited on
Commit
7350352
·
verified ·
1 Parent(s): 8b6974d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. groundingLMM/LLaVA/.devcontainer/Dockerfile +53 -0
  2. groundingLMM/LLaVA/.devcontainer/devcontainer.env +2 -0
  3. groundingLMM/LLaVA/.devcontainer/devcontainer.json +71 -0
  4. groundingLMM/LLaVA/.devcontainer/postCreateCommand.sh +45 -0
  5. groundingLMM/LLaVA/docs/Customize_Component.md +20 -0
  6. groundingLMM/LLaVA/docs/Data.md +29 -0
  7. groundingLMM/LLaVA/docs/Evaluation.md +167 -0
  8. groundingLMM/LLaVA/docs/Finetune_Custom_Data.md +37 -0
  9. groundingLMM/LLaVA/docs/Intel.md +7 -0
  10. groundingLMM/LLaVA/docs/LLaVA_Bench.md +31 -0
  11. groundingLMM/LLaVA/docs/LLaVA_from_LLaMA2.md +29 -0
  12. groundingLMM/LLaVA/docs/LoRA.md +46 -0
  13. groundingLMM/LLaVA/docs/MODEL_ZOO.md +150 -0
  14. groundingLMM/LLaVA/docs/ScienceQA.md +53 -0
  15. groundingLMM/LLaVA/docs/Windows.md +27 -0
  16. groundingLMM/LLaVA/docs/macOS.md +29 -0
  17. groundingLMM/LLaVA/scripts/convert_gqa_for_eval.py +18 -0
  18. groundingLMM/LLaVA/scripts/convert_mmvet_for_eval.py +18 -0
  19. groundingLMM/LLaVA/scripts/convert_sqa_to_llava_base_prompt.py +334 -0
  20. groundingLMM/LLaVA/scripts/convert_vizwiz_for_submission.py +47 -0
  21. groundingLMM/LLaVA/scripts/extract_mm_projector.py +47 -0
  22. groundingLMM/LLaVA/scripts/finetune_qlora.sh +50 -0
  23. groundingLMM/LLaVA/scripts/merge_lora_weights.py +22 -0
  24. groundingLMM/LLaVA/scripts/pretrain.sh +46 -0
  25. groundingLMM/LLaVA/scripts/upload_pypi.sh +16 -0
  26. groundingLMM/LLaVA/scripts/zero2.json +23 -0
  27. groundingLMM/dataset/caption_datasets/COCO_Caption_ds.py +124 -0
  28. groundingLMM/dataset/caption_datasets/GranD_ShortCaption_ds.py +105 -0
  29. groundingLMM/dataset/caption_datasets/LLavaInstruct_vqa_ds.py +107 -0
  30. groundingLMM/dataset/gcg_datasets/GranDf_gcg_ds.py +353 -0
  31. groundingLMM/dataset/region_datasets/Flickr_Region_ds.py +193 -0
  32. groundingLMM/dataset/region_datasets/GranD_ReferringRegion_ds.py +162 -0
  33. groundingLMM/dataset/region_datasets/RefCOCO_VG_Region_ds.py +300 -0
  34. groundingLMM/dataset/segm_datasets/GranD_ReferringSegm_ds.py +156 -0
  35. groundingLMM/dataset/segm_datasets/RefCOCO_Segm_ds.py +242 -0
  36. groundingLMM/dataset/segm_datasets/Semantic_Segm_ds.py +248 -0
  37. groundingLMM/dataset/utils/ade20k_classes.json +30 -0
  38. groundingLMM/dataset/utils/cocostuff_classes.txt +183 -0
  39. groundingLMM/dataset/utils/grefer.py +352 -0
  40. groundingLMM/dataset/utils/refcoco_refer.py +391 -0
  41. groundingLMM/dataset/utils/utils.py +115 -0
  42. groundingLMM/mmcv/tests/data/config/a.py +5 -0
  43. groundingLMM/mmcv/tests/data/config/b.json +8 -0
  44. groundingLMM/mmcv/tests/data/config/base.py +5 -0
  45. groundingLMM/mmcv/tests/data/config/c.yaml +4 -0
  46. groundingLMM/mmcv/tests/data/config/d.py +6 -0
  47. groundingLMM/mmcv/tests/data/config/delete.py +4 -0
  48. groundingLMM/mmcv/tests/data/config/deprecated.py +6 -0
  49. groundingLMM/mmcv/tests/data/config/deprecated_as_base.py +2 -0
  50. groundingLMM/mmcv/tests/data/config/e.py +3 -0
groundingLMM/LLaVA/.devcontainer/Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM mcr.microsoft.com/devcontainers/base:ubuntu-20.04
2
+
3
+ SHELL [ "bash", "-c" ]
4
+
5
+ # update apt and install packages
6
+ RUN apt update && \
7
+ apt install -yq \
8
+ ffmpeg \
9
+ dkms \
10
+ build-essential
11
+
12
+ # add user tools
13
+ RUN sudo apt install -yq \
14
+ jq \
15
+ jp \
16
+ tree \
17
+ tldr
18
+
19
+ # add git-lfs and install
20
+ RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \
21
+ sudo apt-get install -yq git-lfs && \
22
+ git lfs install
23
+
24
+ ############################################
25
+ # Setup user
26
+ ############################################
27
+
28
+ USER vscode
29
+
30
+ # install azcopy, a tool to copy to/from blob storage
31
+ # for more info: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-upload#upload-a-file
32
+ RUN cd /tmp && \
33
+ wget https://azcopyvnext.azureedge.net/release20230123/azcopy_linux_amd64_10.17.0.tar.gz && \
34
+ tar xvf azcopy_linux_amd64_10.17.0.tar.gz && \
35
+ mkdir -p ~/.local/bin && \
36
+ mv azcopy_linux_amd64_10.17.0/azcopy ~/.local/bin && \
37
+ chmod +x ~/.local/bin/azcopy && \
38
+ rm -rf azcopy_linux_amd64*
39
+
40
+ # Setup conda
41
+ RUN cd /tmp && \
42
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
43
+ bash ./Miniconda3-latest-Linux-x86_64.sh -b && \
44
+ rm ./Miniconda3-latest-Linux-x86_64.sh
45
+
46
+ # Install dotnet
47
+ RUN cd /tmp && \
48
+ wget https://dot.net/v1/dotnet-install.sh && \
49
+ chmod +x dotnet-install.sh && \
50
+ ./dotnet-install.sh --channel 7.0 && \
51
+ ./dotnet-install.sh --channel 3.1 && \
52
+ rm ./dotnet-install.sh
53
+
groundingLMM/LLaVA/.devcontainer/devcontainer.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ SAMPLE_ENV_VAR1="Sample Value"
2
+ SAMPLE_ENV_VAR2=332431bf-68bf
groundingLMM/LLaVA/.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "LLaVA",
3
+ "build": {
4
+ "dockerfile": "Dockerfile",
5
+ "context": "..",
6
+ "args": {}
7
+ },
8
+ "features": {
9
+ "ghcr.io/devcontainers/features/docker-in-docker:2": {},
10
+ "ghcr.io/devcontainers/features/azure-cli:1": {},
11
+ "ghcr.io/azure/azure-dev/azd:0": {},
12
+ "ghcr.io/devcontainers/features/powershell:1": {},
13
+ "ghcr.io/devcontainers/features/common-utils:2": {},
14
+ "ghcr.io/devcontainers-contrib/features/zsh-plugins:0": {},
15
+ },
16
+ // "forwardPorts": [],
17
+ "postCreateCommand": "bash ./.devcontainer/postCreateCommand.sh",
18
+ "customizations": {
19
+ "vscode": {
20
+ "settings": {
21
+ "python.analysis.autoImportCompletions": true,
22
+ "python.analysis.autoImportUserSymbols": true,
23
+ "python.defaultInterpreterPath": "~/miniconda3/envs/llava/bin/python",
24
+ "python.formatting.provider": "yapf",
25
+ "python.linting.enabled": true,
26
+ "python.linting.flake8Enabled": true,
27
+ "isort.check": true,
28
+ "dev.containers.copyGitConfig": true,
29
+ "terminal.integrated.defaultProfile.linux": "zsh",
30
+ "terminal.integrated.profiles.linux": {
31
+ "zsh": {
32
+ "path": "/usr/bin/zsh"
33
+ },
34
+ }
35
+ },
36
+ "extensions": [
37
+ "aaron-bond.better-comments",
38
+ "eamodio.gitlens",
39
+ "EditorConfig.EditorConfig",
40
+ "foxundermoon.shell-format",
41
+ "GitHub.copilot-chat",
42
+ "GitHub.copilot-labs",
43
+ "GitHub.copilot",
44
+ "lehoanganh298.json-lines-viewer",
45
+ "mhutchie.git-graph",
46
+ "ms-azuretools.vscode-docker",
47
+ "ms-dotnettools.dotnet-interactive-vscode",
48
+ "ms-python.flake8",
49
+ "ms-python.isort",
50
+ "ms-python.python",
51
+ "ms-python.vscode-pylance",
52
+ "njpwerner.autodocstring",
53
+ "redhat.vscode-yaml",
54
+ "stkb.rewrap",
55
+ "yzhang.markdown-all-in-one",
56
+ ]
57
+ }
58
+ },
59
+ "mounts": [],
60
+ "runArgs": [
61
+ "--gpus",
62
+ "all",
63
+ // "--ipc",
64
+ // "host",
65
+ "--ulimit",
66
+ "memlock=-1",
67
+ "--env-file",
68
+ ".devcontainer/devcontainer.env"
69
+ ],
70
+ // "remoteUser": "root"
71
+ }
groundingLMM/LLaVA/.devcontainer/postCreateCommand.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git config --global safe.directory '*'
2
+ git config --global core.editor "code --wait"
3
+ git config --global pager.branch false
4
+
5
+ # Set AZCOPY concurrency to auto
6
+ echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.zshrc
7
+ echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.bashrc
8
+
9
+ # Activate conda by default
10
+ echo ". /home/vscode/miniconda3/bin/activate" >> ~/.zshrc
11
+ echo ". /home/vscode/miniconda3/bin/activate" >> ~/.bashrc
12
+
13
+ # Use llava environment by default
14
+ echo "conda activate llava" >> ~/.zshrc
15
+ echo "conda activate llava" >> ~/.bashrc
16
+
17
+ # Add dotnet to PATH
18
+ echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.bashrc
19
+ echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.zshrc
20
+
21
+ # Create and activate llava environment
22
+ source /home/vscode/miniconda3/bin/activate
23
+ conda create -y -q -n llava python=3.10
24
+ conda activate llava
25
+
26
+ # Install Nvidia Cuda Compiler
27
+ conda install -y -c nvidia cuda-compiler
28
+
29
+ pip install pre-commit==3.0.2
30
+
31
+ # Install package locally
32
+ pip install --upgrade pip # enable PEP 660 support
33
+ pip install -e .
34
+
35
+ # Install additional packages for training
36
+ pip install -e ".[train]"
37
+ pip install flash-attn --no-build-isolation
38
+
39
+ # Download checkpoints to location outside of the repo
40
+ git clone https://huggingface.co/liuhaotian/llava-v1.5-7b ~/llava-v1.5-7b
41
+
42
+ # Commented because it is unlikely for users to have enough local GPU memory to load the model
43
+ # git clone https://huggingface.co/liuhaotian/llava-v1.5-13b ~/llava-v1.5-13b
44
+
45
+ echo "postCreateCommand.sh COMPLETE!"
groundingLMM/LLaVA/docs/Customize_Component.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customize Components in LLaVA
2
+
3
+ This is an initial guide on how to replace the LLMs, visual encoders, etc. with your choice of components.
4
+
5
+ ## LLM
6
+
7
+ It is quite simple to swap out LLaMA to any other LLMs. You can refer to our implementation of [`llava_llama.py`](https://raw.githubusercontent.com/haotian-liu/LLaVA/main/llava/model/language_model/llava_llama.py) for an example of how to replace the LLM.
8
+
9
+ Although it may seem that it still needs ~100 lines of code, most of them are copied from the original `llama.py` from HF. The only part that is different is to insert some lines for processing the multimodal inputs.
10
+
11
+ In `forward` function, you can see that we call `self.prepare_inputs_labels_for_multimodal` to process the multimodal inputs. This function is defined in `LlavaMetaForCausalLM` and you just need to insert it into the `forward` function of your LLM.
12
+
13
+ In `prepare_inputs_for_generation` function, you can see that we add `images` to the `model_inputs`. This is because we need to pass the images to the LLM during generation.
14
+
15
+ These are basically all the changes you need to make to replace the LLM.
16
+
17
+ ## Visual Encoder
18
+
19
+ You can check out [`clip_encoder.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/model/multimodal_encoder/clip_encoder.py) on how we implement the CLIP visual encoder.
20
+
groundingLMM/LLaVA/docs/Data.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Data
2
+
3
+ | Data file name | Size |
4
+ | --- | ---: |
5
+ | [llava_instruct_150k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json) | 229 MB |
6
+ | [llava_instruct_80k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_80k.json) | 229 MB |
7
+ | [conversation_58k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/conversation_58k.json) | 126 MB |
8
+ | [detail_23k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/detail_23k.json) | 20.5 MB |
9
+ | [complex_reasoning_77k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/complex_reasoning_77k.json) | 79.6 MB |
10
+
11
+ ### Pretraining Dataset
12
+ The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images.
13
+
14
+ If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary.
15
+
16
+ | Data | Chat File | Meta Data | Size |
17
+ | --- | --- | --- | ---: |
18
+ | CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB
19
+ | LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB
20
+
21
+ **Important notice**: Upon the request from the community, as ~15% images of the original CC-3M dataset are no longer accessible, we upload [`images.zip`](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/images.zip) for better reproducing our work in research community. It must not be used for any other purposes. The use of these images must comply with the CC-3M license. This may be taken down at any time when requested by the original CC-3M dataset owner or owners of the referenced images.
22
+
23
+ ### GPT-4 Prompts
24
+
25
+ We provide our prompts and few-shot samples for GPT-4 queries, to better facilitate research in this domain. Please check out the [`prompts`](https://github.com/haotian-liu/LLaVA/tree/main/playground/data/prompts) folder for three kinds of questions: conversation, detail description, and complex reasoning.
26
+
27
+ They are organized in a format of `system_message.txt` for system message, pairs of `abc_caps.txt` for few-shot sample user input, and `abc_conv.txt` for few-shot sample reference output.
28
+
29
+ Note that you may find them in different format. For example, `conversation` is in `jsonl`, and detail description is answer-only. The selected format in our preliminary experiments works slightly better than a limited set of alternatives that we tried: `jsonl`, more natural format, answer-only. If interested, you may try other variants or conduct more careful study in this. Contributions are welcomed!
groundingLMM/LLaVA/docs/Evaluation.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs.
4
+
5
+ Currently, we mostly utilize the official toolkit or server for the evaluation.
6
+
7
+ ## Evaluate on Custom Datasets
8
+
9
+ You can evaluate LLaVA on your custom datasets by converting your dataset to LLaVA's jsonl format, and evaluate using [`model_vqa.py`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/model_vqa.py).
10
+
11
+ Below we provide a general guideline for evaluating datasets with some common formats.
12
+
13
+ 1. Short-answer (e.g. VQAv2, MME).
14
+
15
+ ```
16
+ <question>
17
+ Answer the question using a single word or phrase.
18
+ ```
19
+
20
+ 2. Option-only for multiple-choice (e.g. MMBench, SEED-Bench).
21
+
22
+ ```
23
+ <question>
24
+ A. <option_1>
25
+ B. <option_2>
26
+ C. <option_3>
27
+ D. <option_4>
28
+ Answer with the option's letter from the given choices directly.
29
+ ```
30
+
31
+ 3. Natural QA (e.g. LLaVA-Bench, MM-Vet).
32
+
33
+ No postprocessing is needed.
34
+
35
+ ## Scripts
36
+
37
+ Before preparing task-specific data, **you MUST first download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing)**. It contains custom annotations, scripts, and the prediction files with LLaVA v1.5. Extract to `./playground/data/eval`. This also provides a general structure for all datasets.
38
+
39
+ ### VQAv2
40
+
41
+ 1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`.
42
+ 2. Multi-GPU inference.
43
+ ```Shell
44
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/vqav2.sh
45
+ ```
46
+ 3. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/830/my-submission): `./playground/data/eval/vqav2/answers_upload`.
47
+
48
+ ### GQA
49
+
50
+ 1. Download the [data](https://cs.stanford.edu/people/dorarad/gqa/download.html) and [evaluation scripts](https://cs.stanford.edu/people/dorarad/gqa/evaluate.html) following the official instructions and put under `./playground/data/eval/gqa/data`. You may need to modify `eval.py` as [this](https://gist.github.com/haotian-liu/db6eddc2a984b4cbcc8a7f26fd523187) due to the missing assets in the GQA v1.2 release.
51
+ 2. Multi-GPU inference.
52
+ ```Shell
53
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/gqa.sh
54
+ ```
55
+
56
+ ### VisWiz
57
+
58
+ 1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`.
59
+ 2. Single-GPU inference.
60
+ ```Shell
61
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/vizwiz.sh
62
+ ```
63
+ 3. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/2185/my-submission): `./playground/data/eval/vizwiz/answers_upload`.
64
+
65
+ ### ScienceQA
66
+
67
+ 1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA).
68
+ 2. Single-GPU inference and evaluate.
69
+ ```Shell
70
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/sqa.sh
71
+ ```
72
+
73
+ ### TextVQA
74
+
75
+ 1. Download [`TextVQA_0.5.1_val.json`](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`.
76
+ 2. Single-GPU inference and evaluate.
77
+ ```Shell
78
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/textvqa.sh
79
+ ```
80
+
81
+ ### POPE
82
+
83
+ 1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`.
84
+ 2. Single-GPU inference and evaluate.
85
+ ```Shell
86
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/pope.sh
87
+ ```
88
+
89
+ ### MME
90
+
91
+ 1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation).
92
+ 2. Downloaded images to `MME_Benchmark_release_version`.
93
+ 3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`.
94
+ 4. Single-GPU inference and evaluate.
95
+ ```Shell
96
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mme.sh
97
+ ```
98
+
99
+ ### MMBench
100
+
101
+ 1. Download [`mmbench_dev_20230712.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_20230712.tsv) and put under `./playground/data/eval/mmbench`.
102
+ 2. Single-GPU inference.
103
+ ```Shell
104
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench.sh
105
+ ```
106
+ 3. Submit the results to the [evaluation server](https://opencompass.org.cn/leaderboard-multimodal): `./playground/data/eval/mmbench/answers_upload/mmbench_dev_20230712`.
107
+
108
+ ### MMBench-CN
109
+
110
+ 1. Download [`mmbench_dev_cn_20231003.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_cn_20231003.tsv) and put under `./playground/data/eval/mmbench`.
111
+ 2. Single-GPU inference.
112
+ ```Shell
113
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmbench_cn.sh
114
+ ```
115
+ 3. Submit the results to the evaluation server: `./playground/data/eval/mmbench/answers_upload/mmbench_dev_cn_20231003`.
116
+
117
+
118
+ ### SEED-Bench
119
+
120
+ 1. Following the official [instructions](https://github.com/AILab-CVC/SEED-Bench/blob/main/DATASET.md) to download the images and the videos. Put images under `./playground/data/eval/seed_bench/SEED-Bench-image`.
121
+ 2. Extract the video frame in the middle from the downloaded videos, and put them under `./playground/data/eval/seed_bench/SEED-Bench-video-image`. We provide our script `extract_video_frames.py` modified from the official one.
122
+ 3. Multiple-GPU inference and evaluate.
123
+ ```Shell
124
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/v1_5/eval/seed.sh
125
+ ```
126
+ 4. Optionally, submit the results to the leaderboard: `./playground/data/eval/seed_bench/answers_upload` using the official jupyter notebook.
127
+
128
+ ### LLaVA-Bench-in-the-Wild
129
+
130
+ 1. Extract contents of [`llava-bench-in-the-wild`](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to `./playground/data/eval/llava-bench-in-the-wild`.
131
+ 2. Single-GPU inference and evaluate.
132
+ ```Shell
133
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/llavabench.sh
134
+ ```
135
+
136
+ ### MM-Vet
137
+
138
+ 1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`.
139
+ 2. Single-GPU inference.
140
+ ```Shell
141
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/mmvet.sh
142
+ ```
143
+ 3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook.
144
+
145
+ ## More Benchmarks
146
+
147
+ Below are awesome benchmarks for multimodal understanding from the research community, that are not initially included in the LLaVA-1.5 release.
148
+
149
+ ### Q-Bench
150
+
151
+ 1. Download [`llvisionqa_dev.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_dev.json) (for `dev`-subset) and [`llvisionqa_test.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/llvisionqa_test.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`.
152
+ 2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`.
153
+ 3. Single-GPU inference (change `dev` to `test` for evaluation on test set).
154
+ ```Shell
155
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench.sh dev
156
+ ```
157
+ 4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_dev_answers.jsonl`.
158
+
159
+ ### Chinese-Q-Bench
160
+
161
+ 1. Download [`质衡-问答-验证集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E9%AA%8C%E8%AF%81%E9%9B%86.json) (for `dev`-subset) and [`质衡-问答-测试集.json`](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/%E8%B4%A8%E8%A1%A1-%E9%97%AE%E7%AD%94-%E6%B5%8B%E8%AF%95%E9%9B%86.json) (for `test`-subset). Put them under `./playground/data/eval/qbench`.
162
+ 2. Download and extract [images](https://huggingface.co/datasets/nanyangtu/LLVisionQA-QBench/resolve/main/images_llvisionqa.tar) and put all the images directly under `./playground/data/eval/qbench/images_llviqionqa`.
163
+ 3. Single-GPU inference (change `dev` to `test` for evaluation on test set).
164
+ ```Shell
165
+ CUDA_VISIBLE_DEVICES=0 bash scripts/v1_5/eval/qbench_zh.sh dev
166
+ ```
167
+ 4. Submit the results by instruction [here](https://github.com/VQAssessment/Q-Bench#option-1-submit-results): `./playground/data/eval/qbench/llvisionqa_zh_dev_answers.jsonl`.
groundingLMM/LLaVA/docs/Finetune_Custom_Data.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetune LLaVA on Custom Datasets
2
+
3
+ ## Dataset Format
4
+
5
+ Convert your data to a JSON file of a List of all samples. Sample metadata should contain `id` (a unique identifier), `image` (the path to the image), and `conversations` (the conversation data between human and AI).
6
+
7
+ A sample JSON for finetuning LLaVA for generating tag-style captions for Stable Diffusion:
8
+
9
+ ```json
10
+ [
11
+ {
12
+ "id": "997bb945-628d-4724-b370-b84de974a19f",
13
+ "image": "part-000001/997bb945-628d-4724-b370-b84de974a19f.jpg",
14
+ "conversations": [
15
+ {
16
+ "from": "human",
17
+ "value": "<image>\nWrite a prompt for Stable Diffusion to generate this image."
18
+ },
19
+ {
20
+ "from": "gpt",
21
+ "value": "a beautiful painting of chernobyl by nekro, pascal blanche, john harris, greg rutkowski, sin jong hun, moebius, simon stalenhag. in style of cg art. ray tracing. cel shading. hyper detailed. realistic. ue 5. maya. octane render. "
22
+ },
23
+ ]
24
+ },
25
+ ...
26
+ ]
27
+ ```
28
+
29
+ ## Command
30
+
31
+ If you have a limited task-specific data, we recommend finetuning from LLaVA checkpoints with LoRA following this [script](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_task_lora.sh).
32
+
33
+ If the amount of the task-specific data is sufficient, you can also finetune from LLaVA checkpoints with full-model finetuning following this [script](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_task.sh).
34
+
35
+ You may need to adjust the hyperparameters to fit each specific dataset and your hardware constraint.
36
+
37
+
groundingLMM/LLaVA/docs/Intel.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Intel Platforms
2
+
3
+ * Support [Intel GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html)
4
+ * Support [Intel CPU Sapphire Rapides](https://ark.intel.com/content/www/us/en/ark/products/codename/126212/products-formerly-sapphire-rapids.html)
5
+ * Based on [Intel Extension for Pytorch](https://intel.github.io/intel-extension-for-pytorch)
6
+
7
+ More details in [**intel branch**](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel)
groundingLMM/LLaVA/docs/LLaVA_Bench.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA-Bench [[Download](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild)]
2
+
3
+ **-Introduction-** Large commercial multimodal chatbots have been released in this week, including
4
+ - [Multimodal Bing-Chat by Microsoft](https://blogs.bing.com/search/july-2023/Bing-Chat-Enterprise-announced,-multimodal-Visual-Search-rolling-out-to-Bing-Chat) (July 18, 2023)
5
+ - [Multimodal Bard by Google](https://bard.google.com/).
6
+
7
+ These chatbots are presumably supported by proprietary large multimodal models (LMM). Compared with the open-source LMM such as LLaVA, proprietary LMM represent the scaling success upperbound of the current SoTA techniques. They share the goal of developing multimodal chatbots that follow human intents to complete various daily-life visual tasks in the wild. While it remains less explored how to evaluate multimodal chat ability, it provides useful feedback to study open-source LMMs against the commercial multimodal chatbots. In addition to the *LLaVA-Bench (COCO)* dataset we used to develop the early versions of LLaVA, we are releasing [*LLaVA-Bench (In-the-Wild)*](https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild) to the community for the public use.
8
+
9
+ ## LLaVA-Bench (In-the-Wild *[Ongoing work]*)
10
+
11
+ To evaluate the model's capability in more challenging tasks and generalizability to novel domains, we collect a diverse set of 24 images with 60 questions in total, including indoor and outdoor scenes, memes, paintings, sketches, etc, and associate each image with a highly-detailed and manually-curated description and a proper selection of questions. Such design also assesses the model's robustness to different prompts. In this release, we also categorize questions into three categories: conversation (simple QA), detailed description, and complex reasoning. We continue to expand and improve the diversity of the LLaVA-Bench (In-the-Wild). We manually query Bing-Chat and Bard to get the responses.
12
+
13
+ ### Results
14
+
15
+ The score is measured by comparing against a reference answer generated by text-only GPT-4. It is generated by feeding the question, along with the ground truth image annotations as the context. A text-only GPT-4 evaluator rates both answers. We query GPT-4 by putting the reference answer first, and then the answer generated by the candidate model. We upload images at their original resolution to Bard and Bing-Chat to obtain the results.
16
+
17
+ | Approach | Conversation | Detail | Reasoning | Overall |
18
+ |----------------|--------------|--------|-----------|---------|
19
+ | Bard-0718 | 83.7 | 69.7 | 78.7 | 77.8 |
20
+ | Bing-Chat-0629 | 59.6 | 52.2 | 90.1 | 71.5 |
21
+ | LLaVA-13B-v1-336px-0719 (beam=1) | 64.3 | 55.9 | 81.7 | 70.1 |
22
+ | LLaVA-13B-v1-336px-0719 (beam=5) | 68.4 | 59.9 | 84.3 | 73.5 |
23
+
24
+ Note that Bard sometimes refuses to answer questions about images containing humans, and Bing-Chat blurs the human faces in the images. We also provide the benchmark score for the subset without humans.
25
+
26
+ | Approach | Conversation | Detail | Reasoning | Overall |
27
+ |----------------|--------------|--------|-----------|---------|
28
+ | Bard-0718 | 94.9 | 74.3 | 84.3 | 84.6 |
29
+ | Bing-Chat-0629 | 55.8 | 53.6 | 93.5 | 72.6 |
30
+ | LLaVA-13B-v1-336px-0719 (beam=1) | 62.2 | 56.4 | 82.2 | 70.0 |
31
+ | LLaVA-13B-v1-336px-0719 (beam=5) | 65.6 | 61.7 | 85.0 | 73.6 |
groundingLMM/LLaVA/docs/LLaVA_from_LLaMA2.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA (based on Llama 2 LLM, Preview)
2
+
3
+ *NOTE: This is a technical preview. We are still running hyperparameter search, and will release the final model soon. If you'd like to contribute to this, please contact us.*
4
+
5
+ :llama: **-Introduction-** [Llama 2 is an open-source LLM released by Meta AI](https://about.fb.com/news/2023/07/llama-2/) today (July 18, 2023). Compared with its early version [Llama 1](https://ai.meta.com/blog/large-language-model-llama-meta-ai/), Llama 2 is more favored in ***stronger language performance***, ***longer context window***, and importantly ***commercially usable***! While Llama 2 is changing the LLM market landscape in the language space, its multimodal ability remains unknown. We quickly develop the LLaVA variant based on the latest Llama 2 checkpoints, and release it to the community for the public use.
6
+
7
+ You need to apply for and download the latest Llama 2 checkpoints to start your own training (apply [here](https://ai.meta.com/resources/models-and-libraries/llama-downloads/))
8
+
9
+
10
+ ## Training
11
+
12
+ Please checkout [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/pretrain.sh), [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune.sh), [`finetune_lora.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune_lora.sh).
13
+
14
+ ## LLaVA (based on Llama 2), What is different?
15
+
16
+ :volcano: How is the new LLaVA based on Llama 2 different from Llama 1? The comparisons of the training process are described:
17
+ - **Pre-training**. The pre-trained base LLM is changed from Llama 1 to Llama 2
18
+ - **Language instruction-tuning**. The previous LLaVA model starts with Vicuna, which is instruct tuned on ShareGPT data from Llama 1; The new LLaVA model starts with Llama 2 Chat, which is an instruct tuned checkpoint on dialogue data from Llama 2.
19
+ - **Multimodal instruction-tuning**. The same LLaVA-Lighting process is applied.
20
+
21
+
22
+ ### Results
23
+
24
+ - Llama 2 is better at following the instructions of role playing; Llama 2 fails in following the instructions of translation
25
+ - The quantitative evaluation on [LLaVA-Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) demonstrates on-par performance between Llama 2 and Llama 1 in LLaVA's multimodal chat ability.
26
+
27
+
28
+ <img src="../images/llava_example_cmp.png" width="100%">
29
+
groundingLMM/LLaVA/docs/LoRA.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaVA (LoRA, Preview)
2
+
3
+ NOTE: This is a technical preview, and is not yet ready for production use. We are still running hyperparameter search for the LoRA model, and will release the final model soon. If you'd like to contribute to this, please contact us.
4
+
5
+ You need latest code base for LoRA support (instructions [here](https://github.com/haotian-liu/LLaVA#upgrade-to-latest-code-base))
6
+
7
+ ## Demo (Web UI)
8
+
9
+ Please execute each of the commands below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)).
10
+
11
+ #### Launch a controller
12
+ ```Shell
13
+ python -m llava.serve.controller --host 0.0.0.0 --port 10000
14
+ ```
15
+
16
+ #### Launch a gradio web server.
17
+ ```Shell
18
+ python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload
19
+ ```
20
+ You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker.
21
+
22
+ #### Launch a model worker
23
+ ```Shell
24
+ python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-vicuna-7b-v1.1-lcs_558k-instruct_80k_3e-lora-preview-alpha --model-base /path/to/vicuna-v1.1
25
+ ```
26
+ Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list.
27
+
28
+ You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker.
29
+
30
+
31
+ ## Training
32
+
33
+ Please see sample training scripts for [LoRA](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune_lora.sh) and [QLoRA](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune_qlora.sh).
34
+
35
+ We provide sample DeepSpeed configs, [`zero3.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3.json) is more like PyTorch FSDP, and [`zero3_offload.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero3_offload.json) can further save memory consumption by offloading parameters to CPU. `zero3.json` is usually faster than `zero3_offload.json` but requires more GPU memory, therefore, we recommend trying `zero3.json` first, and if you run out of GPU memory, try `zero3_offload.json`. You can also tweak the `per_device_train_batch_size` and `gradient_accumulation_steps` in the config to save memory, and just to make sure that `per_device_train_batch_size` and `gradient_accumulation_steps` remains the same.
36
+
37
+ If you are having issues with ZeRO-3 configs, and there are enough VRAM, you may try [`zero2.json`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/zero2.json). This consumes slightly more memory than ZeRO-3, and behaves more similar to PyTorch FSDP, while still supporting parameter-efficient tuning.
38
+
39
+ ## Create Merged Checkpoints
40
+
41
+ ```Shell
42
+ python scripts/merge_lora_weights.py \
43
+ --model-path /path/to/lora_model \
44
+ --model-base /path/to/base_model \
45
+ --save-model-path /path/to/merge_model
46
+ ```
groundingLMM/LLaVA/docs/MODEL_ZOO.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Zoo
2
+
3
+ **To Use LLaVA-1.6 checkpoints, your llava package version must be newer than 1.2.0. [Instructions](https://github.com/haotian-liu/LLaVA#upgrade-to-latest-code-base) on how to upgrade.**
4
+
5
+ If you are interested in including any other details in Model Zoo, please open an issue :)
6
+
7
+ The model weights below are *merged* weights. You do not need to apply delta. The usage of LLaVA checkpoints should comply with the base LLM's model license.
8
+
9
+ ## LLaVA-v1.6
10
+
11
+ | Version | LLM | Schedule | Checkpoint | MMMU | MathVista | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MM-Bench | MM-Bench-CN | SEED-IMG | LLaVA-Bench-Wild | MM-Vet |
12
+ |----------|----------|-----------|-----------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
13
+ | LLaVA-1.6 | Vicuna-7B | full_ft-1e | [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) | 35.8 | 34.6 | 81.8 | 64.2 | 57.6 | 70.1 | 64.9 | 86.5 | 1519/332 | 67.4 | 60.6 | 70.2 | 81.6 | 43.9 |
14
+ | LLaVA-1.6 | Vicuna-13B | full_ft-1e | [liuhaotian/llava-v1.6-vicuna-13b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b) | 36.2 | 35.3 | 82.8 | 65.4 | 60.5 | 73.6 | 67.1 | 86.2 | 1575/326 | 70 | 64.4 | 71.9 | 87.3 | 48.4 |
15
+ | LLaVA-1.6 | Mistral-7B | full_ft-1e | [liuhaotian/llava-v1.6-mistral-7b](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b) | 35.3 | 37.7 | 82.2 | 64.8 | 60.0 | 72.8 | 65.7 | 86.7 | 1498/321 | 68.7 | 61.2 | 72.2 | 83.2 | 47.3 |
16
+ | LLaVA-1.6 | Hermes-Yi-34B | full_ft-1e | [liuhaotian/llava-v1.6-34b](https://huggingface.co/liuhaotian/llava-v1.6-34b) | 51.1 | 46.5 | 83.7 | 67.1 | 63.8 | 81.8 | 69.5 | 87.7 | 1631/397 | 79.3 | 79 | 75.9 | 89.6 | 57.4 |
17
+
18
+ *LLaVA-1.6-34B outperforms Gemini Pro on benchmarks like MMMU and MathVista.*
19
+
20
+
21
+ ## LLaVA-v1.5
22
+
23
+ | Version | Size | Schedule | Checkpoint | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MM-Bench | MM-Bench-CN | SEED | LLaVA-Bench-Wild | MM-Vet |
24
+ |----------|----------|-----------|-----------|---|---|---|---|---|---|---|---|---|---|---|---|
25
+ | LLaVA-1.5 | 7B | full_ft-1e | [liuhaotian/llava-v1.5-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) | 78.5 | 62.0 | 50.0 | 66.8 | 58.2 | 85.9 | 1510.7 | 64.3 | 58.3 | 58.6 | 65.4 | 31.1 |
26
+ | LLaVA-1.5 | 13B | full_ft-1e | [liuhaotian/llava-v1.5-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b) | 80.0 | 63.3 | 53.6 | 71.6 | 61.3 | 85.9 | 1531.3 | 67.7 | 63.6 | 61.6 | 72.5 | 36.1 |
27
+ | LLaVA-1.5 | 7B | lora-1e | [liuhaotian/llava-v1.5-7b-lora](https://huggingface.co/liuhaotian/llava-v1.5-7b-lora) | 79.1 | 63.0 | 47.8 | 68.4 | 58.2 | 86.4 | 1476.9 | 66.1 | 58.9 | 60.1 | 67.9 | 30.2 |
28
+ | LLaVA-1.5 | 13B | lora-1e | [liuhaotian/llava-v1.5-13b-lora](https://huggingface.co/liuhaotian/llava-v1.5-13b-lora) | 80.0 | 63.3 | 58.9 | 71.2 | 60.2 | 86.7 | 1541.7 | 68.5 | 61.5 | 61.3 | 69.5 | 38.3 |
29
+
30
+ Base model: Vicuna v1.5. Training logs: [wandb](https://api.wandb.ai/links/lht/6orh56wc).
31
+
32
+ <p align="center">
33
+ <img src="../images/llava_v1_5_radar.jpg" width="500px"> <br>
34
+ LLaVA-1.5 achieves SoTA performance across 11 benchmarks.
35
+ </p>
36
+
37
+
38
+ ## LLaVA-v1
39
+
40
+ *Note: We recommend using the most capable LLaVA-v1.6 series above for the best performance.*
41
+
42
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | LLaVA-Bench-Conv | LLaVA-Bench-Detail | LLaVA-Bench-Complex | LLaVA-Bench-Overall | Download |
43
+ |----------|----------------|---------------|----------------------|-----------------|--------------------|------------------|--------------------|---------------------|---------------------|---------------------|
44
+ | Vicuna-13B-v1.3 | CLIP-L-336px | LCS-558K | 1e | LLaVA-Instruct-80K | proj-1e, lora-1e | 64.3 | 55.9 | 81.7 | 70.1 | [LoRA](https://huggingface.co/liuhaotian/llava-v1-0719-336px-lora-vicuna-13b-v1.3) [LoRA-Merged](https://huggingface.co/liuhaotian/llava-v1-0719-336px-lora-merge-vicuna-13b-v1.3) |
45
+ | LLaMA-2-13B-Chat | CLIP-L | LCS-558K | 1e | LLaVA-Instruct-80K | full_ft-1e | 56.7 | 58.6 | 80.0 | 67.9 | [ckpt](https://huggingface.co/liuhaotian/llava-llama-2-13b-chat-lightning-preview) |
46
+ | LLaMA-2-7B-Chat | CLIP-L | LCS-558K | 1e | LLaVA-Instruct-80K | lora-1e | 51.2 | 58.9 | 71.6 | 62.8 | [LoRA](https://huggingface.co/liuhaotian/llava-llama-2-7b-chat-lightning-lora-preview) |
47
+
48
+
49
+ ## Projector weights
50
+
51
+ These are projector weights we have pretrained. You can use these projector weights for visual instruction tuning. They are just pretrained on image-text pairs and are NOT instruction-tuned, which means they do NOT follow instructions as well as our official models and can output repetitive, lengthy, and garbled outputs. If you want to have nice conversations with LLaVA, use the checkpoints above (LLaVA v1.6).
52
+
53
+ NOTE: These projector weights are only compatible with `llava>=1.0.0`. Please check out the latest codebase if your local code version is below v1.0.0.
54
+
55
+ NOTE: When you use our pretrained projector for visual instruction tuning, it is very important to use the same base LLM and vision encoder as the one we used for pretraining the projector. Otherwise, the performance will be very poor.
56
+
57
+ When using these projector weights to instruction-tune your LMM, please make sure that these options are correctly set as follows,
58
+
59
+ ```Shell
60
+ --mm_use_im_start_end False
61
+ --mm_use_im_patch_token False
62
+ ```
63
+
64
+ | Base LLM | Vision Encoder | Projection | Pretrain Data | Pretraining schedule | Download |
65
+ |----------|----------------|---------------|----------------------|----------|----------|
66
+ | Vicuna-13B-v1.5 | CLIP-L-336px | MLP-2x | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-13b-v1.5) |
67
+ | Vicuna-7B-v1.5 | CLIP-L-336px | MLP-2x | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5) |
68
+ | LLaMA-2-13B-Chat | CLIP-L-336px | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-336px-pretrain-llama-2-13b-chat) |
69
+ | LLaMA-2-7B-Chat | CLIP-L-336px | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-336px-pretrain-llama-2-7b-chat) |
70
+ | LLaMA-2-13B-Chat | CLIP-L | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-pretrain-llama-2-13b-chat) |
71
+ | LLaMA-2-7B-Chat | CLIP-L | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-pretrain-llama-2-7b-chat) |
72
+ | Vicuna-13B-v1.3 | CLIP-L-336px | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-336px-pretrain-vicuna-13b-v1.3) |
73
+ | Vicuna-7B-v1.3 | CLIP-L-336px | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-336px-pretrain-vicuna-7b-v1.3) |
74
+ | Vicuna-13B-v1.3 | CLIP-L | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-pretrain-vicuna-13b-v1.3) |
75
+ | Vicuna-7B-v1.3 | CLIP-L | Linear | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/llava-pretrain-vicuna-7b-v1.3) |
76
+
77
+
78
+ ## Science QA Checkpoints
79
+
80
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | Download |
81
+ |----------|----------------|---------------|----------------------|-----------------|--------------------|---------------------|
82
+ | Vicuna-13B-v1.3 | CLIP-L | LCS-558K | 1e | ScienceQA | full_ft-12e | [ckpt](https://huggingface.co/liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3) |
83
+
84
+
85
+ ## Legacy Models (merged weights)
86
+
87
+ The model weights below are *merged* weights. You do not need to apply delta. The usage of LLaVA checkpoints should comply with the base LLM's model license.
88
+
89
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | Download |
90
+ |----------|----------------|---------------|----------------------|-----------------|--------------------|------------------|
91
+ | MPT-7B-Chat | CLIP-L | LCS-558K | 1e | LLaVA-Instruct-80K | full_ft-1e | [preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview) |
92
+
93
+
94
+ ## Legacy Models (delta weights)
95
+
96
+ The model weights below are *delta* weights. The usage of LLaVA checkpoints should comply with the base LLM's model license: [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md).
97
+
98
+ You can add our delta to the original LLaMA weights to obtain the LLaVA weights.
99
+
100
+ Instructions:
101
+
102
+ 1. Get the original LLaMA weights in the huggingface format by following the instructions [here](https://huggingface.co/docs/transformers/main/model_doc/llama).
103
+ 2. Use the following scripts to get LLaVA weights by applying our delta. It will automatically download delta weights from our Hugging Face account. In the script below, we use the delta weights of [`liuhaotian/LLaVA-7b-delta-v0`](https://huggingface.co/liuhaotian/LLaVA-7b-delta-v0) as an example. It can be adapted for other delta weights by changing the `--delta` argument (and base/target accordingly).
104
+
105
+ ```bash
106
+ python3 -m llava.model.apply_delta \
107
+ --base /path/to/llama-7b \
108
+ --target /output/path/to/LLaVA-7B-v0 \
109
+ --delta liuhaotian/LLaVA-7b-delta-v0
110
+ ```
111
+
112
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Finetuning Data | Finetuning schedule | Download |
113
+ |----------|----------------|---------------|----------------------|-----------------|--------------------|------------------|
114
+ | Vicuna-13B-v1.1 | CLIP-L | CC-595K | 1e | LLaVA-Instruct-158K | full_ft-3e | [delta-weights](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1) |
115
+ | Vicuna-7B-v1.1 | CLIP-L | LCS-558K | 1e | LLaVA-Instruct-80K | full_ft-1e | [delta-weights](https://huggingface.co/liuhaotian/LLaVA-Lightning-7B-delta-v1-1) |
116
+ | Vicuna-13B-v0 | CLIP-L | CC-595K | 1e | LLaVA-Instruct-158K | full_ft-3e | [delta-weights](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0) |
117
+ | Vicuna-13B-v0 | CLIP-L | CC-595K | 1e | ScienceQA | full_ft-12e | [delta-weights](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0-science_qa) |
118
+ | Vicuna-7B-v0 | CLIP-L | CC-595K | 1e | LLaVA-Instruct-158K | full_ft-3e | [delta-weights](https://huggingface.co/liuhaotian/LLaVA-7b-delta-v0) |
119
+
120
+
121
+
122
+ ## Legacy Projector weights
123
+
124
+ The following projector weights are deprecated, and the support for them may be removed in the future. They do not support zero-shot inference. Please use the projector weights in the [table above](#projector-weights) if possible.
125
+
126
+ **NOTE**: When you use our pretrained projector for visual instruction tuning, it is very important to **use the same base LLM and vision encoder** as the one we used for pretraining the projector. Otherwise, the performance will be very bad.
127
+
128
+ When using these projector weights to instruction tune your LMM, please make sure that these options are correctly set as follows,
129
+
130
+ ```Shell
131
+ --mm_use_im_start_end True
132
+ --mm_use_im_patch_token False
133
+ ```
134
+
135
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Download |
136
+ |----------|----------------|---------------|----------------------|----------|
137
+ | Vicuna-7B-v1.1 | CLIP-L | LCS-558K | 1e | [projector](https://huggingface.co/liuhaotian/LLaVA-Pretrained-Projectors/blob/main/LLaVA-7b-pretrain-projector-v1-1-LCS-558K-blip_caption.bin) |
138
+ | Vicuna-13B-v0 | CLIP-L | CC-595K | 1e | [projector](https://huggingface.co/liuhaotian/LLaVA-Pretrained-Projectors/blob/main/LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin) |
139
+ | Vicuna-7B-v0 | CLIP-L | CC-595K | 1e | [projector](https://huggingface.co/liuhaotian/LLaVA-Pretrained-Projectors/blob/main/LLaVA-7b-pretrain-projector-v0-CC3M-595K-original_caption.bin) |
140
+
141
+ When using these projector weights to instruction tune your LMM, please make sure that these options are correctly set as follows,
142
+
143
+ ```Shell
144
+ --mm_use_im_start_end False
145
+ --mm_use_im_patch_token False
146
+ ```
147
+
148
+ | Base LLM | Vision Encoder | Pretrain Data | Pretraining schedule | Download |
149
+ |----------|----------------|---------------|----------------------|----------|
150
+ | Vicuna-13B-v0 | CLIP-L | CC-595K | 1e | [projector](https://huggingface.co/liuhaotian/LLaVA-Pretrained-Projectors/blob/main/LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption-no_im_token.bin) |
groundingLMM/LLaVA/docs/ScienceQA.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### ScienceQA
2
+
3
+ #### Prepare Data
4
+ 1. Please see ScienceQA [repo](https://github.com/lupantech/ScienceQA) for setting up the dataset.
5
+ 2. Generate ScienceQA dataset for LLaVA conversation-style format.
6
+
7
+ ```Shell
8
+ python scripts/convert_sqa_to_llava.py \
9
+ convert_to_llava \
10
+ --base-dir /path/to/ScienceQA/data/scienceqa \
11
+ --prompt-format "QCM-LEA" \
12
+ --split {train,val,minival,test,minitest}
13
+ ```
14
+
15
+ #### Training
16
+
17
+ 1. Pretraining
18
+
19
+ You can download our pretrained projector weights from our [Model Zoo](), or train your own projector weights using [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/pretrain.sh).
20
+
21
+ 2. Finetuning
22
+
23
+ See [`finetune_sqa.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/finetune_sqa.sh).
24
+
25
+ #### Evaluation
26
+
27
+ 1. Multiple-GPU inference
28
+ You may evaluate this with multiple GPUs, and concatenate the generated jsonl files. Please refer to our script for [batch evaluation](https://github.com/haotian-liu/LLaVA/blob/main/scripts/sqa_eval_batch.sh) and [results gathering](https://github.com/haotian-liu/LLaVA/blob/main/scripts/sqa_eval_gather.sh).
29
+
30
+ 2. Single-GPU inference
31
+
32
+ (a) Generate LLaVA responses on ScienceQA dataset
33
+
34
+ ```Shell
35
+ python -m llava.eval.model_vqa_science \
36
+ --model-path liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3 \
37
+ --question-file /path/to/ScienceQA/data/scienceqa/llava_test_QCM-LEA.json \
38
+ --image-folder /path/to/ScienceQA/data/scienceqa/images/test \
39
+ --answers-file vqa/results/ScienceQA/test_llava-13b.jsonl \
40
+ --conv-mode llava_v1
41
+ ```
42
+
43
+ (b) Evaluate the generated responses
44
+
45
+ ```Shell
46
+ python eval_science_qa.py \
47
+ --base-dir /path/to/ScienceQA/data/scienceqa \
48
+ --result-file vqa/results/ScienceQA/test_llava-13b.jsonl \
49
+ --output-file vqa/results/ScienceQA/test_llava-13b_output.json \
50
+ --output-result vqa/results/ScienceQA/test_llava-13b_result.json \
51
+ ```
52
+
53
+ For reference, we attach our prediction file [`test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/table/results/test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json) and [`test_sqa_llava_13b_v0.json`](https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/table/results/test_sqa_llava_13b_v0.json) for comparison when reproducing our results, as well as for further analysis in detail.
groundingLMM/LLaVA/docs/Windows.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run LLaVA on Windows
2
+
3
+ *NOTE: LLaVA on Windows is not fully supported. Currently we only support 16-bit inference. For a more complete support, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install) for now. More functionalities on Windows is to be added soon, stay tuned.*
4
+
5
+ ## Installation
6
+
7
+ 1. Clone this repository and navigate to LLaVA folder
8
+ ```bash
9
+ git clone https://github.com/haotian-liu/LLaVA.git
10
+ cd LLaVA
11
+ ```
12
+
13
+ 2. Install Package
14
+ ```Shell
15
+ conda create -n llava python=3.10 -y
16
+ conda activate llava
17
+ python -m pip install --upgrade pip # enable PEP 660 support
18
+ pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117
19
+ pip install -e .
20
+ pip uninstall bitsandbytes
21
+ ```
22
+
23
+ ## Run demo
24
+
25
+ See instructions [here](https://github.com/haotian-liu/LLaVA#demo).
26
+
27
+ Note that quantization (4-bit, 8-bit) is *NOT* supported on Windows. Stay tuned for the 4-bit support on Windows!
groundingLMM/LLaVA/docs/macOS.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run LLaVA on macOS
2
+
3
+ *NOTE: LLaVA on macOS is not fully supported. Currently we only support 16-bit inference. More functionalities on macOS is to be added soon, stay tuned.*
4
+
5
+ ## Installation
6
+
7
+ 1. Clone this repository and navigate to LLaVA folder
8
+ ```bash
9
+ git clone https://github.com/haotian-liu/LLaVA.git
10
+ cd LLaVA
11
+ ```
12
+
13
+ 2. Install Package
14
+ ```Shell
15
+ conda create -n llava python=3.10 -y
16
+ conda activate llava
17
+ python -mpip install --upgrade pip # enable PEP 660 support
18
+ pip install -e .
19
+ pip install torch==2.1.0 torchvision==0.16.0
20
+ pip uninstall bitsandbytes
21
+ ```
22
+
23
+ ## Run demo
24
+
25
+ Specify `--device mps` when launching model worker or CLI.
26
+
27
+ See instructions [here](https://github.com/haotian-liu/LLaVA#demo).
28
+
29
+ Note that quantization (4-bit, 8-bit) is *NOT* supported on macOS. Stay tuned for the 4-bit support on macOS!
groundingLMM/LLaVA/scripts/convert_gqa_for_eval.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--src", type=str)
7
+ parser.add_argument("--dst", type=str)
8
+ args = parser.parse_args()
9
+
10
+ all_answers = []
11
+ for line_idx, line in enumerate(open(args.src)):
12
+ res = json.loads(line)
13
+ question_id = res['question_id']
14
+ text = res['text'].rstrip('.').lower()
15
+ all_answers.append({"questionId": question_id, "prediction": text})
16
+
17
+ with open(args.dst, 'w') as f:
18
+ json.dump(all_answers, f)
groundingLMM/LLaVA/scripts/convert_mmvet_for_eval.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--src", type=str)
7
+ parser.add_argument("--dst", type=str)
8
+ args = parser.parse_args()
9
+
10
+ cur_result = {}
11
+
12
+ for line in open(args.src):
13
+ data = json.loads(line)
14
+ qid = data['question_id']
15
+ cur_result[f'v1_{qid}'] = data['text']
16
+
17
+ with open(args.dst, 'w') as f:
18
+ json.dump(cur_result, f, indent=2)
groundingLMM/LLaVA/scripts/convert_sqa_to_llava_base_prompt.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_question_text(problem):
2
+ question = problem['question']
3
+ return question
4
+
5
+
6
+ def get_context_text(problem, use_caption):
7
+ txt_context = problem['hint']
8
+ img_context = problem['caption'] if use_caption else ""
9
+ context = " ".join([txt_context, img_context]).strip()
10
+ if context == "":
11
+ context = "N/A"
12
+ return context
13
+
14
+
15
+ def get_choice_text(probelm, options):
16
+ choices = probelm['choices']
17
+ choice_list = []
18
+ for i, c in enumerate(choices):
19
+ choice_list.append("({}) {}".format(options[i], c))
20
+ choice_txt = " ".join(choice_list)
21
+ #print(choice_txt)
22
+ return choice_txt
23
+
24
+
25
+ def get_answer(problem, options):
26
+ return options[problem['answer']]
27
+
28
+
29
+ def get_lecture_text(problem):
30
+ # \\n: GPT-3 can generate the lecture with more tokens.
31
+ lecture = problem['lecture'].replace("\n", "\\n")
32
+ return lecture
33
+
34
+
35
+ def get_solution_text(problem):
36
+ # \\n: GPT-3 can generate the solution with more tokens
37
+ solution = problem['solution'].replace("\n", "\\n")
38
+ return solution
39
+
40
+
41
+ def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
42
+
43
+ input_format, output_format = format.split("-")
44
+
45
+ ## Inputs
46
+ if input_format == "CQM":
47
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
48
+ elif input_format == "QCM":
49
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
50
+ # upper bound experiment
51
+ elif input_format == "QCML":
52
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
53
+ elif input_format == "QCME":
54
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
55
+ elif input_format == "QCMLE":
56
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
57
+
58
+ elif input_format == "QCLM":
59
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
60
+ elif input_format == "QCEM":
61
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
62
+ elif input_format == "QCLEM":
63
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
64
+
65
+ # Outputs
66
+ if test_example:
67
+ output = "Answer:"
68
+ elif output_format == 'A':
69
+ output = f"Answer: The answer is {answer}."
70
+
71
+ elif output_format == 'AL':
72
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
73
+ elif output_format == 'AE':
74
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
75
+ elif output_format == 'ALE':
76
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
77
+ elif output_format == 'AEL':
78
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
79
+
80
+ elif output_format == 'LA':
81
+ output = f"Answer: {lecture} The answer is {answer}."
82
+ elif output_format == 'EA':
83
+ output = f"Answer: {solution} The answer is {answer}."
84
+ elif output_format == 'LEA':
85
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
86
+ elif output_format == 'ELA':
87
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
88
+ elif output_format == 'LEPA':
89
+ output = ''
90
+ if len(lecture.strip()) > 0:
91
+ output += f"LECTURE: {lecture}\n"
92
+ if len(solution.strip()) > 0:
93
+ output += f"SOLUTION: {solution}\n"
94
+ output += '###\n'
95
+ output += f"ANSWER: {answer}."
96
+
97
+ input = input.replace(" ", " ").strip()
98
+ output = output.replace(" ", " ").strip()
99
+ if input.endswith("BECAUSE:"):
100
+ input = input.replace("BECAUSE:", "").strip()
101
+ if output.endswith("BECAUSE:"):
102
+ output = output.replace("BECAUSE:", "").strip()
103
+ return input, output
104
+
105
+
106
+ def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
107
+
108
+ input_format, output_format = format.split("-")
109
+
110
+ ## Inputs
111
+ if input_format == "CQM":
112
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
113
+ elif input_format == "QCM":
114
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
115
+ # upper bound experiment
116
+ elif input_format == "QCML":
117
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
118
+ elif input_format == "QCME":
119
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
120
+ elif input_format == "QCMLE":
121
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
122
+
123
+ elif input_format == "QCLM":
124
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
125
+ elif input_format == "QCEM":
126
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
127
+ elif input_format == "QCLEM":
128
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
129
+
130
+ # Outputs
131
+ if test_example:
132
+ output = "Answer:"
133
+ elif output_format == 'A':
134
+ output = f"Answer: The answer is {answer}."
135
+
136
+ elif output_format == 'AL':
137
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
138
+ elif output_format == 'AE':
139
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
140
+ elif output_format == 'ALE':
141
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
142
+ elif output_format == 'AEL':
143
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
144
+
145
+ elif output_format == 'LA':
146
+ output = f"Answer: {lecture} The answer is {answer}."
147
+ elif output_format == 'EA':
148
+ output = f"Answer: {solution} The answer is {answer}."
149
+ elif output_format == 'LEA':
150
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
151
+ elif output_format == 'ELA':
152
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
153
+
154
+ text = input + output
155
+ text = text.replace(" ", " ").strip()
156
+ if text.endswith("BECAUSE:"):
157
+ text = text.replace("BECAUSE:", "").strip()
158
+ return text
159
+
160
+
161
+
162
+ def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
163
+
164
+ input_format, output_format = format.split("-")
165
+
166
+ ## Inputs
167
+ if input_format == "CQM":
168
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
169
+ elif input_format == "QCM":
170
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
171
+ # upper bound experiment
172
+ elif input_format == "QCML":
173
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
174
+ elif input_format == "QCME":
175
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
176
+ elif input_format == "QCMLE":
177
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
178
+
179
+ elif input_format == "QCLM":
180
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
181
+ elif input_format == "QCEM":
182
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
183
+ elif input_format == "QCLEM":
184
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
185
+
186
+ # Outputs
187
+ if test_example:
188
+ output = "Answer:"
189
+ elif output_format == 'A':
190
+ output = f"Answer: The answer is {answer}."
191
+
192
+ elif output_format == 'AL':
193
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
194
+ elif output_format == 'AE':
195
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
196
+ elif output_format == 'ALE':
197
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
198
+ elif output_format == 'AEL':
199
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
200
+
201
+ elif output_format == 'LA':
202
+ output = f"Answer: {lecture} The answer is {answer}."
203
+ elif output_format == 'EA':
204
+ output = f"Answer: {solution} The answer is {answer}."
205
+ elif output_format == 'LEA':
206
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
207
+ elif output_format == 'ELA':
208
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
209
+
210
+ input = input.replace(" ", " ").strip()
211
+ output = output.replace(" ", " ").strip()
212
+ if output.endswith("BECAUSE:"):
213
+ output = output.replace("BECAUSE:", "").strip()
214
+
215
+ user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
216
+ assistant_prompt = {"role": "assistant", "content": f"{output}"}
217
+
218
+ return user_prompt, assistant_prompt
219
+
220
+
221
+ def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False):
222
+ examples = {}
223
+
224
+ for qid in shot_qids:
225
+ question = get_question_text(problems[qid])
226
+ context = get_context_text(problems[qid], use_caption)
227
+ choice = get_choice_text(problems[qid], options)
228
+ answer = get_answer(problems[qid], options)
229
+ lecture = get_lecture_text(problems[qid]).replace('\\n', '\n')
230
+ solution = get_solution_text(problems[qid]).replace('\\n', '\n')
231
+
232
+ train_example = create_one_example_chatbot(prompt_format,
233
+ question,
234
+ context,
235
+ choice,
236
+ answer,
237
+ lecture,
238
+ solution,
239
+ test_example=is_test)
240
+ examples[qid] = train_example
241
+ return examples
242
+
243
+
244
+ def build_prompt(problems, shot_qids, test_qid, args):
245
+
246
+ examples = []
247
+
248
+ # n-shot training examples
249
+ for qid in shot_qids:
250
+ question = get_question_text(problems[qid])
251
+ context = get_context_text(problems[qid], args.use_caption)
252
+ choice = get_choice_text(problems[qid], args.options)
253
+ answer = get_answer(problems[qid], args.options)
254
+ lecture = get_lecture_text(problems[qid])
255
+ solution = get_solution_text(problems[qid])
256
+
257
+ train_example = create_one_example(args.prompt_format,
258
+ question,
259
+ context,
260
+ choice,
261
+ answer,
262
+ lecture,
263
+ solution,
264
+ test_example=False)
265
+ examples.append(train_example)
266
+
267
+ # test example
268
+ question = get_question_text(problems[test_qid])
269
+ context = get_context_text(problems[test_qid], args.use_caption)
270
+ choice = get_choice_text(problems[test_qid], args.options)
271
+ answer = get_answer(problems[test_qid], args.options)
272
+ lecture = get_lecture_text(problems[test_qid])
273
+ solution = get_solution_text(problems[test_qid])
274
+
275
+ test_example = create_one_example(args.prompt_format,
276
+ question,
277
+ context,
278
+ choice,
279
+ answer,
280
+ lecture,
281
+ solution,
282
+ test_example=True)
283
+ examples.append(test_example)
284
+
285
+ # create the prompt input
286
+ prompt_input = '\n\n'.join(examples)
287
+
288
+ return prompt_input
289
+
290
+
291
+ def build_prompt_gpt4(problems, shot_qids, test_qid, args):
292
+
293
+ prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
294
+
295
+ # n-shot training examples
296
+ for qid in shot_qids:
297
+ question = get_question_text(problems[qid])
298
+ context = get_context_text(problems[qid], args.use_caption)
299
+ choice = get_choice_text(problems[qid], args.options)
300
+ answer = get_answer(problems[qid], args.options)
301
+ lecture = get_lecture_text(problems[qid])
302
+ solution = get_solution_text(problems[qid])
303
+
304
+ user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
305
+ question,
306
+ context,
307
+ choice,
308
+ answer,
309
+ lecture,
310
+ solution,
311
+ test_example=False)
312
+ prompt_array.append(user_prompt)
313
+ prompt_array.append(assistant_prompt)
314
+
315
+ # test example
316
+ question = get_question_text(problems[test_qid])
317
+ context = get_context_text(problems[test_qid], args.use_caption)
318
+ choice = get_choice_text(problems[test_qid], args.options)
319
+ answer = get_answer(problems[test_qid], args.options)
320
+ lecture = get_lecture_text(problems[test_qid])
321
+ solution = get_solution_text(problems[test_qid])
322
+
323
+ user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
324
+ question,
325
+ context,
326
+ choice,
327
+ answer,
328
+ lecture,
329
+ solution,
330
+ test_example=True)
331
+ prompt_array.append(user_prompt)
332
+ prompt_array.append(assistant_prompt)
333
+
334
+ return prompt_array
groundingLMM/LLaVA/scripts/convert_vizwiz_for_submission.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+
5
+ from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
6
+
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--annotation-file', type=str, required=True)
11
+ parser.add_argument('--result-file', type=str, required=True)
12
+ parser.add_argument('--result-upload-file', type=str, required=True)
13
+ return parser.parse_args()
14
+
15
+
16
+ if __name__ == '__main__':
17
+
18
+ args = parse_args()
19
+
20
+ os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
21
+
22
+ results = []
23
+ error_line = 0
24
+ for line_idx, line in enumerate(open(args.result_file)):
25
+ try:
26
+ results.append(json.loads(line))
27
+ except:
28
+ error_line += 1
29
+ results = {x['question_id']: x['text'] for x in results}
30
+ test_split = [json.loads(line) for line in open(args.annotation_file)]
31
+ split_ids = set([x['question_id'] for x in test_split])
32
+
33
+ print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
34
+
35
+ all_answers = []
36
+
37
+ answer_processor = EvalAIAnswerProcessor()
38
+
39
+ for x in test_split:
40
+ assert x['question_id'] in results
41
+ all_answers.append({
42
+ 'image': x['image'],
43
+ 'answer': answer_processor(results[x['question_id']])
44
+ })
45
+
46
+ with open(args.result_upload_file, 'w') as f:
47
+ json.dump(all_answers, f)
groundingLMM/LLaVA/scripts/extract_mm_projector.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is just a utility that I use to extract the projector for quantized models.
3
+ It is NOT necessary at all to train, or run inference/serve demos.
4
+ Use this script ONLY if you fully understand its implications.
5
+ """
6
+
7
+
8
+ import os
9
+ import argparse
10
+ import torch
11
+ import json
12
+ from collections import defaultdict
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description='Extract MMProjector weights')
17
+ parser.add_argument('--model-path', type=str, help='model folder')
18
+ parser.add_argument('--output', type=str, help='output file')
19
+ args = parser.parse_args()
20
+ return args
21
+
22
+
23
+ if __name__ == '__main__':
24
+ args = parse_args()
25
+
26
+ keys_to_match = ['mm_projector']
27
+ ckpt_to_key = defaultdict(list)
28
+ try:
29
+ model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
30
+ for k, v in model_indices['weight_map'].items():
31
+ if any(key_match in k for key_match in keys_to_match):
32
+ ckpt_to_key[v].append(k)
33
+ except FileNotFoundError:
34
+ # Smaller models or model checkpoints saved by DeepSpeed.
35
+ v = 'pytorch_model.bin'
36
+ for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
37
+ if any(key_match in k for key_match in keys_to_match):
38
+ ckpt_to_key[v].append(k)
39
+
40
+ loaded_weights = {}
41
+
42
+ for ckpt_name, weight_keys in ckpt_to_key.items():
43
+ ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
44
+ for k in weight_keys:
45
+ loaded_weights[k] = ckpt[k]
46
+
47
+ torch.save(loaded_weights, args.output)
groundingLMM/LLaVA/scripts/finetune_qlora.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4
+
5
+ # Uncomment and set the following variables correspondingly to run this script:
6
+
7
+ ################## VICUNA ##################
8
+ # PROMPT_VERSION=v1
9
+ # MODEL_VERSION="vicuna-v1-3-7b"
10
+ ################## VICUNA ##################
11
+
12
+ ################## LLaMA-2 ##################
13
+ # PROMPT_VERSION="llava_llama_2"
14
+ # MODEL_VERSION="llama-2-7b-chat"
15
+ ################## LLaMA-2 ##################
16
+
17
+ deepspeed llava/train/train_mem.py \
18
+ --deepspeed ./scripts/zero2.json \
19
+ --lora_enable True \
20
+ --bits 4 \
21
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
22
+ --version $PROMPT_VERSION \
23
+ --data_path ./playground/data/llava_instruct_80k.json \
24
+ --image_folder /path/to/coco/train2017 \
25
+ --vision_tower openai/clip-vit-large-patch14 \
26
+ --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
27
+ --mm_vision_select_layer -2 \
28
+ --mm_use_im_start_end False \
29
+ --mm_use_im_patch_token False \
30
+ --bf16 True \
31
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
32
+ --num_train_epochs 1 \
33
+ --per_device_train_batch_size 16 \
34
+ --per_device_eval_batch_size 4 \
35
+ --gradient_accumulation_steps 1 \
36
+ --evaluation_strategy "no" \
37
+ --save_strategy "steps" \
38
+ --save_steps 50000 \
39
+ --save_total_limit 1 \
40
+ --learning_rate 2e-5 \
41
+ --weight_decay 0. \
42
+ --warmup_ratio 0.03 \
43
+ --lr_scheduler_type "cosine" \
44
+ --logging_steps 1 \
45
+ --tf32 True \
46
+ --model_max_length 2048 \
47
+ --gradient_checkpointing True \
48
+ --lazy_preprocess True \
49
+ --dataloader_num_workers 4 \
50
+ --report_to wandb
groundingLMM/LLaVA/scripts/merge_lora_weights.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from llava.model.builder import load_pretrained_model
3
+ from llava.mm_utils import get_model_name_from_path
4
+
5
+
6
+ def merge_lora(args):
7
+ model_name = get_model_name_from_path(args.model_path)
8
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9
+
10
+ model.save_pretrained(args.save_model_path)
11
+ tokenizer.save_pretrained(args.save_model_path)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--model-path", type=str, required=True)
17
+ parser.add_argument("--model-base", type=str, required=True)
18
+ parser.add_argument("--save-model-path", type=str, required=True)
19
+
20
+ args = parser.parse_args()
21
+
22
+ merge_lora(args)
groundingLMM/LLaVA/scripts/pretrain.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4
+
5
+ # Uncomment and set the following variables correspondingly to run this script:
6
+
7
+ # MODEL_VERSION=vicuna-v1-3-7b
8
+ # MODEL_VERSION=llama-2-7b-chat
9
+
10
+ ########### DO NOT CHANGE ###########
11
+ ########### USE THIS FOR BOTH ###########
12
+ PROMPT_VERSION=plain
13
+ ########### DO NOT CHANGE ###########
14
+
15
+ deepspeed llava/train/train_mem.py \
16
+ --deepspeed ./scripts/zero2.json \
17
+ --model_name_or_path ./checkpoints/$MODEL_VERSION \
18
+ --version $PROMPT_VERSION \
19
+ --data_path /path/to/pretrain_data.json \
20
+ --image_folder /path/to/images \
21
+ --vision_tower openai/clip-vit-large-patch14 \
22
+ --tune_mm_mlp_adapter True \
23
+ --mm_vision_select_layer -2 \
24
+ --mm_use_im_start_end False \
25
+ --mm_use_im_patch_token False \
26
+ --bf16 True \
27
+ --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
28
+ --num_train_epochs 1 \
29
+ --per_device_train_batch_size 16 \
30
+ --per_device_eval_batch_size 4 \
31
+ --gradient_accumulation_steps 1 \
32
+ --evaluation_strategy "no" \
33
+ --save_strategy "steps" \
34
+ --save_steps 24000 \
35
+ --save_total_limit 1 \
36
+ --learning_rate 2e-3 \
37
+ --weight_decay 0. \
38
+ --warmup_ratio 0.03 \
39
+ --lr_scheduler_type "cosine" \
40
+ --logging_steps 1 \
41
+ --tf32 True \
42
+ --model_max_length 2048 \
43
+ --gradient_checkpointing True \
44
+ --dataloader_num_workers 4 \
45
+ --lazy_preprocess True \
46
+ --report_to wandb
groundingLMM/LLaVA/scripts/upload_pypi.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Step 0: Clean up
4
+ rm -rf dist
5
+
6
+ # Step 1: Change the package name to "llava-torch"
7
+ sed -i 's/name = "llava"/name = "llava-torch"/' pyproject.toml
8
+
9
+ # Step 2: Build the package
10
+ python -m build
11
+
12
+ # Step 3: Revert the changes in pyproject.toml to the original
13
+ sed -i 's/name = "llava-torch"/name = "llava"/' pyproject.toml
14
+
15
+ # Step 4: Upload to PyPI
16
+ python -m twine upload dist/*
groundingLMM/LLaVA/scripts/zero2.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 2,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto"
22
+ }
23
+ }
groundingLMM/dataset/caption_datasets/COCO_Caption_ds.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from pycocotools.coco import COCO
7
+ from transformers import CLIPImageProcessor
8
+ from model.llava import conversation as conversation_lib
9
+ from model.SAM.utils.transforms import ResizeLongestSide
10
+ from tools.utils import DEFAULT_IMAGE_TOKEN
11
+ from dataset.utils.utils import CAPTION_QUESTIONS
12
+
13
+
14
+ class CocoCapDataset(torch.utils.data.Dataset):
15
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
16
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
17
+ IMG_SIZE = 1024
18
+ IGNORE_LABEL = 255
19
+
20
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=10000, precision="fp32",
21
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
22
+ self.epoch_samples = epoch_samples
23
+ self.num_classes_per_sample = num_classes_per_sample
24
+
25
+ self.dataset_dir = dataset_dir
26
+ self.image_size = image_size
27
+ self.tokenizer = tokenizer
28
+ self.precision = precision
29
+ self.transform = ResizeLongestSide(image_size)
30
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
31
+
32
+ self.max_gt_per_img = max_gt_per_img
33
+ self.validation = validation
34
+ self.random_sampling = random_sampling
35
+
36
+ # Defining paths
37
+ mode = "val" if validation else "train"
38
+ self.base_dir = os.path.join(dataset_dir, "coco_2017")
39
+ self.image_folder = os.path.join(dataset_dir, f"coco_2017/{mode}2017")
40
+ json_files = {'validation': "captions_val2017.json", 'training': "captions_train2017.json"}
41
+ annotations_file = os.path.join(self.base_dir, "annotations",
42
+ json_files['validation'] if validation else json_files['training'])
43
+ self.data_infos = self._load_annotations(annotations_file)
44
+
45
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
46
+ mode = "Val" if validation else "Train"
47
+ print('\033[92m' + "----CAP-{}: COCO Caption dataset initialized----".format(mode) + '\033[0m')
48
+
49
+ def _load_annotations(self, annotation_file):
50
+ self.coco_api = COCO(annotation_file)
51
+ ann_ids = self.coco_api.getAnnIds()
52
+ # Limiting anns to 1000(optional) for validation
53
+ ann_ids = ann_ids[:1000] if self.validation else ann_ids
54
+ images_info = []
55
+ for i, id in enumerate(ann_ids):
56
+ annotation = self.coco_api.loadAnns([id])[0]
57
+ image_id = annotation['image_id']
58
+ image_info = self.coco_api.loadImgs([image_id])[0]
59
+ image_info['filename'] = image_info['file_name'].split('_')[-1]
60
+ images_info.append(image_info)
61
+ return images_info
62
+
63
+ def _parse_ann_info(self, annotation):
64
+ return {'caption': annotation['caption'].strip()}
65
+
66
+ def __getitem__(self, idx):
67
+ ann_id = random.choice(self.coco_api.getAnnIds())
68
+ annotation = self.coco_api.loadAnns(ann_id)[0]
69
+ image_info = self.coco_api.loadImgs([annotation['image_id']])[0]
70
+
71
+ # Extract caption from annotation
72
+ caption_info = self._parse_ann_info(annotation)
73
+
74
+ data = {"image_path": os.path.join(self.image_folder, image_info['file_name']),
75
+ "filename": image_info['file_name'],
76
+ "caption": caption_info['caption'],
77
+ }
78
+
79
+ processed_data = self.process_data(data)
80
+ return processed_data
81
+
82
+ def __len__(self):
83
+ return len(self.data_infos)
84
+
85
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = (x - self.IMG_MEAN) / self.IMG_STD
87
+ h, w = x.shape[-2:]
88
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
89
+ return x
90
+
91
+ def create_conversations(self, labels):
92
+ conversations = []
93
+ questions = []
94
+ conv = conversation_lib.default_conversation.copy()
95
+ conv.messages = []
96
+
97
+ question = random.choice(CAPTION_QUESTIONS).strip()
98
+ answer = labels
99
+
100
+ conv.append_message(conv.roles[0], self.begin_str + question)
101
+ conv.append_message(conv.roles[1], answer)
102
+ prompt = conv.get_prompt()
103
+ conversations.append(prompt)
104
+ return questions, conversations
105
+
106
+ def process_data(self, data_item):
107
+ caption = data_item['caption']
108
+ image_path = data_item['image_path']
109
+ image = cv2.imread(image_path)
110
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
111
+ # Prepare input for Global Image Encoder
112
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
113
+ # Skip input for Grounding Image Encoder
114
+ grounding_enc_image = None
115
+ image_resize = None
116
+
117
+ masks, bboxes = None, None
118
+
119
+ questions, conversations = self.create_conversations(caption)
120
+ label = None
121
+ selected_labels = [caption]
122
+
123
+ return (image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
124
+ questions, selected_labels)
groundingLMM/dataset/caption_datasets/GranD_ShortCaption_ds.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import lmdb
4
+ import json
5
+ import random
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import CLIPImageProcessor
9
+ from model.llava import conversation as conversation_lib
10
+ from model.SAM.utils.transforms import ResizeLongestSide
11
+ from dataset.utils.utils import CAPTION_QUESTIONS
12
+ from tools.utils import DEFAULT_IMAGE_TOKEN
13
+
14
+
15
+ class GrandShortCaptionDataset(torch.utils.data.Dataset):
16
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
17
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
18
+ IMG_SIZE = 1024
19
+ IGNORE_LABEL = 255
20
+
21
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=10000, precision="fp32",
22
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True):
23
+
24
+ self.dataset_dir = dataset_dir
25
+ self.image_size = image_size
26
+ self.tokenizer = tokenizer
27
+ self.precision = precision
28
+ self.transform = ResizeLongestSide(image_size)
29
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
30
+ self.epoch_samples = epoch_samples
31
+ self.num_classes_per_sample = num_classes_per_sample
32
+ self.validation = validation
33
+ self.random_sampling = random_sampling
34
+
35
+ # Defining paths
36
+ self.base_dir = os.path.join(dataset_dir, "GranD_Data")
37
+ self.image_folder = os.path.join(self.base_dir, "images")
38
+ ann_file_name = "Grand_Caption_Grounding_lmdb"
39
+ ann_path = os.path.join(self.base_dir, ann_file_name)
40
+ self.annos = lmdb.open(ann_path, readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
41
+ mode = "Val" if validation else "Train"
42
+ self.data_infos = self._load_annotations(os.path.join(self.base_dir, ann_file_name, f'{ann_file_name}_{mode}.txt'))
43
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
44
+ print('\033[92m' + "----CAP-{}: Grand Short Caption dataset initialized----".format(mode) + '\033[0m')
45
+
46
+ def _load_annotations(self, ann_file):
47
+ with open(ann_file, 'r') as f:
48
+ data_infos = [line.strip() for line in f if line.strip()]
49
+ data_infos = data_infos[0: 1000] if self.validation else data_infos
50
+ return data_infos
51
+
52
+ def __len__(self):
53
+ return len(self.data_infos)
54
+
55
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
56
+ x = (x - self.IMG_MEAN) / self.IMG_STD
57
+ h, w = x.shape[-2:]
58
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
59
+ return x
60
+
61
+ def create_conversations(self, labels):
62
+ conversations = []
63
+ questions = []
64
+ conv = conversation_lib.default_conversation.copy()
65
+ conv.messages = []
66
+
67
+ question = random.choice(CAPTION_QUESTIONS).strip()
68
+ answer = labels
69
+
70
+ conv.append_message(conv.roles[0], self.begin_str + question)
71
+ conv.append_message(conv.roles[1], answer)
72
+ prompt = conv.get_prompt()
73
+ conversations.append(prompt)
74
+ return questions, conversations
75
+
76
+ def __getitem__(self, idx):
77
+ image_name = self.data_infos[idx] if (self.validation or not self.random_sampling) else self.data_infos[
78
+ random.randint(0, len(self.data_infos) - 1)]
79
+ # Get the annotation from lmdb
80
+ with self.annos.begin() as txn:
81
+ json_contents = txn.get(image_name.encode())
82
+ json_contents = json.loads(json_contents.decode('utf-8'))
83
+ ann_info = random.choice(json_contents[image_name])
84
+ # Process the image
85
+ image_path = os.path.join(self.image_folder, image_name)
86
+ image = cv2.imread(image_path)
87
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
+ # Prepare input for Global Image Encoder
89
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
90
+ # Skip input for Grounding Image Encoder
91
+ grounding_enc_image = None
92
+ image_resize = None
93
+ bboxes = None
94
+
95
+ caption = ann_info["caption"]
96
+ questions, conversations = self.create_conversations(caption)
97
+ selected_labels = conversations
98
+
99
+ masks = torch.rand(0, *image_resize)
100
+ label = None
101
+
102
+ assert len(conversations) == 1
103
+
104
+ return (image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
105
+ questions, selected_labels)
groundingLMM/dataset/caption_datasets/LLavaInstruct_vqa_ds.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import CLIPImageProcessor
8
+ from model.llava import conversation as conversation_lib
9
+ from model.SAM.utils.transforms import ResizeLongestSide
10
+ from tools.utils import DEFAULT_IMAGE_TOKEN
11
+
12
+
13
+ class LLaVAInstructDataset(torch.utils.data.Dataset):
14
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
15
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
16
+ IMG_SIZE = 1024
17
+ IGNORE_LABEL = 255
18
+
19
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=10000, precision="fp32",
20
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True):
21
+
22
+ self.dataset_dir = dataset_dir
23
+ self.image_size = image_size
24
+ self.tokenizer = tokenizer
25
+ self.precision = precision
26
+ self.transform = ResizeLongestSide(image_size)
27
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
28
+ self.epoch_samples = epoch_samples
29
+ self.num_classes_per_sample = num_classes_per_sample
30
+ self.validation = validation
31
+ self.random_sampling = random_sampling
32
+
33
+ # Defining paths
34
+ mode = "val" if validation else "train"
35
+ self.base_dir = os.path.join(dataset_dir, "llava_dataset")
36
+ self.image_folder = os.path.join(dataset_dir, f"coco_2017/{mode}2017")
37
+ annotations_file = os.path.join(self.base_dir, "llava_instruct_150k.json")
38
+ self.data_infos = self._load_annotations(annotations_file)
39
+ print('\033[92m' + "----CAP-{}: LLaVA-Instruct VQA dataset initialized----".format(mode) + '\033[0m')
40
+
41
+ def _load_annotations(self, ann_file):
42
+ with open(ann_file, 'r') as f:
43
+ data_infos = json.load(f)
44
+ data_infos = data_infos[0: 1000] if self.validation else data_infos
45
+ return data_infos
46
+
47
+ def __len__(self):
48
+ return len(self.vqa_data)
49
+
50
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
51
+ x = (x - self.IMG_MEAN) / self.IMG_STD
52
+ h, w = x.shape[-2:]
53
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
54
+ return x
55
+
56
+ def create_conversations(self, conv_ann):
57
+ # Preprocess:
58
+ for sentence in conv_ann:
59
+ if DEFAULT_IMAGE_TOKEN in sentence["value"]:
60
+ sentence["value"] = (sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip())
61
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
62
+ sentence["value"] = sentence["value"].strip()
63
+ if "mmtag" in conversation_lib.default_conversation.version:
64
+ sentence["value"] = sentence["value"].replace(
65
+ DEFAULT_IMAGE_TOKEN, "<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>"
66
+ )
67
+ conversations = []
68
+ conv = conversation_lib.default_conversation.copy()
69
+ conv.messages = []
70
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
71
+ if roles[conv_ann[0]["from"]] != conv.roles[0]:
72
+ # Skip the first one if it is not from human
73
+ conv_ann = conv_ann[1:]
74
+
75
+ for j, sentence in enumerate(conv_ann):
76
+ role = roles[sentence["from"]]
77
+ assert role == conv.roles[j % 2], f"{j}"
78
+ conv.append_message(role, sentence["value"])
79
+ conversations.append(conv.get_prompt())
80
+ questions = conversations
81
+
82
+ return questions, conversations
83
+
84
+ def __getitem__(self, idx):
85
+ ann_info = self.data_infos[idx] if (self.validation or not self.random_sampling) else self.data_infos[
86
+ random.randint(0, len(self.data_infos) - 1)]
87
+ image_path = os.path.join(self.image_folder, ann_info["image"])
88
+ image = cv2.imread(image_path)
89
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
90
+ # Prepare input for Global Image Encoder
91
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
92
+ # Skip input for Grounding Image Encoder
93
+ grounding_enc_image = None
94
+ image_resize = None
95
+ bboxes = None
96
+
97
+ conv_ann = ann_info["conversations"]
98
+ questions, conversations = self.create_conversations(conv_ann)
99
+ selected_labels = conversations
100
+
101
+ masks = None
102
+ label = None
103
+
104
+ assert len(conversations) == 1
105
+
106
+ return (image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
107
+ questions, selected_labels)
groundingLMM/dataset/gcg_datasets/GranDf_gcg_ds.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from pycocotools import mask
10
+ from pycocotools.coco import COCO
11
+ from transformers import CLIPImageProcessor
12
+ from model.llava import conversation as conversation_lib
13
+ from model.SAM.utils.transforms import ResizeLongestSide
14
+ from tools.utils import DEFAULT_IMAGE_TOKEN
15
+ from dataset.utils.utils import GCG_QUESTIONS
16
+
17
+
18
+ class GCGBaseDataset(torch.utils.data.Dataset):
19
+ """
20
+ Dataset Class for Grounded Conversation Generation (GCG) proposed in GLaMM.
21
+ """
22
+ CLASSES = ('object',)
23
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
24
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
25
+ IMG_SIZE = 1024
26
+ IGNORE_LABEL = 255
27
+
28
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
29
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True,
30
+ image_dir='', json_path=''):
31
+ self.epoch_samples = epoch_samples
32
+ self.num_classes_per_sample = num_classes_per_sample
33
+ self.dataset_dir = dataset_dir
34
+ self.image_size = image_size
35
+ self.tokenizer = tokenizer
36
+ self.precision = precision
37
+ self.transform = ResizeLongestSide(image_size)
38
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
39
+ self.validation = validation
40
+ self.random_sampling = random_sampling
41
+
42
+ self.question_templates = GCG_QUESTIONS
43
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
44
+ self.validation = validation
45
+
46
+ # Defining paths
47
+ self.base_dir = os.path.join(dataset_dir, "GranDf")
48
+ self.image_folder = os.path.join(image_dir)
49
+ self.ann_file = os.path.join(self.base_dir, "annotations", "train", json_path)
50
+ self.data_infos = self._load_annotations(self.ann_file)
51
+
52
+ def _load_annotations(self, ann_file):
53
+ with open(ann_file, 'r') as f:
54
+ data_infos = json.load(f)
55
+ data_infos = data_infos[0: 1000] if self.validation else data_infos
56
+ return data_infos
57
+
58
+ def _parse_annotations(self, ann_info):
59
+ image_path = os.path.join(self.image_folder, ann_info['file_name'])
60
+ annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [],
61
+ 'file_name': ann_info['file_name']}
62
+ width, height = Image.open(image_path).size
63
+ annotations['caption'] = ann_info['caption'].strip('"').strip()
64
+
65
+ for word, grounding in ann_info["groundings"].items():
66
+ annotations['labels'].append(word)
67
+ annotations['tokens_positive'].append(grounding["token_positives"])
68
+
69
+ # Convert segmentation to binary mask
70
+ binary_mask = np.zeros((height, width), dtype=np.uint8)
71
+ for rle in grounding["rle_masks"]:
72
+ m = mask.decode(rle).astype(np.uint8)
73
+ binary_mask += m.squeeze()
74
+ annotations['masks'].append(binary_mask)
75
+
76
+ return annotations
77
+
78
+ def __getitem__(self, index):
79
+ while True:
80
+ ann_info = self.data_infos[index] if (self.validation or not self.random_sampling) \
81
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
82
+ # Parse annotation info
83
+ ann = self._parse_annotations(ann_info)
84
+ image_path = os.path.join(self.image_folder, ann['file_name'])
85
+ if len(ann['labels']) > 0:
86
+ break
87
+ else:
88
+ index = random.randint(0, len(self.data_infos) - 1)
89
+ data_item = {"image_path": image_path, "filename": ann['file_name'], "caption": ann['caption'],
90
+ "labels": ann['labels'], "masks": ann['masks'], "tokens_positive": ann['tokens_positive']}
91
+ return self.process_data(data_item)
92
+
93
+ def __len__(self):
94
+ return len(self.data_infos)
95
+
96
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
97
+ x = (x - self.IMG_MEAN) / self.IMG_STD
98
+ h, w = x.shape[-2:]
99
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
100
+ return x
101
+
102
+ def create_conversations(self, caption, tokens_positive):
103
+ question = random.choice(self.question_templates).strip()
104
+
105
+ # Prepare caption with tags
106
+ def tag_caption(caption, tokens):
107
+ for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
108
+ caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
109
+ return caption
110
+
111
+ detailed_answer = tag_caption(caption, tokens_positive)
112
+
113
+ conversations = []
114
+ conv = conversation_lib.default_conversation.copy()
115
+ conv.messages = []
116
+ conv.append_message(conv.roles[0], self.begin_str + question)
117
+ conv.append_message(conv.roles[1], detailed_answer)
118
+ conversations.append(conv.get_prompt())
119
+ questions = [question]
120
+ return questions, conversations
121
+
122
+ def process_data(self, data_item):
123
+ data_labels = data_item['labels']
124
+ masks = data_item['masks']
125
+ caption = data_item['caption']
126
+ tokens_positive = data_item['tokens_positive']
127
+ image_path = data_item['image_path']
128
+
129
+ # Function to sort elements based on the start index of each phrase
130
+ def sort_by_start_index(items, order):
131
+ return [items[i] for i in order]
132
+
133
+ # Sort phrases based on their appearance in the sentence
134
+ phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
135
+ masks = sort_by_start_index(masks, phrase_order)
136
+ data_labels = sort_by_start_index(data_labels, phrase_order)
137
+ tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
138
+
139
+ image = cv2.imread(image_path)
140
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
141
+ # Prepare input for Global Image Encoder
142
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
143
+ # Prepare input for Grounding Image Encoder
144
+ image = self.transform.apply_image(image)
145
+ image_resize = image.shape[:2]
146
+ grounding_enc_image = self.grounding_enc_processor(torch.from_numpy(image).permute(2, 0, 1).contiguous())
147
+ bboxes = None
148
+
149
+ questions, conversations = self.create_conversations(caption, tokens_positive)
150
+ masks = np.stack(masks, axis=0)
151
+ masks = torch.from_numpy(masks)
152
+ label = torch.ones(masks.shape[1:], dtype=torch.long) * self.IGNORE_LABEL
153
+ selected_labels = data_labels
154
+
155
+ return (
156
+ image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize, questions,
157
+ selected_labels)
158
+
159
+
160
+ class GranDfDataset(GCGBaseDataset):
161
+ """
162
+ Human annotated dataset proposed in GLaMM as part of GranDf dataset.
163
+ """
164
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
165
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True):
166
+ self.base_dir = os.path.join(dataset_dir, "GranDf")
167
+ json_path = "GranDf_HA_GCG_train.json"
168
+ image_dir = os.path.join(self.base_dir, "GranDf_HA_images", "train")
169
+ mode = "Val" if validation else "Train"
170
+
171
+ super().__init__(
172
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
173
+ validation, random_sampling, image_dir, json_path, )
174
+ print('\033[92m' + "----GCG-{}: GranDf-GCG dataset initialized----".format(mode) + '\033[0m')
175
+
176
+
177
+ class OpenPsgGCGDataset(GCGBaseDataset):
178
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
179
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True):
180
+ json_files = {'validation': "OpenPsgGCG_val.json", 'training': "OpenPsgGCG_train.json"}
181
+ json_path = json_files['validation'] if validation else json_files['training']
182
+ image_dir = os.path.join("coco_2017", "train2017")
183
+ mode = "Val" if validation else "Train"
184
+
185
+ super().__init__(
186
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
187
+ validation, random_sampling, image_dir, json_path, )
188
+ print('\033[92m' + "----GCG-{}: OpenPSG-GCG dataset initialized----".format(mode) + '\033[0m')
189
+
190
+
191
+ class Flickr30kGCGDataset(GCGBaseDataset):
192
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
193
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True):
194
+ json_files = {'validation': "flickr_mergedGT_GCG_val.json", 'training': "flickr_mergedGT_GCG_train.json"}
195
+ json_path = json_files['validation'] if validation else json_files['training']
196
+ image_dir = os.path.join("flikcr_30k", "train")
197
+ mode = "Val" if validation else "Train"
198
+
199
+ super().__init__(
200
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
201
+ validation, random_sampling, image_dir, json_path, )
202
+ # Filter out images smaller than the minimum size
203
+ self.data_infos = [self.data_infos[i] for i in self._filter_images(min_size=32)]
204
+ self.validation = validation
205
+ print('\033[92m' + "----GCG-{}: Flickr30k-GCG dataset initialized----".format(mode) + '\033[0m')
206
+
207
+ def _load_annotations(self, ann_file):
208
+ # Load annotations and filter out images with very short captions
209
+ self.coco = COCO(ann_file)
210
+ self.image_ids = self.coco.getImgIds()
211
+ data_infos = []
212
+ total_ann_ids = []
213
+ removed_img_count = 0
214
+ for img_id in self.image_ids:
215
+ if len(data_infos) == 1000 and self.validation:
216
+ # Only limited images for validation
217
+ break
218
+ info = self.coco.loadImgs([img_id])[0]
219
+ if len(info['caption'].split(' ')) < 3:
220
+ removed_img_count += 1
221
+ continue
222
+ info['filename'] = info['file_name'].split('_')[-1]
223
+ info['height'] = int(info['height'])
224
+ info['width'] = int(info['width'])
225
+ data_infos.append(info)
226
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
227
+ total_ann_ids.extend(ann_ids)
228
+ assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{ann_file}'!"
229
+ print(f'Removed {removed_img_count} images.')
230
+ return data_infos
231
+
232
+ def _filter_images(self, min_size):
233
+ return [i for i, info in enumerate(self.data_infos) if min(info['width'], info['height']) >= min_size]
234
+
235
+ def _parse_annotations(self, img_info, ann_info):
236
+ annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': img_info['caption'], 'masks': [],
237
+ 'tokens_positive': []}
238
+ for ann in ann_info:
239
+ if ann.get('ignore', False):
240
+ continue
241
+ x1, y1, w, h = ann['bbox']
242
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
243
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
244
+ if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
245
+ continue
246
+ bbox = [x1, y1, x1 + w, y1 + h]
247
+ annotations['bboxes'].append(bbox)
248
+ tokens_positive = ann['tokens_positive']
249
+ gt_label = [img_info['caption'][span[0]:span[1]] for span in tokens_positive]
250
+ annotations['labels'].append(gt_label[0])
251
+ annotations['tokens_positive'].append(tokens_positive[0])
252
+
253
+ rle = ann['sam_mask']
254
+ mask_decoded = mask.decode(rle).astype(np.uint8)
255
+ annotations['masks'].append(mask_decoded)
256
+
257
+ # Convert bounding boxes to numpy arrays
258
+ annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
259
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
260
+ annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[
261
+ 'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32)
262
+
263
+ return annotations
264
+
265
+ def __getitem__(self, index):
266
+ img_info = self.data_infos[index] if (self.validation or not self.random_sampling) \
267
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
268
+ ann_ids = self.coco.getAnnIds(imgIds=img_info['id'])
269
+ ann_info = self.coco.loadAnns(ann_ids)
270
+ image_path = os.path.join(self.image_folder, img_info['file_name'])
271
+ # Parse annotation info
272
+ ann = self._parse_annotations(img_info, ann_info)
273
+ data_item = {"image_path": image_path, "filename": img_info['file_name'], "width": img_info['width'],
274
+ "height": img_info['height'], "bbox": ann['bboxes'], "caption": ann['caption'],
275
+ "labels": ann['labels'], "masks": ann['masks'], "tokens_positive": ann['tokens_positive']}
276
+ return self.process_data(data_item)
277
+
278
+
279
+ class RefCOCOgGCGDataset(GCGBaseDataset):
280
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
281
+ image_size=224, num_classes_per_sample=3, validation=False, random_sampling=True):
282
+ json_files = {'validation': "RefCOCOg_GCG_val.json", 'training': "RefCOCOg_GCG_train.json"}
283
+ json_path = json_files['validation'] if validation else json_files['training']
284
+ image_dir = os.path.join("coco_2014", "train2014")
285
+ mode = "Val" if validation else "Train"
286
+
287
+ super().__init__(
288
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
289
+ validation, random_sampling, image_dir, json_path, )
290
+ print('\033[92m' + "----GCG-{}: RefCOCOg-GCG dataset initialized----".format(mode) + '\033[0m')
291
+
292
+ def _parse_annotations(self, ann_info):
293
+ image_path = os.path.join(self.image_folder, ann_info['img_file_name'])
294
+ annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [],
295
+ 'file_name': ann_info['img_file_name']}
296
+ width, height = Image.open(image_path).size
297
+ orig_caption = ann_info['caption'].strip('"').strip()
298
+ annotations['caption'] = orig_caption.lower()
299
+
300
+ for detail in ann_info['refs']:
301
+ phrase = detail['sentence']
302
+ if phrase.lower() in annotations['caption']:
303
+ annotations['labels'].append(phrase)
304
+ index = annotations['caption'].find(phrase)
305
+ end_index = index + len(phrase) if index != -1 else -1
306
+ annotations['tokens_positive'].append([index, end_index])
307
+
308
+ # Convert segmentation to binary mask
309
+ binary_mask = np.zeros((height, width), dtype=np.uint8)
310
+ for seg in detail["segmentation"]:
311
+ rles = mask.frPyObjects([seg], height, width)
312
+ m = mask.decode(rles)
313
+ m = m.astype(np.uint8)
314
+ binary_mask += m.squeeze()
315
+ annotations['masks'].append(binary_mask)
316
+
317
+ # Sort tokens_positive and corresponding lists
318
+ tokens_positive = annotations['tokens_positive']
319
+ sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0])
320
+ annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices]
321
+ annotations['masks'] = [annotations['masks'][i] for i in sorted_indices]
322
+ annotations['labels'] = [annotations['labels'][i] for i in sorted_indices]
323
+
324
+ # Trimming overlapping intervals
325
+ for i in range(len(tokens_positive)):
326
+ for j in range(i + 1, len(tokens_positive)):
327
+ # If there is overlap
328
+ if tokens_positive[i][1] >= tokens_positive[j][0]:
329
+ # Modify the end index of phrase i to be one less than the start index of phrase j
330
+ tokens_positive[i][1] = tokens_positive[j][0] - 1
331
+ # Modify the phrases to reflect the change in indices
332
+ annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1]
333
+ break # Exit inner loop since i was modified
334
+
335
+ return annotations
336
+
337
+ def __getitem__(self, index):
338
+ while True:
339
+ ann_dict = self.data_infos[index] if (self.validation or not self.random_sampling) \
340
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
341
+ ann_info = next(iter(ann_dict.values()))
342
+ # Parse annotation info
343
+ ann = self._parse_annotations(ann_info)
344
+ image_path = os.path.join(self.image_folder, ann['file_name'])
345
+ # Check if len(gt_phrases) > 0 and if True, break the loop
346
+ if len(ann['labels']) > 0:
347
+ break
348
+ else:
349
+ index = random.randint(0, len(self.data_infos) - 1)
350
+ data_item = {"image_path": image_path, "filename": ann['file_name'], "caption": ann['caption'],
351
+ "labels": ann['labels'], "masks": ann['masks'], "tokens_positive": ann['tokens_positive']}
352
+
353
+ return self.process_data(data_item)
groundingLMM/dataset/region_datasets/Flickr_Region_ds.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from pycocotools.coco import COCO
8
+ from transformers import CLIPImageProcessor
9
+ from model.llava import conversation as conversation_lib
10
+ from model.SAM.utils.transforms import ResizeLongestSide
11
+ from tools.utils import DEFAULT_IMAGE_TOKEN
12
+ from dataset.utils.utils import REGION_QUESTIONS, REGION_GROUP_QUESTIONS
13
+
14
+
15
+ class Flickr30kRegDataset(torch.utils.data.Dataset):
16
+ CLASSES = ('object',)
17
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
18
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
19
+ IMG_SIZE = 1024
20
+ IGNORE_LABEL = 255
21
+
22
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
23
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
24
+ self.epoch_samples = epoch_samples
25
+ self.num_classes_per_sample = num_classes_per_sample
26
+ self.dataset_dir = dataset_dir
27
+ self.image_size = image_size
28
+ self.tokenizer = tokenizer
29
+ self.precision = precision
30
+ self.transform = ResizeLongestSide(image_size)
31
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
32
+ self.max_gt_per_img = max_gt_per_img
33
+ self.validation = validation
34
+ self.random_sampling = random_sampling
35
+
36
+ self.base_dir = os.path.join(dataset_dir, "flikcr_30k")
37
+ self.image_folder = os.path.join(self.base_dir, "flickr30k-images")
38
+ self.ann_file = os.path.join(self.base_dir, "mdetr_annotations", "final_flickr_mergedGT_train.json")
39
+
40
+ self.data_infos = self._load_annotations(self.ann_file)
41
+ self.data_infos = [self.data_infos[i] for i in self._filter_images(min_size=32)]
42
+ self.id_cap_dict = dict()
43
+ self.begin_str = f"The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"
44
+ print('\033[92m' + "----REGION-Train: Loaded Flickr30k dataset ----" + '\033[0m')
45
+
46
+ def _load_annotations(self, ann_file):
47
+ self.coco = COCO(ann_file)
48
+ img_ids = self.coco.getImgIds()
49
+ data_infos = []
50
+ for img_id in img_ids:
51
+ info = self.coco.loadImgs([img_id])[0]
52
+ if len(info['caption'].split(' ')) < 3:
53
+ continue
54
+ info['filename'] = info['file_name'].split('_')[-1]
55
+ info['height'] = int(info['height'])
56
+ info['width'] = int(info['width'])
57
+ data_infos.append(info)
58
+ return data_infos
59
+
60
+ def _filter_images(self, min_size):
61
+ return [i for i, info in enumerate(self.data_infos) if min(info['width'], info['height']) >= min_size]
62
+
63
+ def _parse_annotations(self, img_info, ann_info):
64
+ annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'masks_ann': []}
65
+ self.cat_ids = self.coco.getCatIds(catNms=self.CLASSES)
66
+ self.id_cap_dict = dict()
67
+ self.id_cap_dict[img_info['file_name']] = img_info['caption']
68
+
69
+ for ann in ann_info:
70
+ if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
71
+ continue
72
+ bbox = self._get_valid_bbox(ann['bbox'], img_info['width'], img_info['height'])
73
+ if bbox:
74
+ if ann.get('iscrowd', False):
75
+ annotations['bboxes_ignore'].append(bbox)
76
+ else:
77
+ annotations['bboxes'].append(bbox)
78
+ gt_list = [img_info['caption'][atp[0]:atp[1]] for atp in ann['tokens_positive']]
79
+ annotations['labels'].append(gt_list[0])
80
+ annotations['masks_ann'].append(ann.get('segmentation', None))
81
+
82
+ annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
83
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
84
+ annotations['bboxes_ignore'] = np.zeros((0, 4), dtype=np.float32)
85
+ return annotations
86
+
87
+ def _get_valid_bbox(self, bbox, img_width, img_height):
88
+ x1, y1, w, h = bbox
89
+ inter_w = max(0, min(x1 + w, img_width) - max(x1, 0))
90
+ inter_h = max(0, min(y1 + h, img_height) - max(y1, 0))
91
+ if inter_w * inter_h == 0:
92
+ return None
93
+ return [x1, y1, x1 + w, y1 + h]
94
+
95
+ def __getitem__(self, index):
96
+ img_info = self.data_infos[index] if (self.validation or not self.random_sampling) \
97
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
98
+ ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
99
+ ann = self._parse_annotations(img_info, ann_info)
100
+
101
+ data_item = {
102
+ "image_path": os.path.join(self.image_folder, img_info['file_name']),
103
+ "filename": img_info['file_name'],
104
+ "width": img_info['width'],
105
+ "height": img_info['height'],
106
+ "bbox": ann['bboxes'],
107
+ "caption": img_info['caption'],
108
+ "labels": ann['labels'],
109
+ }
110
+
111
+ return self.process_data(data_item)
112
+
113
+ def __len__(self):
114
+ return len(self.coco.imgs)
115
+
116
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
117
+ x = (x - self.IMG_MEAN) / self.IMG_STD
118
+ h, w = x.shape[-2:]
119
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
120
+ return x
121
+
122
+ def region_enc_processor(self, orig_size, post_size, bboxes, labels, device):
123
+ orig_h, orig_w = orig_size
124
+ post_h, post_w = post_size
125
+ y_scale = post_h / orig_h
126
+ x_scale = post_w / orig_w
127
+ shuffle_ids = torch.randperm(len(labels))
128
+ if len(shuffle_ids) > self.max_gt_per_img:
129
+ shuffle_ids_reg_question = shuffle_ids[:self.max_gt_per_img]
130
+ selected_labels = [labels[i] for i in shuffle_ids_reg_question]
131
+ else:
132
+ selected_labels = [labels[i] for i in shuffle_ids]
133
+ selected_bboxes = bboxes[shuffle_ids]
134
+ # Ensure selected_bboxes is two-dimensional
135
+ if len(selected_bboxes.shape) == 1:
136
+ selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
137
+
138
+ selected_bboxes[:, [0, 2]] *= x_scale
139
+ selected_bboxes[:, [1, 3]] *= y_scale
140
+ selected_bboxes = torch.tensor(selected_bboxes, device=device, dtype=torch.float32) / post_h
141
+ return selected_bboxes, selected_labels
142
+
143
+ def create_conversations(self, labels, caption):
144
+ # DETAILED QUESTION (About all objects - answer is caption
145
+ # (bbox order does not matter because all objects are used)
146
+ questions = []
147
+ question = random.choice(REGION_GROUP_QUESTIONS).strip()
148
+ region_string = ''
149
+ for i in range(len(labels)):
150
+ region_string = region_string + f'region{i + 1} <bbox>,'
151
+ detailed_question = question.replace('<region>', region_string)
152
+ questions.append(detailed_question)
153
+ detailed_answer = caption
154
+
155
+ conversations = []
156
+ conv = conversation_lib.default_conversation.copy()
157
+ conv.messages = []
158
+
159
+ # Start with question of all regions - Create message with roles:
160
+ conv.append_message(conv.roles[0], self.begin_str + detailed_question)
161
+ conv.append_message(conv.roles[1], detailed_answer)
162
+ for i, reg_answer in enumerate(labels):
163
+ reg_question = random.choice(REGION_QUESTIONS).replace('<region>', f'region{i + 1} <bbox>').strip()
164
+ conv.append_message(conv.roles[0], reg_question)
165
+ conv.append_message(conv.roles[1], reg_answer)
166
+ conversations.append(conv.get_prompt())
167
+ return questions, conversations
168
+
169
+ def process_data(self, data_item):
170
+ data_labels = data_item['labels']
171
+ data_bboxes = data_item['bbox']
172
+ caption = data_item['caption']
173
+ image_path = data_item['image_path']
174
+ image = cv2.imread(image_path)
175
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
176
+ orig_h, orig_w = image.shape[:2]
177
+ # Prepare input for Global Image Encoder
178
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
179
+ post_h, post_w = global_enc_image.shape[1:3]
180
+ # Skip input for Grounding Image Encoder
181
+ grounding_enc_image = None
182
+ image_resize = None
183
+ # Prepare input for Region Image Encoder
184
+ bboxes, selected_labels = self.region_enc_processor(
185
+ (orig_h, orig_w), (post_h, post_w), data_bboxes, data_labels, global_enc_image.device
186
+ )
187
+ masks = None
188
+
189
+ questions, conversations = self.create_conversations(selected_labels, caption)
190
+ label = None
191
+
192
+ return (image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
193
+ questions, selected_labels)
groundingLMM/dataset/region_datasets/GranD_ReferringRegion_ds.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import lmdb
4
+ import json
5
+ import numpy as np
6
+ import random
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import CLIPImageProcessor
10
+ from model.llava import conversation as conversation_lib
11
+ from model.SAM.utils.transforms import ResizeLongestSide
12
+ from tools.utils import DEFAULT_IMAGE_TOKEN
13
+ from dataset.utils.utils import REGION_QUESTIONS
14
+
15
+
16
+ class GrandReferRegDataset(torch.utils.data.Dataset):
17
+ CLASSES = ('object',)
18
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
19
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
20
+ IMG_SIZE = 1024
21
+ IGNORE_LABEL = 255
22
+
23
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
24
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
25
+ self.epoch_samples = epoch_samples
26
+ self.num_classes_per_sample = num_classes_per_sample
27
+ self.dataset_dir = dataset_dir
28
+ self.image_size = image_size
29
+ self.tokenizer = tokenizer
30
+ self.precision = precision
31
+ self.transform = ResizeLongestSide(image_size)
32
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
33
+ self.max_gt_per_img = max_gt_per_img
34
+ self.validation = validation
35
+ self.random_sampling = random_sampling
36
+
37
+ # Defining paths
38
+ self.base_dir = os.path.join(dataset_dir, "GranD_Data")
39
+ self.image_folder = os.path.join(self.base_dir, "images")
40
+ ann_file_name = "Grand_Referring_Expression_lmdb"
41
+ ann_path = os.path.join(self.base_dir, ann_file_name)
42
+ self.annos = lmdb.open(ann_path, readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
43
+ mode = "Val" if validation else "Train"
44
+ self.data_infos = self._load_annotations(
45
+ os.path.join(self.base_dir, ann_file_name, f'{ann_file_name}_{mode}.txt')
46
+ )
47
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
48
+ self.question_templates = REGION_QUESTIONS
49
+ print('\033[92m' + "----REGION-{}: GranD Referring Region dataset initialized----".format(mode) + '\033[0m')
50
+
51
+ def _load_annotations(self, ann_file):
52
+ with open(ann_file, 'r') as f:
53
+ data_infos = [line.strip() for line in f if line.strip()]
54
+ data_infos = data_infos[0: 1000] if self.validation else data_infos
55
+ return data_infos
56
+
57
+ def _parse_annotations(self, ann_info):
58
+ annotations = {'bboxes': [], 'labels': []}
59
+ for ann in ann_info:
60
+ bbox = ann['bbox']
61
+ if bbox:
62
+ annotations['bboxes'].append(bbox)
63
+ annotations['labels'].append(ann['attribute'])
64
+
65
+ annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
66
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
67
+ return annotations
68
+
69
+ def __getitem__(self, index):
70
+ image_name = self.data_infos[index] if (self.validation or not self.random_sampling) \
71
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
72
+ image_path = os.path.join(self.image_folder, image_name)
73
+ # Get the annotation from lmdb
74
+ with self.annos.begin() as txn:
75
+ json_contents = txn.get(image_name.encode())
76
+ json_contents = json.loads(json_contents.decode('utf-8'))
77
+ ann_info = json_contents[image_name]
78
+ ann = self._parse_annotations(ann_info)
79
+
80
+ data_item = {
81
+ "image_path": image_path,
82
+ "filename": image_name,
83
+ "bbox": ann['bboxes'],
84
+ "labels": ann['labels'],
85
+ }
86
+
87
+ return self.process_data(data_item)
88
+
89
+ def __len__(self):
90
+ return len(self.coco.imgs)
91
+
92
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = (x - self.IMG_MEAN) / self.IMG_STD
94
+ h, w = x.shape[-2:]
95
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
96
+ return x
97
+
98
+ def region_enc_processor(self, orig_size, post_size, bboxes, labels, device):
99
+ orig_h, orig_w = orig_size
100
+ post_h, post_w = post_size
101
+ y_scale = post_h / orig_h
102
+ x_scale = post_w / orig_w
103
+ shuffle_ids = torch.randperm(len(labels))
104
+ if len(shuffle_ids) > self.max_gt_per_img:
105
+ shuffle_ids_reg_question = shuffle_ids[:self.max_gt_per_img]
106
+ selected_labels = [labels[i] for i in shuffle_ids_reg_question]
107
+ else:
108
+ selected_labels = [labels[i] for i in shuffle_ids]
109
+ selected_bboxes = bboxes[shuffle_ids]
110
+ # Ensure selected_bboxes is two-dimensional
111
+ if len(selected_bboxes.shape) == 1:
112
+ selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
113
+
114
+ selected_bboxes[:, [0, 2]] *= x_scale
115
+ selected_bboxes[:, [1, 3]] *= y_scale
116
+ selected_bboxes = torch.tensor(selected_bboxes, device=device, dtype=torch.float32) / post_h
117
+ return selected_bboxes, selected_labels
118
+
119
+ def create_conversations(self, labels, question_templates):
120
+ questions = []
121
+ answers = []
122
+ for i, label in enumerate(labels):
123
+ question = random.choice(question_templates).strip().replace('<region>', f'region{i + 1} <bbox>')
124
+ questions.append(question)
125
+ answers.append(label)
126
+
127
+ conversations = []
128
+ conv = conversation_lib.default_conversation.copy()
129
+ conv.messages = []
130
+ for i, (question, answer) in enumerate(zip(questions, answers)):
131
+ if i == 0:
132
+ question = self.begin_str + question
133
+ conv.append_message(conv.roles[0], question)
134
+ conv.append_message(conv.roles[1], answer)
135
+ conversations.append(conv.get_prompt())
136
+ return questions, conversations
137
+
138
+ def process_data(self, data_item):
139
+ data_labels = data_item['labels']
140
+ data_bboxes = data_item['bbox']
141
+
142
+ image_path = data_item['image_path']
143
+ image = cv2.imread(image_path)
144
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
145
+ orig_h, orig_w = image.shape[:2]
146
+ # Prepare input for Global Image Encoder
147
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
148
+ post_h, post_w = global_enc_image.shape[1:3]
149
+ # Skip input for Grounding Image Encoder
150
+ grounding_enc_image = None
151
+ image_resize = None
152
+ # Prepare input for Region Image Encoder
153
+ bboxes, selected_labels = self.region_enc_processor(
154
+ (orig_h, orig_w), (post_h, post_w), data_bboxes, data_labels, global_enc_image.device
155
+ )
156
+ masks = None
157
+
158
+ questions, conversations = self.create_conversations(selected_labels, question_templates=self.question_templates)
159
+ label = None
160
+
161
+ return (image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
162
+ questions, selected_labels)
groundingLMM/dataset/region_datasets/RefCOCO_VG_Region_ds.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ from pycocotools.coco import COCO
7
+ import torch.nn.functional as F
8
+ from transformers import CLIPImageProcessor
9
+ from model.llava import conversation as conversation_lib
10
+ from model.SAM.utils.transforms import ResizeLongestSide
11
+ from tools.utils import DEFAULT_IMAGE_TOKEN
12
+ from dataset.utils.utils import REGION_QUESTIONS
13
+
14
+
15
+ class RegionBaseDataset(torch.utils.data.Dataset):
16
+ CLASSES = ('object',)
17
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
18
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
19
+ IMG_SIZE = 1024
20
+ IGNORE_LABEL = 255
21
+
22
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
23
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, dataset_name='',
24
+ image_dir='', json_path='', intro_string='', question_templates=None, random_sampling=True):
25
+ self.epoch_samples = epoch_samples
26
+ self.num_classes_per_sample = num_classes_per_sample
27
+ self.dataset_dir = dataset_dir
28
+ self.image_size = image_size
29
+ self.tokenizer = tokenizer
30
+ self.precision = precision
31
+ self.transform = ResizeLongestSide(image_size)
32
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
33
+ self.max_gt_per_img = max_gt_per_img
34
+ self.validation = validation
35
+ self.random_sampling = random_sampling
36
+
37
+ # Dataset type specific
38
+ self.begin_str = intro_string
39
+ self.base_dir = os.path.join(dataset_dir, dataset_name)
40
+ self.ann_file = os.path.join(self.base_dir, json_path)
41
+ self.question_templates = question_templates
42
+ self.image_folder = os.path.join(self.base_dir, image_dir)
43
+
44
+ self.data_infos = self._load_annotations(self.ann_file)
45
+ self.data_infos = [self.data_infos[i] for i in self._filter_images(min_size=32)]
46
+
47
+ def _load_annotations(self, ann_file):
48
+ self.coco = COCO(ann_file)
49
+ img_ids = self.coco.getImgIds()
50
+ data_infos = []
51
+ for img_id in img_ids:
52
+ if self.validation and len(data_infos) == 1000:
53
+ # limited images during validation
54
+ break
55
+ info = self.coco.loadImgs([img_id])[0]
56
+ info['filename'] = info['file_name'].split('_')[-1]
57
+ info['height'] = int(info['height'])
58
+ info['width'] = int(info['width'])
59
+ data_infos.append(info)
60
+ return data_infos
61
+
62
+ def _filter_images(self, min_size):
63
+ return [i for i, info in enumerate(self.data_infos) if min(info['width'], info['height']) >= min_size]
64
+
65
+ def _parse_annotations(self, img_info, ann_info):
66
+ annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'masks_ann': [],
67
+ 'seg_map': img_info['file_name'].replace('jpg', 'png')}
68
+
69
+ for ann in ann_info:
70
+ if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
71
+ continue
72
+ bbox = self._get_valid_bbox(ann['bbox'], img_info['width'], img_info['height'])
73
+ if bbox:
74
+ annotations['bboxes'].append(bbox)
75
+ annotations['labels'].append(img_info['caption'].strip())
76
+
77
+ annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
78
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
79
+ annotations['bboxes_ignore'] = np.zeros((0, 4), dtype=np.float32)
80
+ return annotations
81
+
82
+ def _get_valid_bbox(self, bbox, img_width, img_height):
83
+ x1, y1, w, h = bbox
84
+ inter_w = max(0, min(x1 + w, img_width) - max(x1, 0))
85
+ inter_h = max(0, min(y1 + h, img_height) - max(y1, 0))
86
+ if inter_w * inter_h == 0:
87
+ return None
88
+ return [x1, y1, x1 + w, y1 + h]
89
+
90
+ def __getitem__(self, index):
91
+ img_info = self.data_infos[index] if (self.validation or not self.random_sampling) \
92
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
93
+ ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
94
+ ann = self._parse_annotations(img_info, ann_info)
95
+
96
+ data_item = {
97
+ "image_path": os.path.join(self.image_folder, img_info['file_name']),
98
+ "width": img_info['width'],
99
+ "height": img_info['height'],
100
+ "bbox": ann['bboxes'],
101
+ "caption": img_info['caption'],
102
+ "labels": ann['labels'],
103
+ "seg_map": ann['seg_map'],
104
+ }
105
+
106
+ return self.process_data(data_item)
107
+
108
+ def __len__(self):
109
+ return len(self.data_infos)
110
+
111
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
112
+ x = (x - self.IMG_MEAN) / self.IMG_STD
113
+ h, w = x.shape[-2:]
114
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
115
+ return x
116
+
117
+ def region_enc_processor(self, orig_size, post_size, bboxes, labels, device):
118
+ orig_h, orig_w = orig_size
119
+ post_h, post_w = post_size
120
+ y_scale = post_h / orig_h
121
+ x_scale = post_w / orig_w
122
+ shuffle_ids = torch.randperm(len(labels))[:self.max_gt_per_img]
123
+ selected_bboxes = bboxes[shuffle_ids]
124
+
125
+ # Ensure selected_bboxes is two-dimensional
126
+ if len(selected_bboxes.shape) == 1:
127
+ selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
128
+
129
+ selected_labels = [labels[i] for i in shuffle_ids]
130
+ selected_bboxes[:, [0, 2]] *= x_scale
131
+ selected_bboxes[:, [1, 3]] *= y_scale
132
+ selected_bboxes = torch.tensor(selected_bboxes, device=device, dtype=torch.float32) / post_h
133
+ return selected_bboxes, selected_labels
134
+
135
+ def create_conversations(self, labels, question_templates):
136
+ questions = []
137
+ answers = []
138
+ for i, label in enumerate(labels):
139
+ question = random.choice(question_templates).strip().replace('<region>', f'region{i + 1} <bbox>')
140
+ questions.append(question)
141
+ answers.append(label)
142
+
143
+ conversations = []
144
+ conv = conversation_lib.default_conversation.copy()
145
+ conv.messages = []
146
+ for i, (question, answer) in enumerate(zip(questions, answers)):
147
+ if i == 0:
148
+ question = self.begin_str + question
149
+ conv.append_message(conv.roles[0], question)
150
+ conv.append_message(conv.roles[1], answer)
151
+ conversations.append(conv.get_prompt())
152
+ return questions, conversations
153
+
154
+ def process_data(self, data_item):
155
+ data_labels = data_item['labels']
156
+ data_bboxes = data_item['bbox']
157
+
158
+ image_path = data_item['image_path']
159
+ image = cv2.imread(image_path)
160
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
161
+ orig_h, orig_w = image.shape[:2]
162
+ # Prepare input for Global Image Encoder
163
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
164
+ post_h, post_w = global_enc_image.shape[1:3]
165
+ # Skip input for Grounding Image Encoder
166
+ grounding_enc_image = None
167
+ image_resize = None
168
+ # Prepare input for Region Image Encoder
169
+ bboxes, selected_labels = self.region_enc_processor((orig_h, orig_w), (post_h, post_w), data_bboxes, data_labels,
170
+ global_enc_image.device)
171
+ masks = None
172
+
173
+ questions, conversations = self.create_conversations(
174
+ selected_labels, question_templates=self.question_templates
175
+ )
176
+ label = None
177
+
178
+ return (image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
179
+ questions, selected_labels)
180
+
181
+
182
+ class RefCocoRegDataset(RegionBaseDataset):
183
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
184
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
185
+ intro_string = DEFAULT_IMAGE_TOKEN + "\n" + ("I will provide you with only one region containing only one "
186
+ "object, although there may be other objects present in the "
187
+ "image. It is recommended that you describe the object's "
188
+ "relative position with respect to other objects in the image, "
189
+ "as well as its position within the image and its basic "
190
+ "attributes.")
191
+ json_path = os.path.join("mdetr_annotations", "finetune_refcoco_train.json")
192
+ dataset_name = "RefCoco_Reg"
193
+ image_dir = "coco_2014"
194
+ question_templates = ['<region>',]
195
+ mode = "Val" if validation else "Train"
196
+
197
+ super().__init__(
198
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
199
+ max_gt_per_img, validation, dataset_name, image_dir, json_path,
200
+ intro_string, question_templates, random_sampling
201
+ )
202
+ print('\033[92m' + "----REGION-{}: Loaded RefCOCO dataset ----".format(mode) + '\033[0m')
203
+
204
+
205
+ class RefCocoGRegDataset(RegionBaseDataset):
206
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
207
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
208
+ intro_string = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
209
+ dataset_name = "RefCoco_Reg"
210
+ json_files = {'validation': "finetune_refcocog_val.json", 'training': "finetune_refcocog_train.json"}
211
+ json_path = os.path.join("mdetr_annotations", json_files['validation'] if validation else json_files['training'])
212
+ image_dir = "coco_2014"
213
+ question_templates = REGION_QUESTIONS
214
+ mode = "Val" if validation else "Train"
215
+
216
+ super().__init__(
217
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
218
+ max_gt_per_img, validation, dataset_name, image_dir, json_path,
219
+ intro_string, question_templates, random_sampling
220
+ )
221
+ print('\033[92m' + "----REGION-{}: Loaded RefCOCO-G dataset ----".format(mode) + '\033[0m')
222
+
223
+
224
+ class RefCocoPRegDataset(RegionBaseDataset):
225
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
226
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
227
+ intro_string = DEFAULT_IMAGE_TOKEN + "\n" + ("I will provide you with only one region containing only one "
228
+ "object, although there may be other objects present in the "
229
+ "image. It is recommended that you describe the object's "
230
+ "relative position with respect to other objects in the image, "
231
+ "as well as its position within the image and its basic "
232
+ "attributes.")
233
+ dataset_name = "RefCoco_Reg"
234
+ json_files = {'validation': "finetune_refcoco+_val.json", 'training': "finetune_refcoco+_train.json"}
235
+ json_path = os.path.join(
236
+ "mdetr_annotations", json_files['validation'] if validation else json_files['training']
237
+ )
238
+ image_dir = "coco_2014"
239
+ question_templates = ['<region>', ]
240
+ mode = "Val" if validation else "Train"
241
+
242
+ super().__init__(
243
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
244
+ max_gt_per_img, validation, dataset_name, image_dir, json_path,
245
+ intro_string, question_templates, random_sampling
246
+ )
247
+ print('\033[92m' + "----REGION-{}: Loaded RefCOCO-P dataset ----".format(mode) + '\033[0m')
248
+
249
+
250
+ class VisualGenomeRegDataset(RegionBaseDataset):
251
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=8000, precision="fp32",
252
+ image_size=224, num_classes_per_sample=3, max_gt_per_img=10, validation=False, random_sampling=True):
253
+ intro_string = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
254
+ dataset_name = "visual_genome"
255
+ json_files = {'validation': "test_caption.json", 'training': "train.json"}
256
+ json_path = json_files['validation'] if validation else json_files['training']
257
+ image_dir = "images"
258
+ question_templates = REGION_QUESTIONS
259
+ mode = "Val" if validation else "Train"
260
+
261
+ super().__init__(
262
+ dataset_dir, tokenizer, global_image_encoder, epoch_samples, precision, image_size, num_classes_per_sample,
263
+ max_gt_per_img, validation, dataset_name, image_dir, json_path,
264
+ intro_string, question_templates, random_sampling
265
+ )
266
+ print('\033[92m' + "----REGION-{}: Loaded VisualGenome dataset ----".format(mode) + '\033[0m')
267
+
268
+ def _parse_annotations(self, img_info, ann_info):
269
+ annotations = {'bboxes': [], 'labels': [], }
270
+
271
+ for ann in ann_info:
272
+ if ann.get('ignore', False):
273
+ continue
274
+ # Check for valid area and dimensions
275
+ if ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
276
+ continue
277
+ bbox = self._get_valid_bbox(ann['bbox'], img_info['width'], img_info['height'])
278
+ if bbox:
279
+ annotations['bboxes'].append(bbox)
280
+ annotations['labels'].append(ann['caption'].strip())
281
+
282
+ annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
283
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
284
+ return annotations
285
+
286
+ def __getitem__(self, index):
287
+ img_info = self.data_infos[index] if (self.validation or not self.random_sampling) \
288
+ else self.data_infos[random.randint(0, len(self.data_infos) - 1)]
289
+ ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
290
+ ann = self._parse_annotations(img_info, ann_info)
291
+
292
+ data_item = {
293
+ "image_path": os.path.join(self.image_folder, img_info['file_name']),
294
+ "width": img_info['width'],
295
+ "height": img_info['height'],
296
+ "bbox": ann['bboxes'],
297
+ "labels": ann['labels'],
298
+ }
299
+
300
+ return self.process_data(data_item)
groundingLMM/dataset/segm_datasets/GranD_ReferringSegm_ds.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import lmdb
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from pycocotools import mask
10
+ from transformers import CLIPImageProcessor
11
+ from model.llava import conversation as conversation_lib
12
+ from model.SAM.utils.transforms import ResizeLongestSide
13
+ from tools.utils import DEFAULT_IMAGE_TOKEN
14
+ from dataset.utils.utils import ANSWER_LIST, SEG_QUESTIONS
15
+
16
+
17
+ class GrandReferSegmDataset(torch.utils.data.Dataset):
18
+ CLASSES = ('object',)
19
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
20
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
21
+ IMG_SIZE = 1024
22
+ IGNORE_LABEL = 255
23
+
24
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=500 * 8 * 2 * 10,
25
+ precision: str = "fp32", image_size: int = 224, num_classes_per_sample: int = 3,
26
+ validation=False, split='train', random_sampling=True, inference=False):
27
+ self.epoch_samples = epoch_samples
28
+ self.num_classes_per_sample = num_classes_per_sample
29
+
30
+ self.dataset_dir = dataset_dir
31
+ self.image_size = image_size
32
+ self.tokenizer = tokenizer
33
+ self.precision = precision
34
+ self.transform = ResizeLongestSide(image_size)
35
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
36
+
37
+ self.question_templates = SEG_QUESTIONS
38
+ self.answer_list = ANSWER_LIST
39
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
40
+ self.validation = validation
41
+ self.random_sampling = random_sampling
42
+ # Defining paths
43
+ self.base_dir = os.path.join(dataset_dir, "GranD_Data")
44
+ self.image_folder = os.path.join(self.base_dir, "images")
45
+ ann_file_name = "Grand_Referring_Expression_lmdb"
46
+ ann_path = os.path.join(self.base_dir, ann_file_name)
47
+ self.annos = lmdb.open(ann_path, readonly=True, max_readers=1, lock=False, readahead=False, meminit=False)
48
+ mode = "Val" if validation else "Train"
49
+ self.data_infos = self._load_annotations(
50
+ os.path.join(self.base_dir, ann_file_name, f'{ann_file_name}_{mode}.txt')
51
+ )
52
+ print('\033[92m' + "----SEGM-{}: GranD Referring Segm dataset initialized----".format(mode) + '\033[0m')
53
+
54
+ def _load_annotations(self, ann_file):
55
+ with open(ann_file, 'r') as f:
56
+ data_infos = [line.strip() for line in f if line.strip()]
57
+ data_infos = data_infos[0: 1000] if self.validation else data_infos
58
+ return data_infos
59
+
60
+ def __len__(self):
61
+ return len(self.data_infos)
62
+
63
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
64
+ x = (x - self.IMG_MEAN) / self.IMG_STD
65
+ h, w = x.shape[-2:]
66
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
67
+ return x
68
+
69
+ def create_conversations(self, labels, questions):
70
+ questions = []
71
+ answers = []
72
+ for i, label in enumerate(labels):
73
+ question = random.choice(questions)
74
+ questions.append(question)
75
+ answers.append(label)
76
+
77
+ conversations = []
78
+ conv = conversation_lib.default_conversation.copy()
79
+ conv.messages = []
80
+ for i, (question, answer) in enumerate(zip(questions, answers)):
81
+ if i == 0:
82
+ question = self.begin_str + question
83
+ conv.append_message(conv.roles[0], question)
84
+ conv.append_message(conv.roles[1], answer)
85
+ conversations.append(conv.get_prompt())
86
+ return questions, conversations
87
+
88
+ def _parse_annotations(self, ann_info):
89
+ annotations = {'masks': [], 'labels': []}
90
+ for ann in ann_info:
91
+ rle = ann.get("segmentation")
92
+ if rle:
93
+ m = mask.decode(rle)
94
+ m = m.astype(np.uint8)
95
+ annotations['masks'].append(m)
96
+ annotations['labels'].append(ann['attribute'])
97
+
98
+ annotations['bboxes'] = np.array(annotations['masks'], dtype=np.float32) if annotations[
99
+ 'bboxes'] else np.zeros((0, 4), dtype=np.float32)
100
+ return annotations
101
+
102
+ def __getitem__(self, idx):
103
+ image_name = self.data_infos[idx] if (self.validation or not self.random_sampling) else self.data_infos[
104
+ random.randint(0, len(self.data_infos) - 1)]
105
+ image_path = os.path.join(self.image_folder, image_name)
106
+ # Get the annotation from lmdb
107
+ with self.annos.begin() as txn:
108
+ json_contents = txn.get(image_name.encode())
109
+ json_contents = json.loads(json_contents.decode('utf-8'))
110
+ ann_info = json_contents[image_name]
111
+ print(image_path)
112
+ ann = self._parse_annotations(ann_info)
113
+ data_item = {"image_path": image_path,
114
+ "filename": image_name,
115
+ "bbox": ann['bboxes'],
116
+ "labels": ann['labels'], }
117
+
118
+ return self.process_data(data_item)
119
+
120
+ def process_data(self, data_item):
121
+ data_labels = data_item['labels']
122
+ data_masks = data_item['maks']
123
+
124
+ image_path = data_item['image_path']
125
+ image = cv2.imread(image_path)
126
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
127
+ # Prepare input for Global Image Encoder
128
+ global_enc_image = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
129
+ # Prepare input for Grounding Image Encoder
130
+ image = self.transform.apply_image(image)
131
+ image_resize = image.shape[:2]
132
+ grounding_enc_image = self.grounding_enc_processor(torch.from_numpy(image).permute(2, 0, 1).contiguous())
133
+
134
+ # Prepare input for Segmentation module
135
+ shuffle_ids = torch.randperm(len(data_labels))
136
+ if len(shuffle_ids) > self.max_gt_per_img:
137
+ shuffle_ids_segm_question = shuffle_ids[:self.max_gt_per_img]
138
+ selected_labels = [data_labels[i] for i in shuffle_ids_segm_question]
139
+ else:
140
+ selected_labels = [data_labels[i] for i in shuffle_ids]
141
+ selected_masks = data_masks[shuffle_ids]
142
+
143
+ masks = np.stack(selected_masks, axis=0)
144
+ masks = torch.from_numpy(masks)
145
+
146
+ if len(data_labels) == 0:
147
+ print(image_path)
148
+
149
+ questions, conversations = self.create_conversations(
150
+ selected_labels, self.question_templates)
151
+ label = torch.ones(grounding_enc_image.shape[1], grounding_enc_image.shape[2]) * self.IGNORE_LABEL
152
+ bboxes = None
153
+
154
+ return (
155
+ image_path, global_enc_image, grounding_enc_image, bboxes, conversations, masks, label, image_resize,
156
+ questions, selected_labels)
groundingLMM/dataset/segm_datasets/RefCOCO_Segm_ds.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from pycocotools import mask
8
+ from transformers import CLIPImageProcessor
9
+ from model.llava import conversation as conversation_lib
10
+ from model.SAM.utils.transforms import ResizeLongestSide
11
+ from dataset.utils.grefer import G_REFER
12
+ from dataset.utils.refcoco_refer import REFER
13
+ from tools.utils import DEFAULT_IMAGE_TOKEN
14
+ from dataset.utils.utils import ANSWER_LIST, SEG_QUESTIONS
15
+
16
+
17
+ class ReferSegmDataset(torch.utils.data.Dataset):
18
+ CLASSES = ('object',)
19
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
20
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
21
+ IMG_SIZE = 1024
22
+ IGNORE_LABEL = 255
23
+
24
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=500 * 8 * 2 * 10,
25
+ precision: str = "fp32", image_size: int = 224, num_classes_per_sample: int = 3,
26
+ refer_segm_data="refcoco||refcoco+||refcocog||refclef", validation=False, split='train',
27
+ random_sampling=True, inference=False):
28
+ self.epoch_samples = epoch_samples
29
+ self.num_classes_per_sample = num_classes_per_sample
30
+
31
+ self.dataset_dir = dataset_dir
32
+ self.image_size = image_size
33
+ self.tokenizer = tokenizer
34
+ self.precision = precision
35
+ self.transform = ResizeLongestSide(image_size)
36
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
37
+
38
+ self.question_templates = SEG_QUESTIONS
39
+ self.answer_list = ANSWER_LIST
40
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
41
+ self.validation = validation
42
+ self.split = split
43
+ self.initialize_refer_segm_data(refer_segm_data, inference)
44
+ self.random_sampling = random_sampling
45
+
46
+ def initialize_refer_segm_data(self, refer_segm_data, inference=False):
47
+
48
+ dataset_dir = os.path.join(self.dataset_dir, "Refer_Segm")
49
+ self.refer_seg_ds_list = refer_segm_data.split("||")
50
+ # ['refclef', 'refcoco', 'refcoco+', 'refcocog']
51
+ self.refer_segm_data = {}
52
+
53
+ for dataset_name in self.refer_seg_ds_list:
54
+ splitBy = "umd" if dataset_name == "refcocog" else "unc"
55
+ refer_api = G_REFER(dataset_dir, dataset_name, splitBy) if dataset_name == "grefcoco" else\
56
+ REFER(dataset_dir, dataset_name, splitBy)
57
+ ref_ids_train = refer_api.getRefIds(split=self.split)
58
+ images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
59
+ refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
60
+ refer_seg_ds = {
61
+ "images": self.load_images(refer_api, images_ids_train, dataset_dir, dataset_name, inference=inference),
62
+ "annotations": refer_api.Anns,
63
+ "img2refs": self.create_img_to_refs_mapping(refs_train)
64
+ }
65
+
66
+ print(f"dataset {dataset_name} (refs {splitBy}) ({self.split} split) has {len(refer_seg_ds['images'])} "
67
+ f"images and {len(refer_seg_ds['annotations'])} annotations.")
68
+ print(f'\033[92m----SEG-{"Val" if self.validation else "Train"}:'
69
+ f' Loaded ReferSeg - {dataset_name} dataset ----\033[0m')
70
+
71
+ self.refer_segm_data[dataset_name] = refer_seg_ds
72
+
73
+ def load_images(self, refer_api, images_ids_train, dataset_dir, dataset_name, inference=False):
74
+ images = []
75
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
76
+ # Limiting images to 1000(optional) for validation
77
+ loaded_images = loaded_images[:1000] if (self.validation and not inference) else loaded_images
78
+ for item in loaded_images:
79
+ item = item.copy()
80
+ if dataset_name == 'refclef':
81
+ item["file_name"] = os.path.join(dataset_dir, "images", "saiapr_tc-12", item["file_name"])
82
+ else:
83
+ item["file_name"] = os.path.join(dataset_dir.replace("Refer_Segm/", ""), "coco_2014/train2014",
84
+ item["file_name"])
85
+ images.append(item)
86
+ return images
87
+
88
+ def create_img_to_refs_mapping(self, refs_train):
89
+ img2refs = {}
90
+ for ref in refs_train:
91
+ img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
92
+ return img2refs
93
+
94
+ def __len__(self):
95
+ return self.epoch_samples
96
+
97
+ def _set_len(self, length):
98
+ self.epoch_samples = length
99
+
100
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
101
+ x = (x - self.IMG_MEAN) / self.IMG_STD
102
+ h, w = x.shape[-2:]
103
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
104
+ return x
105
+
106
+ def create_conversations(self, labels):
107
+ questions = []
108
+ answers = []
109
+ for i, label in enumerate(labels):
110
+ label = label.strip()
111
+ assert len(label.split("||")) == 1
112
+ question_template = random.choice(self.question_templates)
113
+ questions.append(question_template.format(class_name=label.lower()))
114
+ answers.append(random.choice(self.answer_list))
115
+
116
+ conversations = []
117
+ conv = conversation_lib.default_conversation.copy()
118
+ conv.messages = []
119
+ for i, (question, answer) in enumerate(zip(questions, answers)):
120
+ if i == 0:
121
+ question = self.begin_str + question
122
+ conv.append_message(conv.roles[0], question)
123
+ conv.append_message(conv.roles[1], answer)
124
+ conversations.append(conv.get_prompt())
125
+ return questions, conversations
126
+
127
+ def __getitem__(self, idx):
128
+ dataset_idx = random.randint(0, len(self.refer_seg_ds_list) - 1)
129
+ dataset_name = self.refer_seg_ds_list[dataset_idx]
130
+ refer_seg_ds = self.refer_segm_data[dataset_name]
131
+ images = refer_seg_ds["images"]
132
+ annotations = refer_seg_ds["annotations"]
133
+ img2refs = refer_seg_ds["img2refs"]
134
+ idx = idx if (self.validation or not self.random_sampling) else random.randint(0, len(images) - 1)
135
+ image_info = images[idx]
136
+ image_id = image_info["id"]
137
+ refs = img2refs[image_id]
138
+ if len(refs) == 0:
139
+ return self.__getitem__(0)
140
+
141
+ sents = []
142
+ ann_ids = []
143
+ for ref in refs:
144
+ for sent in ref["sentences"]:
145
+ text = sent["sent"]
146
+ sents.append(text)
147
+ ann_ids.append(ref["ann_id"])
148
+ if len(sents) >= self.num_classes_per_sample:
149
+ sampled_inds = np.random.choice(
150
+ list(range(len(sents))), size=self.num_classes_per_sample, replace=False
151
+ )
152
+ else:
153
+ sampled_inds = list(range(len(sents)))
154
+ sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
155
+ # sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
156
+ sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds]
157
+ selected_labels = sampled_sents
158
+
159
+ # Load and process the image
160
+ image_path = image_info["file_name"]
161
+ image = cv2.imread(image_path)
162
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
163
+ global_enc_img = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
164
+ image = self.transform.apply_image(image)
165
+ image_resize = image.shape[:2]
166
+ grounding_enc_img = self.grounding_enc_processor(torch.from_numpy(image).permute(2, 0, 1).contiguous())
167
+
168
+ # Generate questions and answers
169
+ questions, conversations = self.create_conversations(selected_labels)
170
+
171
+ flag = False
172
+ masks = []
173
+ for ann_id in sampled_ann_ids:
174
+ if isinstance(ann_id, list):
175
+ flag = True
176
+ if -1 in ann_id:
177
+ assert len(ann_id) == 1
178
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
179
+ np.uint8
180
+ )
181
+ else:
182
+ m_final = np.zeros(
183
+ (image_info["height"], image_info["width"])
184
+ ).astype(np.uint8)
185
+ for ann_id_i in ann_id:
186
+ ann = annotations[ann_id_i]
187
+
188
+ if len(ann["segmentation"]) == 0:
189
+ m = np.zeros(
190
+ (image_info["height"], image_info["width"])
191
+ ).astype(np.uint8)
192
+ else:
193
+ if type(ann["segmentation"][0]) == list: # polygon
194
+ rle = mask.frPyObjects(
195
+ ann["segmentation"], image_info["height"], image_info["width"], )
196
+ else:
197
+ rle = ann["segmentation"]
198
+ for i in range(len(rle)):
199
+ if not isinstance(rle[i]["counts"], bytes):
200
+ rle[i]["counts"] = rle[i]["counts"].encode()
201
+ m = mask.decode(rle)
202
+ m = np.sum(
203
+ m, axis=2
204
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
205
+ m = m.astype(np.uint8) # convert to np.uint8
206
+ m_final = m_final | m
207
+ m = m_final
208
+ masks.append(m)
209
+ continue
210
+
211
+ ann = annotations[ann_id]
212
+
213
+ if len(ann["segmentation"]) == 0:
214
+ m = np.zeros((image_info["height"], image_info["width"])).astype(
215
+ np.uint8
216
+ )
217
+ masks.append(m)
218
+ continue
219
+
220
+ if type(ann["segmentation"][0]) == list: # polygon
221
+ rle = mask.frPyObjects(
222
+ ann["segmentation"], image_info["height"], image_info["width"]
223
+ )
224
+ else:
225
+ rle = ann["segmentation"]
226
+ for i in range(len(rle)):
227
+ if not isinstance(rle[i]["counts"], bytes):
228
+ rle[i]["counts"] = rle[i]["counts"].encode()
229
+ m = mask.decode(rle)
230
+ m = np.sum(m, axis=2) # sometimes there are multiple binary map (corresponding to multiple segs)
231
+ m = m.astype(np.uint8) # convert to np.uint8
232
+ masks.append(m)
233
+
234
+ masks = np.stack(masks, axis=0)
235
+
236
+ masks = torch.from_numpy(masks)
237
+ label = torch.ones(masks.shape[1], masks.shape[2]) * self.IGNORE_LABEL
238
+ # set bboxes to None for segmentation datasets
239
+ bboxes = None
240
+
241
+ return (image_path, global_enc_img, grounding_enc_img, bboxes, conversations, masks, label,
242
+ image_resize, questions, selected_labels)
groundingLMM/dataset/segm_datasets/Semantic_Segm_ds.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import json
5
+ import random
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from pycocotools.coco import COCO
11
+ from transformers import CLIPImageProcessor
12
+ from model.llava import conversation as conversation_lib
13
+ from model.SAM.utils.transforms import ResizeLongestSide
14
+ from tools.utils import DEFAULT_IMAGE_TOKEN
15
+ from dataset.utils.utils import ANSWER_LIST, SEG_QUESTIONS
16
+
17
+
18
+ def load_json_file(file_path):
19
+ with open(file_path, 'r') as file:
20
+ return json.load(file)
21
+
22
+
23
+ def init_ade20k(dataset_dir):
24
+ ade20k_classes = load_json_file("dataset/utils/ade20k_classes.json")
25
+ ade20k_image_dir = os.path.join(dataset_dir, "ade20k", "images", "training")
26
+ ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if
27
+ img.endswith('.jpg')]
28
+ ade20k_labels = [img.replace(".jpg", ".png").replace("images", "annotations") for img in ade20k_images]
29
+ return np.array(ade20k_classes), ade20k_images, ade20k_labels
30
+
31
+
32
+ def init_cocostuff(dataset_dir):
33
+ with open("dataset/utils/cocostuff_classes.txt") as file:
34
+ cocostuff_classes = [line.strip().split(": ")[-1] for line in file.readlines()[1:]]
35
+ # Annotations
36
+ cocostuff_labels = glob.glob(os.path.join(dataset_dir, "cocostuff", "train2017", "*.png"))
37
+ # Images are obtained from COCO 2017 images
38
+ cocostuff_images = [label.replace(".png", ".jpg").replace("cocostuff", "coco_2017").replace("Semantic_Segm/", "") for
39
+ label in cocostuff_labels]
40
+ return np.array(cocostuff_classes), cocostuff_images, cocostuff_labels
41
+
42
+
43
+ def init_paco_lvis(dataset_dir):
44
+ paco_lvis_api = COCO(os.path.join(dataset_dir, "paco_lvis", "annotations", "paco_lvis_v1_train.json"))
45
+ all_classes = paco_lvis_api.loadCats(paco_lvis_api.getCatIds())
46
+ class_map_paco_lvis = {}
47
+
48
+ for cat in all_classes:
49
+ cat_split = cat["name"].strip().split(":")
50
+ if len(cat_split) == 1:
51
+ name = cat_split[0].split("_(")[0]
52
+ else:
53
+ assert len(cat_split) == 2
54
+ obj, part = cat_split
55
+ obj = obj.split("_(")[0]
56
+ part = part.split("_(")[0]
57
+ name = (obj, part)
58
+ class_map_paco_lvis[cat["id"]] = name
59
+
60
+ img_ids = paco_lvis_api.getImgIds()
61
+ return class_map_paco_lvis, img_ids, paco_lvis_api
62
+
63
+
64
+ def init_pascal_part(dataset_dir):
65
+ pascal_part_api = COCO(os.path.join(dataset_dir, "pascal_part", "train.json"))
66
+ all_classes = pascal_part_api.loadCats(pascal_part_api.getCatIds())
67
+ class_map_pascal_part = {}
68
+ for cat in all_classes:
69
+ cat_main, cat_part = cat["name"].strip().split(":")
70
+ name = (cat_main, cat_part)
71
+ class_map_pascal_part[cat["id"]] = name
72
+ img_ids = pascal_part_api.getImgIds()
73
+ return class_map_pascal_part, img_ids, pascal_part_api
74
+
75
+
76
+ def init_mapillary(dataset_dir):
77
+ mapillary_path = os.path.join(dataset_dir, "mapillary")
78
+ mapillary_classes = [cls["readable"].lower() for cls in
79
+ load_json_file(os.path.join(mapillary_path, "config_v2.0.json"))["labels"]]
80
+ mapillary_labels = sorted(glob.glob(os.path.join(mapillary_path, "training", "v2.0", "labels", "*.png")))
81
+ mapillary_images = [label.replace(".png", ".jpg").replace("v2.0/labels", "images") for label in mapillary_labels]
82
+ return np.array(mapillary_classes), mapillary_images, mapillary_labels
83
+
84
+
85
+ class SemanticSegmDataset(torch.utils.data.Dataset):
86
+ CLASSES = ('object',)
87
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
88
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
89
+ IMG_SIZE = 1024
90
+ IGNORE_LABEL = 255
91
+
92
+ def __init__(self, dataset_dir, tokenizer, global_image_encoder, epoch_samples=500 * 8 * 2 * 10,
93
+ precision: str = "fp32", image_size: int = 224, num_classes_per_sample: int = 3,
94
+ semantic_segm_data="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", validation=False,
95
+ random_sampling=True):
96
+ self.epoch_samples = epoch_samples
97
+ self.num_classes_per_sample = num_classes_per_sample
98
+
99
+ self.image_size = image_size
100
+ self.tokenizer = tokenizer
101
+ self.precision = precision
102
+ self.transform = ResizeLongestSide(image_size)
103
+ self.global_enc_processor = CLIPImageProcessor.from_pretrained(global_image_encoder)
104
+
105
+ self.question_templates = SEG_QUESTIONS
106
+ self.answer_list = ANSWER_LIST
107
+ self.begin_str = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
108
+ self.validation = validation
109
+ self.random_sampling = random_sampling
110
+
111
+ self.data2list = {}
112
+ self.data2classes = {}
113
+ self.dataset_dir = os.path.join(dataset_dir, "Semantic_Segm")
114
+ self.semantic_seg_ds_list = semantic_segm_data.split("||")
115
+ for ds in self.semantic_seg_ds_list:
116
+ classes, images, labels = eval("init_{}".format(ds))(self.dataset_dir)
117
+ self.data2list[ds] = (images, labels)
118
+ self.data2classes[ds] = classes
119
+ print(f'\033[92m----SEG-{"Val" if validation else "Train"}: Loaded ReferSeg - {ds} dataset ----\033[0m')
120
+
121
+ if "cocostuff" in self.semantic_seg_ds_list:
122
+ self.cocostuff_class2index = {c: i for i, c in enumerate(self.data2classes["cocostuff"])}
123
+
124
+ def __len__(self):
125
+ return self.epoch_samples
126
+
127
+ def _set_len(self, length):
128
+ self.epoch_samples = length
129
+
130
+ def grounding_enc_processor(self, x: torch.Tensor) -> torch.Tensor:
131
+ x = (x - self.IMG_MEAN) / self.IMG_STD
132
+ h, w = x.shape[-2:]
133
+ x = F.pad(x, (0, self.IMG_SIZE - w, 0, self.IMG_SIZE - h))
134
+ return x
135
+
136
+ def create_conversations(self, labels, dataset_name):
137
+ questions = []
138
+ answers = []
139
+ class_ids = []
140
+ for i, label in enumerate(labels):
141
+ label = label.strip()
142
+ assert len(label.split("||")) == 1
143
+ question_template = random.choice(self.question_templates)
144
+ questions.append(question_template.format(class_name=label.lower()))
145
+ answers.append(random.choice(self.answer_list))
146
+
147
+ if dataset_name in ["paco_lvis", "pascal_part"]:
148
+ continue
149
+ class_id = self.data2classes[dataset_name].tolist().index(label)
150
+ class_ids.append(class_id)
151
+
152
+ conversations = []
153
+ conv = conversation_lib.default_conversation.copy()
154
+ conv.messages = []
155
+ for i, (question, answer) in enumerate(zip(questions, answers)):
156
+ if i == 0:
157
+ question = self.begin_str + question
158
+ conv.append_message(conv.roles[0], question)
159
+ conv.append_message(conv.roles[1], answer)
160
+ conversations.append(conv.get_prompt())
161
+ return questions, conversations, class_ids
162
+
163
+ def __getitem__(self, idx):
164
+ dataset_idx = random.randint(0, len(self.semantic_seg_ds_list) - 1)
165
+ dataset_name = self.semantic_seg_ds_list[dataset_idx]
166
+
167
+ if dataset_name in ["paco_lvis", "pascal_part"]:
168
+ class_map = self.data2classes[dataset_name]
169
+ img_ids, coco_api = self.data2list[dataset_name]
170
+ random_idx = random.randint(0, len(img_ids) - 1)
171
+ img_info = coco_api.loadImgs([img_ids[random_idx]])[0]
172
+ file_name = img_info["file_name"]
173
+ image_path = (os.path.join(
174
+ self.dataset_dir, dataset_name, "VOCdevkit", "VOC2010", "JPEGImages", file_name
175
+ ) if dataset_name == "pascal_part" else self.dataset_dir.replace("Semantic_Segm/", ""),
176
+ "coco_2017", file_name)
177
+
178
+ annotation_ids = coco_api.getAnnIds(imgIds=img_info["id"])
179
+ annotations = coco_api.loadAnns(annotation_ids)
180
+ if not annotations:
181
+ return self.__getitem__(0)
182
+
183
+ sampled_anns = np.random.choice(annotations, self.num_classes_per_sample, replace=False) if len(
184
+ annotations
185
+ ) >= self.num_classes_per_sample else annotations
186
+ selected_labels = []
187
+ for ann in sampled_anns:
188
+ category_id = ann["category_id"]
189
+ sampled_cls = class_map[category_id]
190
+ if isinstance(sampled_cls, tuple):
191
+ obj, part = sampled_cls
192
+ name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
193
+ else:
194
+ name = sampled_cls
195
+ selected_labels.append(name)
196
+
197
+ elif dataset_name in ["ade20k", "cocostuff", "mapillary"]:
198
+ images, labels = self.data2list[dataset_name]
199
+ idx = idx if (self.validation or not self.random_sampling) else random.randint(0, len(images) - 1)
200
+ image_path, label_path = images[idx], labels[idx]
201
+ label = np.array(Image.open(label_path))
202
+ if dataset_name == "ade20k":
203
+ label = np.where(label == 0, 255, label - 1)
204
+ elif dataset_name == "cocostuff":
205
+ ignored_classes = [index for class_name, index in self.cocostuff_class2index.items() if
206
+ "-" in class_name]
207
+ label = np.where(np.isin(label, ignored_classes), 255, label)
208
+
209
+ unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
210
+ if not unique_labels:
211
+ return self.__getitem__(0)
212
+
213
+ classes = [self.data2classes[dataset_name][lbl] for lbl in unique_labels]
214
+ selected_labels = np.random.choice(
215
+ classes, min(len(classes), self.num_classes_per_sample), replace=False
216
+ ) if len(classes) >= self.num_classes_per_sample else classes
217
+
218
+ # Load and process the image
219
+ image = cv2.imread(image_path)
220
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
221
+ global_enc_img = self.global_enc_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
222
+ image = self.transform.apply_image(image)
223
+ image_resize = image.shape[:2]
224
+ grounding_enc_img = self.grounding_enc_processor(torch.from_numpy(image).permute(2, 0, 1).contiguous())
225
+
226
+ # Generate questions and answers
227
+ questions, conversations, class_ids = self.create_conversations(selected_labels, dataset_name)
228
+ if dataset_name in ["paco_lvis", "pascal_part"]:
229
+ try:
230
+ masks = [coco_api.annToMask(ann) for ann in sampled_anns]
231
+ except Exception as e:
232
+ print(f"Error generating mask: {e}")
233
+ return self.__getitem__(0)
234
+
235
+ masks = np.stack(masks, axis=0)
236
+ masks = torch.from_numpy(masks)
237
+ label = torch.ones(masks.shape[1], masks.shape[2]) * self.IGNORE_LABEL
238
+ else:
239
+ label = torch.from_numpy(label).long()
240
+ masks = torch.stack([label == class_id for class_id in class_ids], dim=0)
241
+
242
+ assert len(conversations) == 1
243
+ assert conversations[0].count("[SEG]") == masks.shape[0]
244
+ # set bboxes to None for segmentation datasets
245
+ bboxes = None
246
+
247
+ return (image_path, global_enc_img, grounding_enc_img, bboxes, conversations, masks, label,
248
+ image_resize, questions, selected_labels)
groundingLMM/dataset/utils/ade20k_classes.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road",
3
+ "bed", "windowpane", "grass", "cabinet", "sidewalk",
4
+ "person", "earth", "door", "table", "mountain", "plant",
5
+ "curtain", "chair", "car", "water", "painting", "sofa",
6
+ "shelf", "house", "sea", "mirror", "rug", "field", "armchair",
7
+ "seat", "fence", "desk", "rock", "wardrobe", "lamp",
8
+ "bathtub", "railing", "cushion", "base", "box", "column",
9
+ "signboard", "chest of drawers", "counter", "sand", "sink",
10
+ "skyscraper", "fireplace", "refrigerator", "grandstand",
11
+ "path", "stairs", "runway", "case", "pool table", "pillow",
12
+ "screen door", "stairway", "river", "bridge", "bookcase",
13
+ "blind", "coffee table", "toilet", "flower", "book", "hill",
14
+ "bench", "countertop", "stove", "palm", "kitchen island",
15
+ "computer", "swivel chair", "boat", "bar", "arcade machine",
16
+ "hovel", "bus", "towel", "light", "truck", "tower",
17
+ "chandelier", "awning", "streetlight", "booth",
18
+ "television receiver", "airplane", "dirt track", "apparel",
19
+ "pole", "land", "bannister", "escalator", "ottoman", "bottle",
20
+ "buffet", "poster", "stage", "van", "ship", "fountain",
21
+ "conveyer belt", "canopy", "washer", "plaything",
22
+ "swimming pool", "stool", "barrel", "basket", "waterfall",
23
+ "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
24
+ "step", "tank", "trade name", "microwave", "pot", "animal",
25
+ "bicycle", "lake", "dishwasher", "screen", "blanket",
26
+ "sculpture", "hood", "sconce", "vase", "traffic light",
27
+ "tray", "ashcan", "fan", "pier", "crt screen", "plate",
28
+ "monitor", "bulletin board", "shower", "radiator", "glass",
29
+ "clock", "flag"
30
+ ]
groundingLMM/dataset/utils/cocostuff_classes.txt ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0: unlabeled
2
+ 1: person
3
+ 2: bicycle
4
+ 3: car
5
+ 4: motorcycle
6
+ 5: airplane
7
+ 6: bus
8
+ 7: train
9
+ 8: truck
10
+ 9: boat
11
+ 10: traffic light
12
+ 11: fire hydrant
13
+ 12: street sign
14
+ 13: stop sign
15
+ 14: parking meter
16
+ 15: bench
17
+ 16: bird
18
+ 17: cat
19
+ 18: dog
20
+ 19: horse
21
+ 20: sheep
22
+ 21: cow
23
+ 22: elephant
24
+ 23: bear
25
+ 24: zebra
26
+ 25: giraffe
27
+ 26: hat
28
+ 27: backpack
29
+ 28: umbrella
30
+ 29: shoe
31
+ 30: eye glasses
32
+ 31: handbag
33
+ 32: tie
34
+ 33: suitcase
35
+ 34: frisbee
36
+ 35: skis
37
+ 36: snowboard
38
+ 37: sports ball
39
+ 38: kite
40
+ 39: baseball bat
41
+ 40: baseball glove
42
+ 41: skateboard
43
+ 42: surfboard
44
+ 43: tennis racket
45
+ 44: bottle
46
+ 45: plate
47
+ 46: wine glass
48
+ 47: cup
49
+ 48: fork
50
+ 49: knife
51
+ 50: spoon
52
+ 51: bowl
53
+ 52: banana
54
+ 53: apple
55
+ 54: sandwich
56
+ 55: orange
57
+ 56: broccoli
58
+ 57: carrot
59
+ 58: hot dog
60
+ 59: pizza
61
+ 60: donut
62
+ 61: cake
63
+ 62: chair
64
+ 63: couch
65
+ 64: potted plant
66
+ 65: bed
67
+ 66: mirror
68
+ 67: dining table
69
+ 68: window
70
+ 69: desk
71
+ 70: toilet
72
+ 71: door
73
+ 72: tv
74
+ 73: laptop
75
+ 74: mouse
76
+ 75: remote
77
+ 76: keyboard
78
+ 77: cell phone
79
+ 78: microwave
80
+ 79: oven
81
+ 80: toaster
82
+ 81: sink
83
+ 82: refrigerator
84
+ 83: blender
85
+ 84: book
86
+ 85: clock
87
+ 86: vase
88
+ 87: scissors
89
+ 88: teddy bear
90
+ 89: hair drier
91
+ 90: toothbrush
92
+ 91: hair brush
93
+ 92: banner
94
+ 93: blanket
95
+ 94: branch
96
+ 95: bridge
97
+ 96: building-other
98
+ 97: bush
99
+ 98: cabinet
100
+ 99: cage
101
+ 100: cardboard
102
+ 101: carpet
103
+ 102: ceiling-other
104
+ 103: ceiling-tile
105
+ 104: cloth
106
+ 105: clothes
107
+ 106: clouds
108
+ 107: counter
109
+ 108: cupboard
110
+ 109: curtain
111
+ 110: desk-stuff
112
+ 111: dirt
113
+ 112: door-stuff
114
+ 113: fence
115
+ 114: floor-marble
116
+ 115: floor-other
117
+ 116: floor-stone
118
+ 117: floor-tile
119
+ 118: floor-wood
120
+ 119: flower
121
+ 120: fog
122
+ 121: food-other
123
+ 122: fruit
124
+ 123: furniture-other
125
+ 124: grass
126
+ 125: gravel
127
+ 126: ground-other
128
+ 127: hill
129
+ 128: house
130
+ 129: leaves
131
+ 130: light
132
+ 131: mat
133
+ 132: metal
134
+ 133: mirror-stuff
135
+ 134: moss
136
+ 135: mountain
137
+ 136: mud
138
+ 137: napkin
139
+ 138: net
140
+ 139: paper
141
+ 140: pavement
142
+ 141: pillow
143
+ 142: plant-other
144
+ 143: plastic
145
+ 144: platform
146
+ 145: playingfield
147
+ 146: railing
148
+ 147: railroad
149
+ 148: river
150
+ 149: road
151
+ 150: rock
152
+ 151: roof
153
+ 152: rug
154
+ 153: salad
155
+ 154: sand
156
+ 155: sea
157
+ 156: shelf
158
+ 157: sky
159
+ 158: skyscraper
160
+ 159: snow
161
+ 160: solid-other
162
+ 161: stairs
163
+ 162: stone
164
+ 163: straw
165
+ 164: structural-other
166
+ 165: table
167
+ 166: tent
168
+ 167: textile-other
169
+ 168: towel
170
+ 169: tree
171
+ 170: vegetable
172
+ 171: wall-brick
173
+ 172: wall-concrete
174
+ 173: wall-other
175
+ 174: wall-panel
176
+ 175: wall-stone
177
+ 176: wall-tile
178
+ 177: wall-wood
179
+ 178: water-other
180
+ 179: waterdrops
181
+ 180: window-blind
182
+ 181: window-other
183
+ 182: wood
groundingLMM/dataset/utils/grefer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ grefer v0.1
3
+ This interface provides access to gRefCOCO.
4
+
5
+ The following API functions are defined:
6
+ G_REFER - REFER api class
7
+ getRefIds - get ref ids that satisfy given filter conditions.
8
+ getAnnIds - get ann ids that satisfy given filter conditions.
9
+ getImgIds - get image ids that satisfy given filter conditions.
10
+ getCatIds - get category ids that satisfy given filter conditions.
11
+ loadRefs - load refs with the specified ref ids.
12
+ loadAnns - load anns with the specified ann ids.
13
+ loadImgs - load images with the specified image ids.
14
+ loadCats - load category names with the specified category ids.
15
+ getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
16
+ showRef - show image, segmentation or box of the referred object with the ref
17
+ getMaskByRef - get mask and area of the referred object given ref or ref ids
18
+ getMask - get mask and area of the referred object given ref
19
+ showMask - show mask of the referred object given ref
20
+ """
21
+
22
+ import itertools
23
+ import json
24
+ import os.path as osp
25
+ import pickle
26
+ import time
27
+
28
+ import matplotlib.pyplot as plt
29
+ import numpy as np
30
+ import skimage.io as io
31
+ from matplotlib.collections import PatchCollection
32
+ from matplotlib.patches import Polygon, Rectangle
33
+ from pycocotools import mask
34
+
35
+
36
+ class G_REFER:
37
+ def __init__(self, data_root, dataset="grefcoco", splitBy="unc"):
38
+ # provide data_root folder which contains grefcoco
39
+ print("loading dataset %s into memory..." % dataset)
40
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
41
+ self.DATA_DIR = osp.join(data_root, dataset)
42
+ if dataset in ["grefcoco"]:
43
+ self.IMAGE_DIR = osp.join(data_root, "images/train2014")
44
+ else:
45
+ raise KeyError("No refer dataset is called [%s]" % dataset)
46
+
47
+ tic = time.time()
48
+
49
+ # load refs from data/dataset/refs(dataset).json
50
+ self.data = {}
51
+ self.data["dataset"] = dataset
52
+
53
+ ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).p")
54
+ if osp.exists(ref_file):
55
+ self.data["refs"] = pickle.load(open(ref_file, "rb"), fix_imports=True)
56
+ else:
57
+ ref_file = osp.join(self.DATA_DIR, f"grefs({splitBy}).json")
58
+ if osp.exists(ref_file):
59
+ self.data["refs"] = json.load(open(ref_file, "rb"))
60
+ else:
61
+ raise FileNotFoundError("JSON file not found")
62
+
63
+ # load annotations from data/dataset/instances.json
64
+ instances_file = osp.join(self.DATA_DIR, "instances.json")
65
+ instances = json.load(open(instances_file, "r"))
66
+ self.data["images"] = instances["images"]
67
+ self.data["annotations"] = instances["annotations"]
68
+ self.data["categories"] = instances["categories"]
69
+
70
+ # create index
71
+ self.createIndex()
72
+ print("DONE (t=%.2fs)" % (time.time() - tic))
73
+
74
+ @staticmethod
75
+ def _toList(x):
76
+ return x if isinstance(x, list) else [x]
77
+
78
+ @staticmethod
79
+ def match_any(a, b):
80
+ a = a if isinstance(a, list) else [a]
81
+ b = b if isinstance(b, list) else [b]
82
+ return set(a) & set(b)
83
+
84
+ def createIndex(self):
85
+ # create sets of mapping
86
+ # 1) Refs: {ref_id: ref}
87
+ # 2) Anns: {ann_id: ann}
88
+ # 3) Imgs: {image_id: image}
89
+ # 4) Cats: {category_id: category_name}
90
+ # 5) Sents: {sent_id: sent}
91
+ # 6) imgToRefs: {image_id: refs}
92
+ # 7) imgToAnns: {image_id: anns}
93
+ # 8) refToAnn: {ref_id: ann}
94
+ # 9) annToRef: {ann_id: ref}
95
+ # 10) catToRefs: {category_id: refs}
96
+ # 11) sentToRef: {sent_id: ref}
97
+ # 12) sentToTokens: {sent_id: tokens}
98
+ print("creating index...")
99
+ # fetch info from instances
100
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
101
+ Anns[-1] = None
102
+ for ann in self.data["annotations"]:
103
+ Anns[ann["id"]] = ann
104
+ imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
105
+ for img in self.data["images"]:
106
+ Imgs[img["id"]] = img
107
+ for cat in self.data["categories"]:
108
+ Cats[cat["id"]] = cat["name"]
109
+
110
+ # fetch info from refs
111
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
112
+ Sents, sentToRef, sentToTokens = {}, {}, {}
113
+ availableSplits = []
114
+ for ref in self.data["refs"]:
115
+ # ids
116
+ ref_id = ref["ref_id"]
117
+ ann_id = ref["ann_id"]
118
+ category_id = ref["category_id"]
119
+ image_id = ref["image_id"]
120
+
121
+ if ref["split"] not in availableSplits:
122
+ availableSplits.append(ref["split"])
123
+
124
+ # add mapping related to ref
125
+ if ref_id in Refs:
126
+ print("Duplicate ref id")
127
+ Refs[ref_id] = ref
128
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
129
+
130
+ category_id = self._toList(category_id)
131
+ added_cats = []
132
+ for cat in category_id:
133
+ if cat not in added_cats:
134
+ added_cats.append(cat)
135
+ catToRefs[cat] = catToRefs.get(cat, []) + [ref]
136
+
137
+ ann_id = self._toList(ann_id)
138
+ refToAnn[ref_id] = [Anns[ann] for ann in ann_id]
139
+ for ann_id_n in ann_id:
140
+ annToRef[ann_id_n] = annToRef.get(ann_id_n, []) + [ref]
141
+
142
+ # add mapping of sent
143
+ for sent in ref["sentences"]:
144
+ Sents[sent["sent_id"]] = sent
145
+ sentToRef[sent["sent_id"]] = ref
146
+ sentToTokens[sent["sent_id"]] = sent["tokens"]
147
+
148
+ # create class members
149
+ self.Refs = Refs
150
+ self.Anns = Anns
151
+ self.Imgs = Imgs
152
+ self.Cats = Cats
153
+ self.Sents = Sents
154
+ self.imgToRefs = imgToRefs
155
+ self.imgToAnns = imgToAnns
156
+ self.refToAnn = refToAnn
157
+ self.annToRef = annToRef
158
+ self.catToRefs = catToRefs
159
+ self.sentToRef = sentToRef
160
+ self.sentToTokens = sentToTokens
161
+ self.availableSplits = availableSplits
162
+ print("index created.")
163
+
164
+ def getRefIds(self, image_ids=[], cat_ids=[], split=[]):
165
+ image_ids = self._toList(image_ids)
166
+ cat_ids = self._toList(cat_ids)
167
+ split = self._toList(split)
168
+
169
+ for s in split:
170
+ if s not in self.availableSplits:
171
+ raise ValueError(f"Invalid split name: {s}")
172
+
173
+ refs = self.data["refs"]
174
+
175
+ if len(image_ids) > 0:
176
+ lists = [self.imgToRefs[image_id] for image_id in image_ids]
177
+ refs = list(itertools.chain.from_iterable(lists))
178
+ if len(cat_ids) > 0:
179
+ refs = [ref for ref in refs if self.match_any(ref["category_id"], cat_ids)]
180
+ if len(split) > 0:
181
+ refs = [ref for ref in refs if ref["split"] in split]
182
+
183
+ ref_ids = [ref["ref_id"] for ref in refs]
184
+ return ref_ids
185
+
186
+ def getAnnIds(self, image_ids=[], ref_ids=[]):
187
+ image_ids = self._toList(image_ids)
188
+ ref_ids = self._toList(ref_ids)
189
+
190
+ if any([len(image_ids), len(ref_ids)]):
191
+ if len(image_ids) > 0:
192
+ lists = [
193
+ self.imgToAnns[image_id]
194
+ for image_id in image_ids
195
+ if image_id in self.imgToAnns
196
+ ]
197
+ anns = list(itertools.chain.from_iterable(lists))
198
+ else:
199
+ anns = self.data["annotations"]
200
+ ann_ids = [ann["id"] for ann in anns]
201
+ if len(ref_ids) > 0:
202
+ lists = [self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]
203
+ anns_by_ref_id = list(itertools.chain.from_iterable(lists))
204
+ ann_ids = list(set(ann_ids).intersection(set(anns_by_ref_id)))
205
+ else:
206
+ ann_ids = [ann["id"] for ann in self.data["annotations"]]
207
+
208
+ return ann_ids
209
+
210
+ def getImgIds(self, ref_ids=[]):
211
+ ref_ids = self._toList(ref_ids)
212
+
213
+ if len(ref_ids) > 0:
214
+ image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
215
+ else:
216
+ image_ids = self.Imgs.keys()
217
+ return image_ids
218
+
219
+ def getCatIds(self):
220
+ return self.Cats.keys()
221
+
222
+ def loadRefs(self, ref_ids=[]):
223
+ return [self.Refs[ref_id] for ref_id in self._toList(ref_ids)]
224
+
225
+ def loadAnns(self, ann_ids=[]):
226
+ if isinstance(ann_ids, str):
227
+ ann_ids = int(ann_ids)
228
+ return [self.Anns[ann_id] for ann_id in self._toList(ann_ids)]
229
+
230
+ def loadImgs(self, image_ids=[]):
231
+ return [self.Imgs[image_id] for image_id in self._toList(image_ids)]
232
+
233
+ def loadCats(self, cat_ids=[]):
234
+ return [self.Cats[cat_id] for cat_id in self._toList(cat_ids)]
235
+
236
+ def getRefBox(self, ref_id):
237
+ anns = self.refToAnn[ref_id]
238
+ return [ann["bbox"] for ann in anns] # [x, y, w, h]
239
+
240
+ def showRef(self, ref, seg_box="seg"):
241
+ ax = plt.gca()
242
+ # show image
243
+ image = self.Imgs[ref["image_id"]]
244
+ I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
245
+ ax.imshow(I)
246
+ # show refer expression
247
+ for sid, sent in enumerate(ref["sentences"]):
248
+ print("%s. %s" % (sid + 1, sent["sent"]))
249
+ # show segmentations
250
+ if seg_box == "seg":
251
+ ann_id = ref["ann_id"]
252
+ ann = self.Anns[ann_id]
253
+ polygons = []
254
+ color = []
255
+ c = "none"
256
+ if type(ann["segmentation"][0]) == list:
257
+ # polygon used for refcoco*
258
+ for seg in ann["segmentation"]:
259
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
260
+ polygons.append(Polygon(poly, True, alpha=0.4))
261
+ color.append(c)
262
+ p = PatchCollection(
263
+ polygons,
264
+ facecolors=color,
265
+ edgecolors=(1, 1, 0, 0),
266
+ linewidths=3,
267
+ alpha=1,
268
+ )
269
+ ax.add_collection(p) # thick yellow polygon
270
+ p = PatchCollection(
271
+ polygons,
272
+ facecolors=color,
273
+ edgecolors=(1, 0, 0, 0),
274
+ linewidths=1,
275
+ alpha=1,
276
+ )
277
+ ax.add_collection(p) # thin red polygon
278
+ else:
279
+ # mask used for refclef
280
+ rle = ann["segmentation"]
281
+ m = mask.decode(rle)
282
+ img = np.ones((m.shape[0], m.shape[1], 3))
283
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
284
+ for i in range(3):
285
+ img[:, :, i] = color_mask[i]
286
+ ax.imshow(np.dstack((img, m * 0.5)))
287
+ # show bounding-box
288
+ elif seg_box == "box":
289
+ ann_id = ref["ann_id"]
290
+ ann = self.Anns[ann_id]
291
+ bbox = self.getRefBox(ref["ref_id"])
292
+ box_plot = Rectangle(
293
+ (bbox[0], bbox[1]),
294
+ bbox[2],
295
+ bbox[3],
296
+ fill=False,
297
+ edgecolor="green",
298
+ linewidth=3,
299
+ )
300
+ ax.add_patch(box_plot)
301
+
302
+ def getMask(self, ann):
303
+ if not ann:
304
+ return None
305
+ if ann["iscrowd"]:
306
+ raise ValueError("Crowd object")
307
+ image = self.Imgs[ann["image_id"]]
308
+ if type(ann["segmentation"][0]) == list: # polygon
309
+ rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
310
+ else:
311
+ rle = ann["segmentation"]
312
+
313
+ m = mask.decode(rle)
314
+ m = np.sum(
315
+ m, axis=2
316
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
317
+ m = m.astype(np.uint8) # convert to np.uint8
318
+ # compute area
319
+ area = sum(mask.area(rle)) # should be close to ann['area']
320
+ return {"mask": m, "area": area}
321
+
322
+ def getMaskByRef(self, ref=None, ref_id=None, merge=False):
323
+ if not ref and not ref_id:
324
+ raise ValueError
325
+ if ref:
326
+ ann_ids = ref["ann_id"]
327
+ ref_id = ref["ref_id"]
328
+ else:
329
+ ann_ids = self.getAnnIds(ref_ids=ref_id)
330
+
331
+ if ann_ids == [-1]:
332
+ img = self.Imgs[self.Refs[ref_id]["image_id"]]
333
+ return {
334
+ "mask": np.zeros([img["height"], img["width"]], dtype=np.uint8),
335
+ "empty": True,
336
+ }
337
+
338
+ anns = self.loadAnns(ann_ids)
339
+ mask_list = [self.getMask(ann) for ann in anns if not ann["iscrowd"]]
340
+
341
+ if merge:
342
+ merged_masks = sum([mask["mask"] for mask in mask_list])
343
+ merged_masks[np.where(merged_masks > 1)] = 1
344
+ return {"mask": merged_masks, "empty": False}
345
+ else:
346
+ return mask_list
347
+
348
+ def showMask(self, ref):
349
+ M = self.getMask(ref)
350
+ msk = M["mask"]
351
+ ax = plt.gca()
352
+ ax.imshow(msk)
groundingLMM/dataset/utils/refcoco_refer.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __author__ = "licheng"
2
+
3
+ """
4
+ This interface provides access to four datasets:
5
+ 1) refclef
6
+ 2) refcoco
7
+ 3) refcoco+
8
+ 4) refcocog
9
+ split by unc and google
10
+
11
+ The following API functions are defined:
12
+ REFER - REFER api class
13
+ getRefIds - get ref ids that satisfy given filter conditions.
14
+ getAnnIds - get ann ids that satisfy given filter conditions.
15
+ getImgIds - get image ids that satisfy given filter conditions.
16
+ getCatIds - get category ids that satisfy given filter conditions.
17
+ loadRefs - load refs with the specified ref ids.
18
+ loadAnns - load anns with the specified ann ids.
19
+ loadImgs - load images with the specified image ids.
20
+ loadCats - load category names with the specified category ids.
21
+ getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
22
+ showRef - show image, segmentation or box of the referred object with the ref
23
+ getMask - get mask and area of the referred object given ref
24
+ showMask - show mask of the referred object given ref
25
+ """
26
+
27
+ import itertools
28
+ import json
29
+ import os.path as osp
30
+ import pickle
31
+ import sys
32
+ import time
33
+ from pprint import pprint
34
+
35
+ import matplotlib.pyplot as plt
36
+ import numpy as np
37
+ import skimage.io as io
38
+ from matplotlib.collections import PatchCollection
39
+ from matplotlib.patches import Polygon, Rectangle
40
+ from pycocotools import mask
41
+
42
+
43
+ class REFER:
44
+ def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
45
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
46
+ # also provide dataset name and splitBy information
47
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
48
+ print("loading dataset %s into memory..." % dataset)
49
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
50
+ self.DATA_DIR = osp.join(data_root, dataset)
51
+ if dataset in ["refcoco", "refcoco+", "refcocog"]:
52
+ self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014")
53
+ elif dataset == "refclef":
54
+ self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12")
55
+ else:
56
+ print("No refer dataset is called [%s]" % dataset)
57
+ sys.exit()
58
+
59
+ self.dataset = dataset
60
+
61
+ # load refs from data/dataset/refs(dataset).json
62
+ tic = time.time()
63
+
64
+ ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
65
+ print("ref_file: ", ref_file)
66
+ self.data = {}
67
+ self.data["dataset"] = dataset
68
+ self.data["refs"] = pickle.load(open(ref_file, "rb"))
69
+
70
+ # load annotations from data/dataset/instances.json
71
+ instances_file = osp.join(self.DATA_DIR, "instances.json")
72
+ instances = json.load(open(instances_file, "rb"))
73
+ self.data["images"] = instances["images"]
74
+ self.data["annotations"] = instances["annotations"]
75
+ self.data["categories"] = instances["categories"]
76
+
77
+ # create index
78
+ self.createIndex()
79
+ print("DONE (t=%.2fs)" % (time.time() - tic))
80
+
81
+ def createIndex(self):
82
+ # create sets of mapping
83
+ # 1) Refs: {ref_id: ref}
84
+ # 2) Anns: {ann_id: ann}
85
+ # 3) Imgs: {image_id: image}
86
+ # 4) Cats: {category_id: category_name}
87
+ # 5) Sents: {sent_id: sent}
88
+ # 6) imgToRefs: {image_id: refs}
89
+ # 7) imgToAnns: {image_id: anns}
90
+ # 8) refToAnn: {ref_id: ann}
91
+ # 9) annToRef: {ann_id: ref}
92
+ # 10) catToRefs: {category_id: refs}
93
+ # 11) sentToRef: {sent_id: ref}
94
+ # 12) sentToTokens: {sent_id: tokens}
95
+ print("creating index...")
96
+ # fetch info from instances
97
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
98
+ for ann in self.data["annotations"]:
99
+ Anns[ann["id"]] = ann
100
+ imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
101
+ for img in self.data["images"]:
102
+ Imgs[img["id"]] = img
103
+ for cat in self.data["categories"]:
104
+ Cats[cat["id"]] = cat["name"]
105
+
106
+ # fetch info from refs
107
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
108
+ Sents, sentToRef, sentToTokens = {}, {}, {}
109
+ for ref in self.data["refs"]:
110
+ # ids
111
+ ref_id = ref["ref_id"]
112
+ ann_id = ref["ann_id"]
113
+ category_id = ref["category_id"]
114
+ image_id = ref["image_id"]
115
+
116
+ # add mapping related to ref
117
+ Refs[ref_id] = ref
118
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
119
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
120
+ refToAnn[ref_id] = Anns[ann_id]
121
+ annToRef[ann_id] = ref
122
+
123
+ # add mapping of sent
124
+ for sent in ref["sentences"]:
125
+ Sents[sent["sent_id"]] = sent
126
+ sentToRef[sent["sent_id"]] = ref
127
+ sentToTokens[sent["sent_id"]] = sent["tokens"]
128
+
129
+ # create class members
130
+ self.Refs = Refs
131
+ self.Anns = Anns
132
+ self.Imgs = Imgs
133
+ self.Cats = Cats
134
+ self.Sents = Sents
135
+ self.imgToRefs = imgToRefs
136
+ self.imgToAnns = imgToAnns
137
+ self.refToAnn = refToAnn
138
+ self.annToRef = annToRef
139
+ self.catToRefs = catToRefs
140
+ self.sentToRef = sentToRef
141
+ self.sentToTokens = sentToTokens
142
+ print("index created.")
143
+
144
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
145
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
146
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
147
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
148
+
149
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
150
+ refs = self.data["refs"]
151
+ else:
152
+ if not len(image_ids) == 0:
153
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
154
+ else:
155
+ refs = self.data["refs"]
156
+ if not len(cat_ids) == 0:
157
+ refs = [ref for ref in refs if ref["category_id"] in cat_ids]
158
+ if not len(ref_ids) == 0:
159
+ refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
160
+ if not len(split) == 0:
161
+ if split in ["testA", "testB", "testC"]:
162
+ refs = [
163
+ ref for ref in refs if split[-1] in ref["split"]
164
+ ] # we also consider testAB, testBC, ...
165
+ elif split in ["testAB", "testBC", "testAC"]:
166
+ refs = [
167
+ ref for ref in refs if ref["split"] == split
168
+ ] # rarely used I guess...
169
+ elif split == "test":
170
+ refs = [ref for ref in refs if "test" in ref["split"]]
171
+ elif split == "train" or split == "val":
172
+ refs = [ref for ref in refs if ref["split"] == split]
173
+ else:
174
+ print("No such split [%s]" % split)
175
+ sys.exit()
176
+ ref_ids = [ref["ref_id"] for ref in refs]
177
+ return ref_ids
178
+
179
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
180
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
181
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
182
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
183
+
184
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
185
+ ann_ids = [ann["id"] for ann in self.data["annotations"]]
186
+ else:
187
+ if not len(image_ids) == 0:
188
+ lists = [
189
+ self.imgToAnns[image_id]
190
+ for image_id in image_ids
191
+ if image_id in self.imgToAnns
192
+ ] # list of [anns]
193
+ anns = list(itertools.chain.from_iterable(lists))
194
+ else:
195
+ anns = self.data["annotations"]
196
+ if not len(cat_ids) == 0:
197
+ anns = [ann for ann in anns if ann["category_id"] in cat_ids]
198
+ ann_ids = [ann["id"] for ann in anns]
199
+ if not len(ref_ids) == 0:
200
+ ids = set(ann_ids).intersection(
201
+ set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
202
+ )
203
+ return ann_ids
204
+
205
+ def getImgIds(self, ref_ids=[]):
206
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
207
+
208
+ if not len(ref_ids) == 0:
209
+ image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
210
+ else:
211
+ image_ids = self.Imgs.keys()
212
+ return image_ids
213
+
214
+ def getCatIds(self):
215
+ return self.Cats.keys()
216
+
217
+ def loadRefs(self, ref_ids=[]):
218
+ if type(ref_ids) == list:
219
+ return [self.Refs[ref_id] for ref_id in ref_ids]
220
+ elif type(ref_ids) == int:
221
+ return [self.Refs[ref_ids]]
222
+
223
+ def loadAnns(self, ann_ids=[]):
224
+ if type(ann_ids) == list:
225
+ return [self.Anns[ann_id] for ann_id in ann_ids]
226
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
227
+ return [self.Anns[ann_ids]]
228
+
229
+ def loadImgs(self, image_ids=[]):
230
+ if type(image_ids) == list:
231
+ return [self.Imgs[image_id] for image_id in image_ids]
232
+ elif type(image_ids) == int:
233
+ return [self.Imgs[image_ids]]
234
+
235
+ def loadCats(self, cat_ids=[]):
236
+ if type(cat_ids) == list:
237
+ return [self.Cats[cat_id] for cat_id in cat_ids]
238
+ elif type(cat_ids) == int:
239
+ return [self.Cats[cat_ids]]
240
+
241
+ def getRefBox(self, ref_id):
242
+ ref = self.Refs[ref_id]
243
+ ann = self.refToAnn[ref_id]
244
+ return ann["bbox"] # [x, y, w, h]
245
+
246
+ def showRef(self, ref, seg_box="seg"):
247
+ ax = plt.gca()
248
+ # show image
249
+ image = self.Imgs[ref["image_id"]]
250
+ I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
251
+ ax.imshow(I)
252
+ # show refer expression
253
+ for sid, sent in enumerate(ref["sentences"]):
254
+ print("%s. %s" % (sid + 1, sent["sent"]))
255
+ # show segmentations
256
+ if seg_box == "seg":
257
+ ann_id = ref["ann_id"]
258
+ ann = self.Anns[ann_id]
259
+ polygons = []
260
+ color = []
261
+ c = "none"
262
+ if type(ann["segmentation"][0]) == list:
263
+ # polygon used for refcoco*
264
+ for seg in ann["segmentation"]:
265
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
266
+ polygons.append(Polygon(poly, True, alpha=0.4))
267
+ color.append(c)
268
+ p = PatchCollection(
269
+ polygons,
270
+ facecolors=color,
271
+ edgecolors=(1, 1, 0, 0),
272
+ linewidths=3,
273
+ alpha=1,
274
+ )
275
+ ax.add_collection(p) # thick yellow polygon
276
+ p = PatchCollection(
277
+ polygons,
278
+ facecolors=color,
279
+ edgecolors=(1, 0, 0, 0),
280
+ linewidths=1,
281
+ alpha=1,
282
+ )
283
+ ax.add_collection(p) # thin red polygon
284
+ else:
285
+ # mask used for refclef
286
+ rle = ann["segmentation"]
287
+ m = mask.decode(rle)
288
+ img = np.ones((m.shape[0], m.shape[1], 3))
289
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
290
+ for i in range(3):
291
+ img[:, :, i] = color_mask[i]
292
+ ax.imshow(np.dstack((img, m * 0.5)))
293
+ # show bounding-box
294
+ elif seg_box == "box":
295
+ ann_id = ref["ann_id"]
296
+ ann = self.Anns[ann_id]
297
+ bbox = self.getRefBox(ref["ref_id"])
298
+ box_plot = Rectangle(
299
+ (bbox[0], bbox[1]),
300
+ bbox[2],
301
+ bbox[3],
302
+ fill=False,
303
+ edgecolor="green",
304
+ linewidth=3,
305
+ )
306
+ ax.add_patch(box_plot)
307
+
308
+ def getMask(self, ref):
309
+ # return mask, area and mask-center
310
+ ann = self.refToAnn[ref["ref_id"]]
311
+ image = self.Imgs[ref["image_id"]]
312
+ if type(ann["segmentation"][0]) == list: # polygon
313
+ rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
314
+ else:
315
+ rle = ann["segmentation"]
316
+ m = mask.decode(rle)
317
+ m = np.sum(
318
+ m, axis=2
319
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
320
+ m = m.astype(np.uint8) # convert to np.uint8
321
+ # compute area
322
+ area = sum(mask.area(rle)) # should be close to ann['area']
323
+ return {"mask": m, "area": area}
324
+ # # position
325
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
326
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
327
+ # # mass position (if there were multiple regions, we use the largest one.)
328
+ # label_m = label(m, connectivity=m.ndim)
329
+ # regions = regionprops(label_m)
330
+ # if len(regions) > 0:
331
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
332
+ # largest_props = regions[largest_id]
333
+ # mass_y, mass_x = largest_props.centroid
334
+ # else:
335
+ # mass_x, mass_y = position_x, position_y
336
+ # # if centroid is not in mask, we find the closest point to it from mask
337
+ # if m[mass_y, mass_x] != 1:
338
+ # print('Finding closes mask point ...')
339
+ # kernel = np.ones((10, 10),np.uint8)
340
+ # me = cv2.erode(m, kernel, iterations = 1)
341
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
342
+ # points = np.array(points)
343
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
344
+ # id = np.argsort(dist)[0]
345
+ # mass_y, mass_x = points[id]
346
+ # # return
347
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
348
+ # # show image and mask
349
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
350
+ # plt.figure()
351
+ # plt.imshow(I)
352
+ # ax = plt.gca()
353
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
354
+ # color_mask = np.array([2.0,166.0,101.0])/255
355
+ # for i in range(3):
356
+ # img[:,:,i] = color_mask[i]
357
+ # ax.imshow(np.dstack( (img, m*0.5) ))
358
+ # plt.show()
359
+
360
+ def showMask(self, ref):
361
+ M = self.getMask(ref)
362
+ msk = M["mask"]
363
+ ax = plt.gca()
364
+ ax.imshow(msk)
365
+
366
+
367
+ if __name__ == "__main__":
368
+ refer = REFER(dataset="refcocog", splitBy="google")
369
+ ref_ids = refer.getRefIds()
370
+ print(len(ref_ids))
371
+
372
+ print(len(refer.Imgs))
373
+ print(len(refer.imgToRefs))
374
+
375
+ ref_ids = refer.getRefIds(split="train")
376
+ print("There are %s training referred objects." % len(ref_ids))
377
+
378
+ for ref_id in ref_ids:
379
+ ref = refer.loadRefs(ref_id)[0]
380
+ if len(ref["sentences"]) < 2:
381
+ continue
382
+
383
+ pprint(ref)
384
+ print("The label is %s." % refer.Cats[ref["category_id"]])
385
+ plt.figure()
386
+ refer.showRef(ref, seg_box="box")
387
+ plt.show()
388
+
389
+ # plt.figure()
390
+ # refer.showMask(ref)
391
+ # plt.show()
groundingLMM/dataset/utils/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ CAPTION_QUESTIONS = [
3
+ 'Could you please give me a detailed description of the image?',
4
+ 'Can you provide a thorough description of the this image?',
5
+ 'Please provide a thorough description of the this image',
6
+ 'Please provide a thorough description of the this image.',
7
+ 'Please describe in detail the contents of the image.',
8
+ 'Please describe in detail the contents of the image',
9
+ 'Could you give a comprehensive explanation of what can be found within this picture?',
10
+ 'Could you give me an elaborate explanation of this picture?',
11
+ 'Could you provide me with a detailed analysis of this photo?',
12
+ 'Could you please give me a detailed description of the image?',
13
+ 'Can you provide a thorough description of the this image?',
14
+ 'Please describe in detail the contents of the image',
15
+ 'Please describe in detail the contents of the image.',
16
+ 'Can you give a comprehensive explanation of this photo',
17
+ 'Please provide an elaborate explanation of this picture.',
18
+ 'Please provide an elaborate explanation of this picture',
19
+ 'Could you provide me with a detailed analysis of this photo',
20
+ ]
21
+
22
+ REGION_QUESTIONS = [
23
+ 'Can you provide me with a detailed description of the region in the picture marked by <region>?',
24
+ "I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
25
+ 'What can you tell me about the region indicated by <region> in the image?',
26
+ "I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
27
+ 'Could you describe the region shown as <region> in the picture in great detail?',
28
+ 'What details can you give me about the region outlined by <region> in the photo?',
29
+ 'Please provide me with a comprehensive description of the region marked with <region> in the image.',
30
+ 'Can you give me a detailed account of the region labeled as <region> in the picture?',
31
+ "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
32
+ 'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
33
+ 'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
34
+ "I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
35
+ 'What can you tell me about the region indicated by <region> in the image, exactly?',
36
+ "I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
37
+ 'Could you describe the region shown as <region> in the picture in great detail, please?',
38
+ 'What details can you give me about the region outlined by <region> in the photo, please?',
39
+ 'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
40
+ 'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
41
+ "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
42
+ 'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
43
+ ]
44
+
45
+ REGION_GROUP_QUESTIONS = [
46
+ 'Could you please give me a detailed description of these areas <region>?',
47
+ 'Can you provide a thorough description of the regions <region> in this image?',
48
+ 'Please describe in detail the contents of the boxed areas <region>.',
49
+ 'Could you give a comprehensive explanation of what can be found within <region> in the picture?',
50
+ 'Could you give me an elaborate explanation of the <region> regions in this picture?',
51
+ 'Can you provide a comprehensive description of the areas identified by <region> in this photo?',
52
+ 'Help me understand the specific locations labeled <region> in this picture in detail, please.',
53
+ 'What is the detailed information about the areas marked by <region> in this image?',
54
+ 'Could you provide me with a detailed analysis of the regions designated <region> in this photo?',
55
+ 'What are the specific features of the areas marked <region> in this picture that you can describe in detail?',
56
+ 'Could you elaborate on the regions identified by <region> in this image?',
57
+ 'What can you tell me about the areas labeled <region> in this picture?',
58
+ 'Can you provide a thorough analysis of the specific locations designated <region> in this photo?',
59
+ 'I am interested in learning more about the regions marked <region> in this image. Can you provide me with more information?',
60
+ 'Could you please provide a detailed description of the areas identified by <region> in this photo?',
61
+ 'What is the significance of the regions labeled <region> in this picture?',
62
+ 'I would like to know more about the specific locations designated <region> in this image. Can you provide me with more information?',
63
+ 'Can you provide a detailed breakdown of the regions marked <region> in this photo?',
64
+ 'What specific features can you tell me about the areas identified by <region> in this picture?',
65
+ 'Could you please provide a comprehensive explanation of the locations labeled <region> in this image?',
66
+ 'Can you provide a detailed account of the regions designated <region> in this photo?',
67
+ 'I am curious about the areas marked <region> in this picture. Can you provide me with a detailed analysis?',
68
+ 'What important details can you tell me about the specific locations identified by <region> in this image?',
69
+ 'Could you please provide a detailed description of the regions labeled <region> in this photo?',
70
+ 'What can you tell me about the features of the areas designated <region> in this picture?',
71
+ 'Can you provide a comprehensive overview of the regions marked <region> in this image?',
72
+ 'I would like to know more about the specific locations identified by <region> in this photo. Can you provide me with more information?',
73
+ 'What is the detailed information you have on the areas labeled <region> in this picture?',
74
+ 'Could you provide me with a thorough analysis of the regions designated <region> in this image?',
75
+ 'Can you provide a detailed explanation of the specific locations marked by <region> in this photo?'
76
+ ]
77
+
78
+ GCG_QUESTIONS = [
79
+ 'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
80
+ 'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
81
+ 'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
82
+ 'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
83
+ 'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
84
+ 'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
85
+ ]
86
+
87
+ SEG_QUESTIONS = [
88
+ "Can you segment the {class_name} in this image?",
89
+ "Please segment {class_name} in this image.",
90
+ "What is {class_name} in this image? Please respond with segmentation mask.",
91
+ "What is {class_name} in this image? Please output segmentation mask.",
92
+
93
+ "Can you segment the {class_name} in this image",
94
+ "Please segment {class_name} in this image",
95
+ "What is {class_name} in this image? Please respond with segmentation mask",
96
+ "What is {class_name} in this image? Please output segmentation mask",
97
+
98
+ "Could you provide a segmentation mask for the {class_name} in this image?",
99
+ "Please identify and segment the {class_name} in this image.",
100
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
101
+ "Can you highlight the {class_name} in this image with a segmentation mask?",
102
+
103
+ "Could you provide a segmentation mask for the {class_name} in this image",
104
+ "Please identify and segment the {class_name} in this image",
105
+ "Where is the {class_name} in this picture? Please respond with a segmentation mask",
106
+ "Can you highlight the {class_name} in this image with a segmentation mask",
107
+ ]
108
+
109
+ ANSWER_LIST = [
110
+ "It is [SEG].",
111
+ "Sure, [SEG].",
112
+ "Sure, it is [SEG].",
113
+ "Sure, the segmentation result is [SEG].",
114
+ "[SEG].",
115
+ ]
groundingLMM/mmcv/tests/data/config/a.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ item1 = [1, 2]
3
+ item2 = {'a': 0}
4
+ item3 = True
5
+ item4 = 'test'
groundingLMM/mmcv/tests/data/config/b.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "item1": [1, 2],
3
+ "item2": {
4
+ "a": 0
5
+ },
6
+ "item3": true,
7
+ "item4": "test"
8
+ }
groundingLMM/mmcv/tests/data/config/base.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ item1 = [1, 2]
3
+ item2 = {'a': 0}
4
+ item3 = True
5
+ item4 = 'test'
groundingLMM/mmcv/tests/data/config/c.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ item1: [1, 2]
2
+ item2: {'a': 0}
3
+ item3: True
4
+ item4: 'test'
groundingLMM/mmcv/tests/data/config/d.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ _base_ = './base.py'
3
+ item1 = [2, 3]
4
+ item2 = {'a': 1}
5
+ item3 = False
6
+ item4 = 'test_base'
groundingLMM/mmcv/tests/data/config/delete.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ _base_ = './base.py'
3
+ item1 = {'a': 0, '_delete_': True}
4
+ item2 = {'b': 0}
groundingLMM/mmcv/tests/data/config/deprecated.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ _base_ = './expected.py'
3
+
4
+ _deprecation_ = dict(
5
+ expected='tests/data/config/expected.py',
6
+ reference='https://github.com/open-mmlab/mmcv/pull/1275')
groundingLMM/mmcv/tests/data/config/deprecated_as_base.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ _base_ = './deprecated.py'
groundingLMM/mmcv/tests/data/config/e.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ _base_ = './base.py'
3
+ item3 = {'a': 1}