Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- PaddleMIX/.travis/codestyle/clang_format.hook +4 -0
- PaddleMIX/.travis/codestyle/cpplint_pre_commit.hook +27 -0
- PaddleMIX/.travis/codestyle/pylint_pre_commit.hook +25 -0
- PaddleMIX/applications/Audio2Caption/README.md +49 -0
- PaddleMIX/applications/Audio2Img/README.md +106 -0
- PaddleMIX/applications/Audio2Img/audio2img_imagebind.py +176 -0
- PaddleMIX/applications/Audio2Img/gradio_demo.py +135 -0
- PaddleMIX/applications/AudioChat/README.md +36 -0
- PaddleMIX/applications/Automatic_label/README.md +19 -0
- PaddleMIX/applications/Automatic_label/automatic_label.py +61 -0
- PaddleMIX/applications/CVinW/README.md +19 -0
- PaddleMIX/applications/CVinW/grounded_sam.py +46 -0
- PaddleMIX/applications/Inpainting/README.md +87 -0
- PaddleMIX/applications/Inpainting/grounded_sam_chatglm.py +256 -0
- PaddleMIX/applications/Inpainting/grounded_sam_inpainting.py +234 -0
- PaddleMIX/applications/MusicGeneration/README.md +89 -0
- PaddleMIX/applications/VLChat/README.md +44 -0
- PaddleMIX/applications/image2image/README.md +92 -0
- PaddleMIX/applications/image2text/README.md +66 -0
- PaddleMIX/applications/text2image/README.md +27 -0
- PaddleMIX/applications/text2video/README.md +23 -0
- PaddleMIX/deploy/llava/README.md +83 -0
- PaddleMIX/deploy/llava/export_model.py +98 -0
- PaddleMIX/deploy/llava/llama_inference_model.py +127 -0
- PaddleMIX/deploy/llava/run_static_predict.py +403 -0
- PaddleMIX/deploy/llava/utils.py +83 -0
- PaddleMIX/deploy/qwen2_vl/README.md +50 -0
- PaddleMIX/deploy/qwen2_vl/single_image_infer.py +276 -0
- PaddleMIX/deploy/qwen_vl/run_static_predict.py +203 -0
- PaddleMIX/deploy/sam/README.md +37 -0
- PaddleMIX/deploy/sam/export.py +106 -0
- PaddleMIX/deploy/sam/predict.py +374 -0
- PaddleMIX/docs/hardware_support/ascend_usage.md +222 -0
- PaddleMIX/paddlemix/datasets/__init__.py +37 -0
- PaddleMIX/paddlemix/datasets/caption_dataset.py +109 -0
- PaddleMIX/paddlemix/datasets/cc_sbu_dataset.py +93 -0
- PaddleMIX/paddlemix/datasets/chatml_dataset.py +50 -0
- PaddleMIX/paddlemix/datasets/coco_caption.py +17 -0
- PaddleMIX/paddlemix/datasets/coco_vqa.py +138 -0
- PaddleMIX/paddlemix/datasets/collator.py +362 -0
- PaddleMIX/paddlemix/datasets/dataset.py +1169 -0
- PaddleMIX/paddlemix/datasets/got_dataset.py +439 -0
- PaddleMIX/paddlemix/datasets/internvl_dataset.py +688 -0
- PaddleMIX/paddlemix/datasets/laiondata.py +139 -0
- PaddleMIX/paddlemix/datasets/mixtoken_dataset.py +131 -0
- PaddleMIX/paddlemix/datasets/vg_caption.py +37 -0
- PaddleMIX/paddlemix/demo_images/critic_img_seven.png +0 -0
- PaddleMIX/paddlemix/external_ops/setup.py +107 -0
- PaddleMIX/paddlemix/metrics/clip_zero_shot.py +146 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
DeepSeek-VL2/vg.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
Ovis/temp.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
DeepSeek-VL2/vg.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
Ovis/temp.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
VLM2Vec/figures/vlm2vec_results.png filter=lfs diff=lfs merge=lfs -text
|
PaddleMIX/.travis/codestyle/clang_format.hook
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
|
4 |
+
clang-format $@
|
PaddleMIX/.travis/codestyle/cpplint_pre_commit.hook
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
TOTAL_ERRORS=0
|
4 |
+
if [[ ! $TRAVIS_BRANCH ]]; then
|
5 |
+
# install cpplint on local machine.
|
6 |
+
if [[ ! $(which cpplint) ]]; then
|
7 |
+
pip install cpplint
|
8 |
+
fi
|
9 |
+
# diff files on local machine.
|
10 |
+
files=$(git diff --cached --name-status | awk '$1 != "D" {print $2}')
|
11 |
+
else
|
12 |
+
# diff files between PR and latest commit on Travis CI.
|
13 |
+
branch_ref=$(git rev-parse "$TRAVIS_BRANCH")
|
14 |
+
head_ref=$(git rev-parse HEAD)
|
15 |
+
files=$(git diff --name-status $branch_ref $head_ref | awk '$1 != "D" {print $2}')
|
16 |
+
fi
|
17 |
+
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
|
18 |
+
for file in $files; do
|
19 |
+
if [[ $file =~ ^(patches/.*) ]]; then
|
20 |
+
continue;
|
21 |
+
else
|
22 |
+
cpplint --filter=-readability/fn_size,-build/include_what_you_use,-build/c++11 $file;
|
23 |
+
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
|
24 |
+
fi
|
25 |
+
done
|
26 |
+
|
27 |
+
exit $TOTAL_ERRORS
|
PaddleMIX/.travis/codestyle/pylint_pre_commit.hook
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
TOTAL_ERRORS=0
|
4 |
+
|
5 |
+
|
6 |
+
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
7 |
+
export PYTHONPATH=$DIR:$PYTHONPATH
|
8 |
+
|
9 |
+
readonly VERSION="2.12.0"
|
10 |
+
version=$(pylint --version | grep 'pylint')
|
11 |
+
|
12 |
+
if ! [[ $version == *"$VERSION"* ]]; then
|
13 |
+
pip install pylint==2.12.0
|
14 |
+
fi
|
15 |
+
|
16 |
+
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
|
17 |
+
for file in $(git diff --name-status | awk '$1 != "D" {print $2}'); do
|
18 |
+
pylint --disable=all --load-plugins=docstring_checker \
|
19 |
+
--enable=doc-string-one-line,doc-string-end-with,doc-string-with-all-args,doc-string-triple-quotes,doc-string-missing,doc-string-indent-error,doc-string-with-returns,doc-string-with-raises $file;
|
20 |
+
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
|
21 |
+
done
|
22 |
+
|
23 |
+
exit $TOTAL_ERRORS
|
24 |
+
#For now, just warning:
|
25 |
+
#exit 0
|
PaddleMIX/applications/Audio2Caption/README.md
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 音频描述(Audio-to-Caption Generation)
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
#### 1. Application introduction
|
6 |
+
|
7 |
+
Enter audio and prompt words for question and answer.
|
8 |
+
|
9 |
+
*****
|
10 |
+
- No training is need.
|
11 |
+
- Integration with the moedel of [conformer_u2pp_online_wenetspeech](), [chatglm]().
|
12 |
+
|
13 |
+
----
|
14 |
+
|
15 |
+
#### 2. Demo
|
16 |
+
*****
|
17 |
+
example:
|
18 |
+
|
19 |
+
<!-- ```python
|
20 |
+
python applications/AudioChat/audiochat.py \
|
21 |
+
--chatglm_question_prompt "please describe this passage." \
|
22 |
+
--input_audio_file "./zh.wav" \
|
23 |
+
--chatglm_model_name_or_path "THUDM/chatglm-6b" \
|
24 |
+
``` -->
|
25 |
+
```python
|
26 |
+
#audio2caption -- Audio to caption converter
|
27 |
+
|
28 |
+
from paddlemix.appflow import Appflow
|
29 |
+
import paddle
|
30 |
+
paddle.seed(1024)
|
31 |
+
task = Appflow(app="audio2caption", models=["conformer_u2pp_online_wenetspeech", "THUDM/chatglm-6b"])
|
32 |
+
audio_file = "./zh.wav"
|
33 |
+
prompt = (
|
34 |
+
"描述这段话:{}."
|
35 |
+
)
|
36 |
+
result = task(audio=audio_file, prompt=prompt)['prompt']
|
37 |
+
print(result)
|
38 |
+
# 这段话表达了作者认为跑步最重要的好处之一是身体健康。作者认为,通过跑步,身体得到了良好的锻炼,身体健康得到了改善。作者还强调了跑步对身体健康的重要性,并认为这是最值得投资的运动之一。
|
39 |
+
|
40 |
+
```
|
41 |
+
|
42 |
+
<div align="center">
|
43 |
+
|
44 |
+
| Input Audio | Input Prompt | Output ASR | Output Text |
|
45 |
+
| --- | --- | --- | --- |
|
46 |
+
|[zh.wav](https://github.com/luyao-cv/file_download/blob/main/assets/zh.wav) | "描述这段话." |"我认为跑步最重要的就是给我带来了身体健康" |这段话表达了作者认为跑步最重要的好处之一是身体健康。作者认为,通过跑步,身体得到了良好的锻炼,身体健康得到了改善。作者还强调了跑步对身体健康的重要性,并认为这是最值得投资的运动之一。 |
|
47 |
+
|
48 |
+
<div>
|
49 |
+
|
PaddleMIX/applications/Audio2Img/README.md
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 音频生成图像(Audio-to-Image Generation)
|
2 |
+
|
3 |
+
#### 1. Application introduction
|
4 |
+
|
5 |
+
*****
|
6 |
+
|
7 |
+
Generate image from audio(w/ prompt or image) with [ImageBind](https://facebookresearch.github.io/ImageBind/paper)'s unified latent space and stable-diffusion-2-1-unclip.
|
8 |
+
|
9 |
+
- No training is need.
|
10 |
+
- Integration with [ppdiffusers](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/ppdiffusers).
|
11 |
+
|
12 |
+
----
|
13 |
+
|
14 |
+
**Support Tasks**
|
15 |
+
|
16 |
+
- [Audio To Image](#audio-to-image)
|
17 |
+
- [1. Application Introduction](#1-Application)
|
18 |
+
- [2. Run](#2-Run)
|
19 |
+
- [3. Visualization](#3-Visualization)
|
20 |
+
- [Audio to Image](#audio-to-image-1)
|
21 |
+
- [3.1.1 Instruction](#311-Instruction)
|
22 |
+
- [3.1.2 Result](#312-Result)
|
23 |
+
- [Audio+Text to Image](#audiotext-to-image)
|
24 |
+
- [3.2.1 Instruction](#321-Instruction)
|
25 |
+
- [3.2.2 Result](#322-Result)
|
26 |
+
- [Audio+Image to Image](#audioimage-to-image)
|
27 |
+
- [3.3.1 Instruction](#331-Instruction)
|
28 |
+
- [3.3.2 Result](#332-Result)
|
29 |
+
|
30 |
+
----
|
31 |
+
|
32 |
+
**Update**
|
33 |
+
|
34 |
+
[2023/8/15]:
|
35 |
+
- [v0.0]: Support fusing audio, text(prompt) and imnage in ImageBind latent space.
|
36 |
+
|
37 |
+
|
38 |
+
#### 2. Run
|
39 |
+
*****
|
40 |
+
|
41 |
+
example: Use audio generate image across modalities (e.g. Image, Text and Audio) with the model of ImageBind and StableUnCLIPImg2ImgPipeline.
|
42 |
+
|
43 |
+
```python
|
44 |
+
cd applications/Audio2Img
|
45 |
+
|
46 |
+
python audio2img_imagebind.py \
|
47 |
+
--model_name_or_path imagebind-1.2b/ \
|
48 |
+
--stable_unclip_model_name_or_path stabilityai/stable-diffusion-2-1-unclip \
|
49 |
+
--input_audio https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/bird_audio.wav \
|
50 |
+
```
|
51 |
+
|
52 |
+
----
|
53 |
+
#### 3. Visualization
|
54 |
+
----
|
55 |
+
|
56 |
+
#### Audio to Image
|
57 |
+
#### 3.1.1 Instruction
|
58 |
+
|
59 |
+
```python
|
60 |
+
cd applications/Audio2Img
|
61 |
+
|
62 |
+
python audio2img_imagebind.py \
|
63 |
+
--model_name_or_path imagebind-1.2b/ \
|
64 |
+
--stable_unclip_model_name_or_path stabilityai/stable-diffusion-2-1-unclip \
|
65 |
+
--input_audio https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/bird_audio.wav \
|
66 |
+
```
|
67 |
+
#### 3.1.2 Result
|
68 |
+
| Input Audio | Output Image |
|
69 |
+
| --- | --- |
|
70 |
+
|[bird_audio.wav](https://github.com/luyao-cv/file_download/blob/main/assets/bird_audio.wav)|  |
|
71 |
+
|
72 |
+
|
73 |
+
#### Audio+Text to Image
|
74 |
+
#### 3.2.1 Instruction
|
75 |
+
```python
|
76 |
+
cd applications/Audio2Img
|
77 |
+
|
78 |
+
python audio2img_imagebind.py \
|
79 |
+
--model_name_or_path imagebind-1.2b/ \
|
80 |
+
--stable_unclip_model_name_or_path stabilityai/stable-diffusion-2-1-unclip \
|
81 |
+
--input_audio https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/bird_audio.wav \
|
82 |
+
--input_text 'A photo.' \
|
83 |
+
```
|
84 |
+
#### 3.2.2 Result
|
85 |
+
| Input Audio | Input Text | Output Image |
|
86 |
+
| --- | --- | --- |
|
87 |
+
|[bird_audio.wav](https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/bird_audio.wav) | 'A photo.' | 
|
88 |
+
|
89 |
+
|
90 |
+
#### Audio+Image to Image
|
91 |
+
#### 3.3.1 Instruction
|
92 |
+
```python
|
93 |
+
cd applications/Audio2Img
|
94 |
+
|
95 |
+
python audio2img_imagebind.py \
|
96 |
+
--model_name_or_path imagebind-1.2b/ \
|
97 |
+
--stable_unclip_model_name_or_path stabilityai/stable-diffusion-2-1-unclip \
|
98 |
+
--input_audio https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/wave.wav \
|
99 |
+
--input_image https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/dog_image.jpg \
|
100 |
+
```
|
101 |
+
|
102 |
+
#### 3.3.2 Result
|
103 |
+
| Input Audio | Input Image | Output Image |
|
104 |
+
| --- | --- | --- |
|
105 |
+
|[wave.wav](https://paddlenlp.bj.bcebos.com/models/community/paddlemix/audio-files/wave.wav) |  | 
|
106 |
+
|
PaddleMIX/applications/Audio2Img/audio2img_imagebind.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse # noqa: F401
|
16 |
+
import os
|
17 |
+
import sys # noqa: F401
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from types import SimpleNamespace
|
20 |
+
|
21 |
+
import numpy as np # noqa: F401
|
22 |
+
import paddle
|
23 |
+
import requests # noqa: F401
|
24 |
+
from paddlenlp.trainer import PdArgumentParser
|
25 |
+
from PIL import Image
|
26 |
+
|
27 |
+
import paddlemix.models.imagebind as ib # noqa: F401
|
28 |
+
from paddlemix import ImageBindModel, ImageBindProcessor
|
29 |
+
from paddlemix.datasets import * # noqa: F401,F403
|
30 |
+
from paddlemix.models import * # noqa: F401,F403
|
31 |
+
from paddlemix.models.imagebind.modeling import ImageBindModel # noqa: F811
|
32 |
+
from paddlemix.models.imagebind.utils import * # noqa: F401, F403
|
33 |
+
from paddlemix.utils.log import logger
|
34 |
+
from ppdiffusers import StableUnCLIPImg2ImgPipeline
|
35 |
+
from ppdiffusers.utils import load_image
|
36 |
+
|
37 |
+
# from paddlemix.models.imagebind.utils.resample import *
|
38 |
+
# from paddlemix.models.imagebind.utils.paddle_aux import *
|
39 |
+
|
40 |
+
|
41 |
+
ModalityType = SimpleNamespace(
|
42 |
+
VISION="vision",
|
43 |
+
TEXT="text",
|
44 |
+
AUDIO="audio",
|
45 |
+
THERMAL="thermal",
|
46 |
+
DEPTH="depth",
|
47 |
+
IMU="imu",
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
class Predictor:
|
52 |
+
def __init__(self, model_args):
|
53 |
+
self.processor = ImageBindProcessor.from_pretrained(model_args.model_name_or_path)
|
54 |
+
self.predictor = ImageBindModel.from_pretrained(model_args.model_name_or_path)
|
55 |
+
self.predictor.eval()
|
56 |
+
|
57 |
+
def run(self, inputs):
|
58 |
+
with paddle.no_grad():
|
59 |
+
embeddings = self.predictor(inputs)
|
60 |
+
|
61 |
+
return embeddings
|
62 |
+
|
63 |
+
|
64 |
+
def main(model_args, data_args):
|
65 |
+
|
66 |
+
# build model
|
67 |
+
logger.info("imagebind_model: {}".format(model_args.model_name_or_path))
|
68 |
+
url = data_args.input_image
|
69 |
+
if os.path.isfile(url):
|
70 |
+
# read image
|
71 |
+
image_pil = Image.open(data_args.input_image).convert("RGB")
|
72 |
+
elif url:
|
73 |
+
image_pil = load_image(url)
|
74 |
+
else:
|
75 |
+
image_pil = None
|
76 |
+
|
77 |
+
url = data_args.input_audio
|
78 |
+
if os.path.isfile(url):
|
79 |
+
# read image
|
80 |
+
input_audio = data_args.input_audio
|
81 |
+
elif url:
|
82 |
+
os.system("wget {}".format(url))
|
83 |
+
input_audio = os.path.basename(data_args.input_audio)
|
84 |
+
else:
|
85 |
+
input_audio = None
|
86 |
+
|
87 |
+
predictor = Predictor(model_args)
|
88 |
+
|
89 |
+
encoding = predictor.processor(images=image_pil, text="", audios=input_audio, return_tensors="pd")
|
90 |
+
inputs = {}
|
91 |
+
|
92 |
+
if image_pil:
|
93 |
+
image_processor = encoding["pixel_values"]
|
94 |
+
inputs.update({ModalityType.VISION: image_processor})
|
95 |
+
if data_args.input_audio:
|
96 |
+
audio_processor = encoding["audio_values"]
|
97 |
+
inputs.update({ModalityType.AUDIO: audio_processor})
|
98 |
+
|
99 |
+
embeddings = predictor.run(inputs)
|
100 |
+
image_proj_embeds = embeddings[ModalityType.AUDIO]
|
101 |
+
|
102 |
+
if image_pil:
|
103 |
+
logger.info("Generate vision embedding: {}".format(embeddings[ModalityType.VISION]))
|
104 |
+
image_proj_embeds += embeddings[ModalityType.VISION]
|
105 |
+
|
106 |
+
if data_args.input_audio:
|
107 |
+
logger.info("Generate audio embedding: {}".format(embeddings[ModalityType.AUDIO]))
|
108 |
+
|
109 |
+
prompt = data_args.input_text
|
110 |
+
|
111 |
+
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(model_args.stable_unclip_model_name_or_path)
|
112 |
+
pipe.set_progress_bar_config(disable=None)
|
113 |
+
|
114 |
+
output = pipe(image_embeds=image_proj_embeds, prompt=prompt)
|
115 |
+
os.makedirs(model_args.output_dir, exist_ok=True)
|
116 |
+
|
117 |
+
save_path = os.path.join(model_args.output_dir, "audio2img_imagebind_output.jpg")
|
118 |
+
logger.info("Generate image to: {}".format(save_path))
|
119 |
+
output.images[0].save(save_path)
|
120 |
+
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class DataArguments:
|
124 |
+
"""
|
125 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
126 |
+
Using `PdArgumentParser` we can turn this class
|
127 |
+
into argparse arguments to be able to specify them on
|
128 |
+
the command line.
|
129 |
+
"""
|
130 |
+
|
131 |
+
input_text: str = field(default="", metadata={"help": "The name of prompt input."})
|
132 |
+
input_image: str = field(
|
133 |
+
default="",
|
134 |
+
# wget https://github.com/facebookresearch/ImageBind/blob/main/.assets/bird_image.jpg
|
135 |
+
metadata={"help": "The name of image input."},
|
136 |
+
)
|
137 |
+
input_audio: str = field(
|
138 |
+
default="",
|
139 |
+
# wget https://github.com/facebookresearch/ImageBind/blob/main/.assets/bird_audio.wav
|
140 |
+
metadata={"help": "The name of audio input."},
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
@dataclass
|
145 |
+
class ModelArguments:
|
146 |
+
"""
|
147 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
148 |
+
"""
|
149 |
+
|
150 |
+
model_name_or_path: str = field(
|
151 |
+
default="imagebind-1.2b/",
|
152 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
153 |
+
)
|
154 |
+
|
155 |
+
stable_unclip_model_name_or_path: str = field(
|
156 |
+
default="stabilityai/stable-diffusion-2-1-unclip",
|
157 |
+
metadata={"help": "Path to pretrained model or model identifier in stable_unclip_model_name_or_path"},
|
158 |
+
)
|
159 |
+
|
160 |
+
output_dir: str = field(default="vis_audio2img", metadata={"help": "The name of imagebind audio input."})
|
161 |
+
|
162 |
+
device: str = field(
|
163 |
+
default="GPU",
|
164 |
+
metadata={"help": "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."},
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
|
170 |
+
parser = PdArgumentParser((ModelArguments, DataArguments))
|
171 |
+
model_args, data_args = parser.parse_args_into_dataclasses()
|
172 |
+
|
173 |
+
model_args.device = model_args.device.upper()
|
174 |
+
assert model_args.device in ["CPU", "GPU", "XPU", "NPU"], "device should be CPU, GPU, XPU or NPU"
|
175 |
+
|
176 |
+
main(model_args, data_args)
|
PaddleMIX/applications/Audio2Img/gradio_demo.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
from types import SimpleNamespace
|
17 |
+
|
18 |
+
import gradio as gr
|
19 |
+
import paddle
|
20 |
+
|
21 |
+
from paddlemix import ImageBindModel, ImageBindProcessor
|
22 |
+
from paddlemix.utils.log import logger
|
23 |
+
from ppdiffusers import StableUnCLIPImg2ImgPipeline
|
24 |
+
|
25 |
+
ModalityType = SimpleNamespace(
|
26 |
+
VISION="vision",
|
27 |
+
TEXT="text",
|
28 |
+
AUDIO="audio",
|
29 |
+
THERMAL="thermal",
|
30 |
+
DEPTH="depth",
|
31 |
+
IMU="imu",
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class Predictor:
|
36 |
+
def __init__(self, model_args):
|
37 |
+
self.processor = ImageBindProcessor.from_pretrained(model_args.model_name_or_path)
|
38 |
+
self.predictor = ImageBindModel.from_pretrained(model_args.model_name_or_path)
|
39 |
+
self.predictor.eval()
|
40 |
+
|
41 |
+
def run(self, inputs):
|
42 |
+
with paddle.no_grad():
|
43 |
+
embeddings = self.predictor(inputs)
|
44 |
+
return embeddings
|
45 |
+
|
46 |
+
|
47 |
+
def model_init(model_args):
|
48 |
+
predictor = Predictor(model_args)
|
49 |
+
return predictor
|
50 |
+
|
51 |
+
|
52 |
+
def infer(input_image, input_audio, input_text):
|
53 |
+
|
54 |
+
global predictor
|
55 |
+
image_pil = input_image
|
56 |
+
|
57 |
+
encoding = predictor.processor(images=image_pil, text="", audios=input_audio, return_tensors="pd")
|
58 |
+
inputs = {}
|
59 |
+
|
60 |
+
if image_pil is not None:
|
61 |
+
image_processor = encoding["pixel_values"]
|
62 |
+
inputs.update({ModalityType.VISION: image_processor})
|
63 |
+
|
64 |
+
if input_audio is not None:
|
65 |
+
audio_processor = encoding["audio_values"]
|
66 |
+
inputs.update({ModalityType.AUDIO: audio_processor})
|
67 |
+
else:
|
68 |
+
pass
|
69 |
+
|
70 |
+
embeddings = predictor.run(inputs)
|
71 |
+
image_proj_embeds = embeddings[ModalityType.AUDIO]
|
72 |
+
|
73 |
+
if image_pil is not None:
|
74 |
+
logger.info("Generate vision embedding: {}".format(embeddings[ModalityType.VISION]))
|
75 |
+
image_proj_embeds += embeddings[ModalityType.VISION]
|
76 |
+
|
77 |
+
logger.info("Generate audio embedding: {}".format(embeddings[ModalityType.AUDIO]))
|
78 |
+
|
79 |
+
if input_text is not None:
|
80 |
+
prompt = input_text
|
81 |
+
else:
|
82 |
+
prompt = ""
|
83 |
+
|
84 |
+
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(model_args.stable_unclip_model_name_or_path)
|
85 |
+
pipe.set_progress_bar_config(disable=None)
|
86 |
+
output = pipe(image_embeds=image_proj_embeds, prompt=prompt)
|
87 |
+
|
88 |
+
return output.images[0]
|
89 |
+
|
90 |
+
|
91 |
+
def parse_arguments():
|
92 |
+
parser = argparse.ArgumentParser()
|
93 |
+
parser.add_argument(
|
94 |
+
"--model_name_or_path",
|
95 |
+
type=str,
|
96 |
+
default="imagebind-1.2b/",
|
97 |
+
help="Path to pretrained model or model identifier",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--stable_unclip_model_name_or_path",
|
101 |
+
type=str,
|
102 |
+
default="stabilityai/stable-diffusion-2-1-unclip",
|
103 |
+
help="Path to pretrained model or model identifier in stable_unclip_model_name_or_path",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--device",
|
107 |
+
type=str,
|
108 |
+
default="GPU",
|
109 |
+
choices=["CPU", "GPU", "XPU"],
|
110 |
+
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU.",
|
111 |
+
)
|
112 |
+
return parser.parse_args()
|
113 |
+
|
114 |
+
|
115 |
+
with gr.Blocks() as demo:
|
116 |
+
gr.Markdown("音频生成图像(Audio-to-Image Generation)")
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column():
|
119 |
+
input_audio = gr.Audio(label="input audio", type="filepath")
|
120 |
+
with gr.Tab(label="input text(可选)") as txttab:
|
121 |
+
input_text = gr.Textbox(label="input text")
|
122 |
+
with gr.Tab(label="input image(可选)") as imgtab:
|
123 |
+
input_image = gr.Image(label="input image")
|
124 |
+
infer_button = gr.Button("推理")
|
125 |
+
output_image = gr.Image(label="result")
|
126 |
+
txttab.select(fn=lambda: None, outputs=input_image)
|
127 |
+
imgtab.select(fn=lambda: None, outputs=input_text)
|
128 |
+
infer_button.click(fn=infer, inputs=[input_image, input_audio, input_text], outputs=[output_image])
|
129 |
+
if __name__ == "__main__":
|
130 |
+
|
131 |
+
model_args = parse_arguments()
|
132 |
+
assert model_args.device in ["CPU", "GPU", "XPU", "NPU"], "device should be CPU, GPU, XPU or NPU"
|
133 |
+
predictor = model_init(model_args)
|
134 |
+
|
135 |
+
demo.launch()
|
PaddleMIX/applications/AudioChat/README.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 音频对话(Audio-to-Chat Generation)
|
2 |
+
|
3 |
+
#### 1. Application introduction
|
4 |
+
|
5 |
+
Enter audio and prompt words for question and answer.
|
6 |
+
|
7 |
+
*****
|
8 |
+
- No training is need.
|
9 |
+
- Integration with the moedel of [conformer_u2pp_online_wenetspeech](), [chatglm](). [fastspeech2]().
|
10 |
+
|
11 |
+
----
|
12 |
+
|
13 |
+
#### 2. Demo
|
14 |
+
*****
|
15 |
+
example:
|
16 |
+
|
17 |
+
```python
|
18 |
+
#audio_chat
|
19 |
+
from paddlemix.appflow import Appflow
|
20 |
+
import paddle
|
21 |
+
paddle.seed(1024)
|
22 |
+
task = Appflow(app="audio_chat", models=["conformer_u2pp_online_wenetspeech", "THUDM/chatglm-6b", "speech"])
|
23 |
+
audio_file = "./zh.wav"
|
24 |
+
prompt = (
|
25 |
+
"描述这段话:{}."
|
26 |
+
)
|
27 |
+
output_path = "tmp.wav"
|
28 |
+
result = task(audio=audio_file, prompt=prompt, output=output_path)
|
29 |
+
|
30 |
+
# 这段话表达了作者认为跑步最重要的好处之一是身体健康。作者认为,通过跑步,身体得到了良好的锻炼,身体健康得到了改善。作者还强调了跑步对身体健康的重要性,并认为这是最值得投资的运动之一。
|
31 |
+
|
32 |
+
```
|
33 |
+
|
34 |
+
| Input Audio | Input Prompt |Output Text| Output Audio|
|
35 |
+
| --- | --- | --- | --- |
|
36 |
+
|[zh.wav](https://github.com/luyao-cv/file_download/blob/main/assets/zh.wav) | "描述这段话." |"这段话表达了作者认为跑步最重要的好处之一是身体健康。作者认为,通过跑步,身体得到了良好的锻炼,身体健康得到了改善。作者还强调了跑步对身体健康的重要性,并认为这是最值得投资的运动之一。" |[audiochat-result.wav](https://github.com/luyao-cv/file_download/blob/main/assets/audiochat-result.wav)|
|
PaddleMIX/applications/Automatic_label/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
### 自动标注(AutoLabel)
|
4 |
+
|
5 |
+
`automatic_label` 示例:
|
6 |
+
|
7 |
+
```python
|
8 |
+
python applications/Automatic_label/automatic_label.py
|
9 |
+
```
|
10 |
+
|
11 |
+
效果展示
|
12 |
+
|
13 |
+
<div align="center">
|
14 |
+
|
15 |
+
| Input Image | prompt| Generate Description | annotated image|
|
16 |
+
|:----:|:----:|:----:|:----:|
|
17 |
+
| |describe the image| of the dog sitting on the bench in the field | |
|
18 |
+
| |describe the image| of the horse in the field with the mountains in the background | |
|
19 |
+
</div>
|
PaddleMIX/applications/Automatic_label/automatic_label.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from paddlemix.appflow import Appflow
|
19 |
+
from ppdiffusers.utils import load_image
|
20 |
+
|
21 |
+
|
22 |
+
def show_mask(mask, ax, random_color=False):
|
23 |
+
if random_color:
|
24 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
25 |
+
else:
|
26 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
27 |
+
h, w = mask.shape[-2:]
|
28 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
29 |
+
ax.imshow(mask_image)
|
30 |
+
|
31 |
+
|
32 |
+
def show_box(box, ax, label):
|
33 |
+
x0, y0 = box[0], box[1]
|
34 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
35 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
|
36 |
+
ax.text(x0, y0, label)
|
37 |
+
|
38 |
+
|
39 |
+
task = Appflow(
|
40 |
+
app="auto_label",
|
41 |
+
models=["paddlemix/blip2-caption-opt2.7b", "GroundingDino/groundingdino-swint-ogc", "Sam/SamVitH-1024"],
|
42 |
+
)
|
43 |
+
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
|
44 |
+
image_pil = load_image(url)
|
45 |
+
blip2_prompt = "describe the image"
|
46 |
+
result = task(image=image_pil, blip2_prompt=blip2_prompt)
|
47 |
+
|
48 |
+
plt.figure(figsize=(10, 10))
|
49 |
+
plt.imshow(result["image"])
|
50 |
+
for mask in result["seg_masks"]:
|
51 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
52 |
+
for box, label in zip(result["boxes"], result["labels"]):
|
53 |
+
show_box(box, plt.gca(), label)
|
54 |
+
|
55 |
+
plt.axis("off")
|
56 |
+
plt.savefig(
|
57 |
+
"mask_pred.jpg",
|
58 |
+
bbox_inches="tight",
|
59 |
+
dpi=300,
|
60 |
+
pad_inches=0.0,
|
61 |
+
)
|
PaddleMIX/applications/CVinW/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
### 开放世界检测分割(Grounded-SAM: Detect and Segment Everything with Text Prompt)
|
4 |
+
|
5 |
+
`Grounded-SAM` 示例:
|
6 |
+
|
7 |
+
```python
|
8 |
+
python applications/CVinW/grounded_sam.py
|
9 |
+
```
|
10 |
+
|
11 |
+
效果展示
|
12 |
+
|
13 |
+
<div align="center">
|
14 |
+
|
15 |
+
| Text prompt | Input Image | Generated Mask |
|
16 |
+
|:----:|:----:|:----:|
|
17 |
+
| dog |  |  |
|
18 |
+
| horse,grasses,sky |  |  |
|
19 |
+
</div>
|
PaddleMIX/applications/CVinW/grounded_sam.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from paddlemix.appflow import Appflow
|
19 |
+
from ppdiffusers.utils import load_image
|
20 |
+
|
21 |
+
|
22 |
+
def show_mask(mask, ax, random_color=False):
|
23 |
+
if random_color:
|
24 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
25 |
+
else:
|
26 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
27 |
+
h, w = mask.shape[-2:]
|
28 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
29 |
+
ax.imshow(mask_image)
|
30 |
+
|
31 |
+
|
32 |
+
task = Appflow(
|
33 |
+
app="openset_det_sam", models=["GroundingDino/groundingdino-swint-ogc", "Sam/SamVitH-1024"], static_mode=False
|
34 |
+
) # 如果开启静态图推理,设置为True,默认动态图
|
35 |
+
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
|
36 |
+
image_pil = load_image(url)
|
37 |
+
result = task(image=image_pil, prompt="dog")
|
38 |
+
|
39 |
+
plt.figure(figsize=(10, 10))
|
40 |
+
plt.imshow(image_pil)
|
41 |
+
for mask in result["seg_masks"]:
|
42 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
43 |
+
|
44 |
+
|
45 |
+
plt.axis("off")
|
46 |
+
plt.savefig("dog.jpg", bbox_inches="tight", dpi=300, pad_inches=0.0)
|
PaddleMIX/applications/Inpainting/README.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
### 检测框引导的图像编辑(Det-Guided-Inpainting)
|
4 |
+
|
5 |
+
`Grounded-SAM-Inpainting` 示例:
|
6 |
+
|
7 |
+
```python
|
8 |
+
from paddlemix.appflow import Appflow
|
9 |
+
from ppdiffusers.utils import load_image
|
10 |
+
import paddle
|
11 |
+
task = Appflow(app="inpainting",
|
12 |
+
models=["GroundingDino/groundingdino-swint-ogc","Sam/SamVitH-1024","stabilityai/stable-diffusion-2-inpainting"]
|
13 |
+
)
|
14 |
+
paddle.seed(1024)
|
15 |
+
url = "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg"
|
16 |
+
image_pil = load_image(url)
|
17 |
+
result = task(image=image_pil,prompt="bus",inpaint_prompt="a yellow van")
|
18 |
+
```
|
19 |
+
<div align="center">
|
20 |
+
|
21 |
+
| Input Image | Det Prompt | Generated Mask | Inpaint Prompt | Inpaint Image |
|
22 |
+
|:----:|:----:|:----:|:----:|:----:|
|
23 |
+
|  | bus | | a yellow van | |
|
24 |
+
|
25 |
+
</div>
|
26 |
+
|
27 |
+
|
28 |
+
### 文本检测框引导的图像编辑(ChatAndDet-Guided-Inpainting)
|
29 |
+
`Grounded-SAM-chatglm` 示例:
|
30 |
+
|
31 |
+
```python
|
32 |
+
import paddle
|
33 |
+
from paddlemix.appflow import Appflow
|
34 |
+
from ppdiffusers.utils import load_image
|
35 |
+
task = Appflow(app="inpainting",
|
36 |
+
models=["THUDM/chatglm-6b","GroundingDino/groundingdino-swint-ogc","Sam/SamVitH-1024","stabilityai/stable-diffusion-2-inpainting"]
|
37 |
+
)
|
38 |
+
paddle.seed(1024)
|
39 |
+
url = "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg"
|
40 |
+
image_pil = load_image(url)
|
41 |
+
inpaint_prompt = "bus is changed to A school bus parked on the roadside"
|
42 |
+
prompt = "Given caption,extract the main object to be replaced and marked it as 'main_object'," \
|
43 |
+
+ "Extract the remaining part as 'other prompt', " \
|
44 |
+
+ "Return main_object, other prompt in English" \
|
45 |
+
+ "Given caption: {}.".format(inpaint_prompt)
|
46 |
+
result = task(image=image_pil,prompt=prompt)
|
47 |
+
```
|
48 |
+
|
49 |
+
一些效果展示
|
50 |
+
|
51 |
+
<div align="center">
|
52 |
+
|
53 |
+
| Input Image | Prompt | Generated Mask | Inpaint Prompt |
|
54 |
+
|:----:|:----:|:----:|:----:|
|
55 |
+
|  | bus is changed to A school bus parked on the roadside | | |
|
56 |
+
|
57 |
+
</div>
|
58 |
+
|
59 |
+
### 文本引导的图像编辑(Text-Guided Image Inpainting)
|
60 |
+
|
61 |
+
```python
|
62 |
+
import paddle
|
63 |
+
from paddlemix.appflow import Appflow
|
64 |
+
from PIL import Image
|
65 |
+
from ppdiffusers.utils import load_image
|
66 |
+
img_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
|
67 |
+
mask_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations-mask.png"
|
68 |
+
|
69 |
+
image = load_image(img_url)
|
70 |
+
mask_image = load_image(mask_url)
|
71 |
+
paddle.seed(1024)
|
72 |
+
|
73 |
+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
74 |
+
|
75 |
+
app = Appflow(app='inpainting',models=['stabilityai/stable-diffusion-2-inpainting'])
|
76 |
+
image = app(inpaint_prompt=prompt,image=image,seg_masks=mask_image)['result']
|
77 |
+
|
78 |
+
image.save("a_yellow_cat.png")
|
79 |
+
```
|
80 |
+
|
81 |
+
<div align="center">
|
82 |
+
|
83 |
+
| Input Image | Inpaint Prompt | Mask | Inpaint Image |
|
84 |
+
|:----:|:----:|:----:|:----:|
|
85 |
+
|  | Face of a yellow cat, high resolution, sitting on a park bench| | |
|
86 |
+
|
87 |
+
</div>
|
PaddleMIX/applications/Inpainting/grounded_sam_chatglm.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import numpy as np
|
20 |
+
import paddle
|
21 |
+
import paddle.nn.functional as F
|
22 |
+
import requests
|
23 |
+
from paddlenlp import Taskflow
|
24 |
+
from paddlenlp.trainer import PdArgumentParser
|
25 |
+
from PIL import Image
|
26 |
+
|
27 |
+
from paddlemix.models.groundingdino.modeling import GroundingDinoModel
|
28 |
+
from paddlemix.models.sam.modeling import SamModel
|
29 |
+
from paddlemix.processors.groundingdino_processing import GroundingDinoProcessor
|
30 |
+
from paddlemix.processors.sam_processing import SamProcessor
|
31 |
+
from paddlemix.utils.log import logger
|
32 |
+
from ppdiffusers import StableDiffusionInpaintPipeline
|
33 |
+
|
34 |
+
|
35 |
+
def show_mask(mask, ax, random_color=False):
|
36 |
+
if random_color:
|
37 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
38 |
+
else:
|
39 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
40 |
+
h, w = mask.shape[-2:]
|
41 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
42 |
+
ax.imshow(mask_image)
|
43 |
+
|
44 |
+
|
45 |
+
def show_box(box, ax, label):
|
46 |
+
x0, y0 = box[0], box[1]
|
47 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
48 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
|
49 |
+
ax.text(x0, y0, label)
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class DataArguments:
|
54 |
+
"""
|
55 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
56 |
+
Using `PdArgumentParser` we can turn this class
|
57 |
+
into argparse arguments to be able to specify them on
|
58 |
+
the command line.
|
59 |
+
"""
|
60 |
+
|
61 |
+
input_image: str = field(
|
62 |
+
metadata={"help": "The name of input image."},
|
63 |
+
)
|
64 |
+
|
65 |
+
prompt: str = field(
|
66 |
+
default=None,
|
67 |
+
metadata={"help": "The prompt of the image to be inpaint."},
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class ModelArguments:
|
73 |
+
"""
|
74 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
75 |
+
"""
|
76 |
+
|
77 |
+
stable_diffusion_pipeline_name_or_path: str = field(
|
78 |
+
default="stabilityai/stable-diffusion-2-inpainting",
|
79 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
80 |
+
)
|
81 |
+
dino_model_name_or_path: str = field(
|
82 |
+
default="GroundingDino/groundingdino-swint-ogc",
|
83 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
84 |
+
)
|
85 |
+
sam_model_name_or_path: str = field(
|
86 |
+
default="Sam/SamVitH-1024",
|
87 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
88 |
+
)
|
89 |
+
chatglm_model_name_or_path: str = field(
|
90 |
+
default="THUDM/chatglm-6b",
|
91 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
92 |
+
)
|
93 |
+
box_threshold: float = field(
|
94 |
+
default=0.3,
|
95 |
+
metadata={"help": "box threshold."},
|
96 |
+
)
|
97 |
+
text_threshold: float = field(
|
98 |
+
default=0.25,
|
99 |
+
metadata={"help": "text threshold."},
|
100 |
+
)
|
101 |
+
output_dir: str = field(
|
102 |
+
default="inpainting_output",
|
103 |
+
metadata={"help": "output directory."},
|
104 |
+
)
|
105 |
+
visual: bool = field(
|
106 |
+
default=True,
|
107 |
+
metadata={"help": "save visual image."},
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def filter_prompts_with_chatglm(caption, model_name_or_path="THUDM/chatglm-6b"):
|
112 |
+
prompt = (
|
113 |
+
"Given caption,extract the main object to be replaced and marked it as 'main_object', "
|
114 |
+
+ "Extract the remaining part as 'other prompt', "
|
115 |
+
+ "Return main_object, other prompt in English"
|
116 |
+
+ "Given caption: {}.".format(caption)
|
117 |
+
)
|
118 |
+
|
119 |
+
logger.info("chatglm: {}".format(model_name_or_path))
|
120 |
+
textGen = Taskflow("text2text_generation", model=model_name_or_path)
|
121 |
+
|
122 |
+
reply = textGen(prompt)["result"][0]
|
123 |
+
|
124 |
+
det_prompt, inpaint_prompt = (
|
125 |
+
reply.split("\n")[0].split(":")[-1].strip(),
|
126 |
+
reply.split("\n")[-1].split(":")[-1].strip(),
|
127 |
+
)
|
128 |
+
|
129 |
+
return det_prompt, inpaint_prompt
|
130 |
+
|
131 |
+
|
132 |
+
def main():
|
133 |
+
parser = PdArgumentParser((ModelArguments, DataArguments))
|
134 |
+
model_args, data_args = parser.parse_args_into_dataclasses()
|
135 |
+
url = data_args.input_image
|
136 |
+
|
137 |
+
logger.info("dino_model: {}".format(model_args.dino_model_name_or_path))
|
138 |
+
# build dino processor
|
139 |
+
dino_processor = GroundingDinoProcessor.from_pretrained(model_args.dino_model_name_or_path)
|
140 |
+
# build dino model
|
141 |
+
dino_model = GroundingDinoModel.from_pretrained(model_args.dino_model_name_or_path)
|
142 |
+
dino_model.eval()
|
143 |
+
logger.info("dino_model build finish!")
|
144 |
+
|
145 |
+
# build sam processor
|
146 |
+
sam_processor = SamProcessor.from_pretrained(model_args.sam_model_name_or_path)
|
147 |
+
# build model
|
148 |
+
logger.info("SamModel: {}".format(model_args.sam_model_name_or_path))
|
149 |
+
sam_model = SamModel.from_pretrained(model_args.sam_model_name_or_path, input_type="boxs")
|
150 |
+
logger.info("SamModel build finish!")
|
151 |
+
|
152 |
+
# read image
|
153 |
+
if os.path.isfile(url):
|
154 |
+
# read image
|
155 |
+
image_pil = Image.open(url)
|
156 |
+
else:
|
157 |
+
image_pil = Image.open(requests.get(url, stream=True).raw)
|
158 |
+
|
159 |
+
det_prompt, inpaint_prompt = filter_prompts_with_chatglm(data_args.prompt, model_args.chatglm_model_name_or_path)
|
160 |
+
logger.info("det prompt: {}".format(det_prompt))
|
161 |
+
logger.info("inpaint prompt: {}".format(inpaint_prompt))
|
162 |
+
|
163 |
+
image_pil = image_pil.convert("RGB")
|
164 |
+
|
165 |
+
# preprocess image text_prompt
|
166 |
+
image_tensor, mask, tokenized_out = dino_processor(images=image_pil, text=det_prompt)
|
167 |
+
|
168 |
+
with paddle.no_grad():
|
169 |
+
outputs = dino_model(
|
170 |
+
image_tensor,
|
171 |
+
mask,
|
172 |
+
input_ids=tokenized_out["input_ids"],
|
173 |
+
attention_mask=tokenized_out["attention_mask"],
|
174 |
+
text_self_attention_masks=tokenized_out["text_self_attention_masks"],
|
175 |
+
position_ids=tokenized_out["position_ids"],
|
176 |
+
)
|
177 |
+
|
178 |
+
logits = F.sigmoid(outputs["pred_logits"])[0] # (nq, 256)
|
179 |
+
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
180 |
+
|
181 |
+
# filter output
|
182 |
+
logits_filt = logits.clone()
|
183 |
+
boxes_filt = boxes.clone()
|
184 |
+
filt_mask = logits_filt.max(axis=1) > model_args.box_threshold
|
185 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
186 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
187 |
+
|
188 |
+
# build pred
|
189 |
+
pred_phrases = []
|
190 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
191 |
+
pred_phrase = dino_processor.decode(logit > model_args.text_threshold)
|
192 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
193 |
+
|
194 |
+
size = image_pil.size
|
195 |
+
pred_dict = {
|
196 |
+
"boxes": boxes_filt,
|
197 |
+
"size": [size[1], size[0]], # H,W
|
198 |
+
"labels": pred_phrases,
|
199 |
+
}
|
200 |
+
logger.info("dino output{}".format(pred_dict))
|
201 |
+
|
202 |
+
H, W = size[1], size[0]
|
203 |
+
boxes = []
|
204 |
+
for box in zip(boxes_filt):
|
205 |
+
box = box[0] * paddle.to_tensor([W, H, W, H])
|
206 |
+
box[:2] -= box[2:] / 2
|
207 |
+
box[2:] += box[:2]
|
208 |
+
x0, y0, x1, y1 = box.numpy()
|
209 |
+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
210 |
+
boxes.append([x0, y0, x1, y1])
|
211 |
+
boxes = np.array(boxes)
|
212 |
+
image_seg, prompt = sam_processor(image_pil, input_type="boxs", box=boxes, point_coords=None)
|
213 |
+
seg_masks = sam_model(img=image_seg, prompt=prompt)
|
214 |
+
seg_masks = sam_processor.postprocess_masks(seg_masks)
|
215 |
+
|
216 |
+
logger.info("Sam finish!")
|
217 |
+
|
218 |
+
if model_args.visual:
|
219 |
+
# make dir
|
220 |
+
os.makedirs(model_args.output_dir, exist_ok=True)
|
221 |
+
# draw output image
|
222 |
+
plt.figure(figsize=(10, 10))
|
223 |
+
plt.imshow(image_pil)
|
224 |
+
for mask in seg_masks:
|
225 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
226 |
+
for box, label in zip(boxes, pred_phrases):
|
227 |
+
show_box(box, plt.gca(), label)
|
228 |
+
|
229 |
+
plt.axis("off")
|
230 |
+
plt.savefig(
|
231 |
+
os.path.join(model_args.output_dir, "mask_pred.jpg"),
|
232 |
+
bbox_inches="tight",
|
233 |
+
dpi=300,
|
234 |
+
pad_inches=0.0,
|
235 |
+
)
|
236 |
+
|
237 |
+
logger.info("stable diffusion pipeline: {}".format(model_args.stable_diffusion_pipeline_name_or_path))
|
238 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_args.stable_diffusion_pipeline_name_or_path)
|
239 |
+
logger.info("stable diffusion pipeline build finish!")
|
240 |
+
|
241 |
+
merge_mask = paddle.sum(seg_masks, axis=0).unsqueeze(0)
|
242 |
+
merge_mask = merge_mask > 0
|
243 |
+
mask_pil = Image.fromarray(merge_mask[0][0].cpu().numpy())
|
244 |
+
|
245 |
+
image_pil = image_pil.resize((512, 512))
|
246 |
+
mask_pil = mask_pil.resize((512, 512))
|
247 |
+
|
248 |
+
image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
|
249 |
+
image = image.resize(size)
|
250 |
+
image.save(os.path.join(model_args.output_dir, "grounded_sam_chatglm_output.jpg"))
|
251 |
+
|
252 |
+
logger.info("finish!")
|
253 |
+
|
254 |
+
|
255 |
+
if __name__ == "__main__":
|
256 |
+
main()
|
PaddleMIX/applications/Inpainting/grounded_sam_inpainting.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import numpy as np
|
20 |
+
import paddle
|
21 |
+
import paddle.nn.functional as F
|
22 |
+
import requests
|
23 |
+
from paddlenlp.trainer import PdArgumentParser
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
from paddlemix.models.groundingdino.modeling import GroundingDinoModel
|
27 |
+
from paddlemix.models.sam.modeling import SamModel
|
28 |
+
from paddlemix.processors.groundingdino_processing import GroundingDinoProcessor
|
29 |
+
from paddlemix.processors.sam_processing import SamProcessor
|
30 |
+
from paddlemix.utils.log import logger
|
31 |
+
from ppdiffusers import StableDiffusionInpaintPipeline
|
32 |
+
|
33 |
+
|
34 |
+
def show_mask(mask, ax, random_color=False):
|
35 |
+
if random_color:
|
36 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
37 |
+
else:
|
38 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
39 |
+
h, w = mask.shape[-2:]
|
40 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
41 |
+
ax.imshow(mask_image)
|
42 |
+
|
43 |
+
|
44 |
+
def show_box(box, ax, label):
|
45 |
+
x0, y0 = box[0], box[1]
|
46 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
47 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
|
48 |
+
ax.text(x0, y0, label)
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class DataArguments:
|
53 |
+
"""
|
54 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
55 |
+
Using `PdArgumentParser` we can turn this class
|
56 |
+
into argparse arguments to be able to specify them on
|
57 |
+
the command line.
|
58 |
+
"""
|
59 |
+
|
60 |
+
input_image: str = field(
|
61 |
+
metadata={"help": "The name of input image."},
|
62 |
+
)
|
63 |
+
|
64 |
+
det_prompt: str = field(
|
65 |
+
default=None,
|
66 |
+
metadata={"help": "The prompt of the image to be det."},
|
67 |
+
)
|
68 |
+
|
69 |
+
inpaint_prompt: str = field(
|
70 |
+
default=None,
|
71 |
+
metadata={"help": "The prompt of the image to be inpaint."},
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class ModelArguments:
|
77 |
+
"""
|
78 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
79 |
+
"""
|
80 |
+
|
81 |
+
stable_diffusion_pipeline_name_or_path: str = field(
|
82 |
+
default="stabilityai/stable-diffusion-2-inpainting",
|
83 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
84 |
+
)
|
85 |
+
dino_model_name_or_path: str = field(
|
86 |
+
default="GroundingDino/groundingdino-swint-ogc",
|
87 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
88 |
+
)
|
89 |
+
sam_model_name_or_path: str = field(
|
90 |
+
default="Sam/SamVitH-1024",
|
91 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
92 |
+
)
|
93 |
+
box_threshold: float = field(
|
94 |
+
default=0.3,
|
95 |
+
metadata={"help": "box threshold."},
|
96 |
+
)
|
97 |
+
text_threshold: float = field(
|
98 |
+
default=0.25,
|
99 |
+
metadata={"help": "text threshold."},
|
100 |
+
)
|
101 |
+
output_dir: str = field(
|
102 |
+
default="inpainting_output",
|
103 |
+
metadata={"help": "output directory."},
|
104 |
+
)
|
105 |
+
visual: bool = field(
|
106 |
+
default=True,
|
107 |
+
metadata={"help": "save visual image."},
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def main():
|
112 |
+
parser = PdArgumentParser((ModelArguments, DataArguments))
|
113 |
+
model_args, data_args = parser.parse_args_into_dataclasses()
|
114 |
+
url = data_args.input_image
|
115 |
+
|
116 |
+
logger.info("stable diffusion pipeline: {}".format(model_args.stable_diffusion_pipeline_name_or_path))
|
117 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_args.stable_diffusion_pipeline_name_or_path)
|
118 |
+
logger.info("stable diffusion pipeline build finish!")
|
119 |
+
|
120 |
+
logger.info("dino_model: {}".format(model_args.dino_model_name_or_path))
|
121 |
+
# build dino processor
|
122 |
+
dino_processor = GroundingDinoProcessor.from_pretrained(model_args.dino_model_name_or_path)
|
123 |
+
# build dino model
|
124 |
+
dino_model = GroundingDinoModel.from_pretrained(model_args.dino_model_name_or_path)
|
125 |
+
dino_model.eval()
|
126 |
+
logger.info("dino_model build finish!")
|
127 |
+
|
128 |
+
# build sam processor
|
129 |
+
sam_processor = SamProcessor.from_pretrained(model_args.sam_model_name_or_path)
|
130 |
+
# build model
|
131 |
+
logger.info("SamModel: {}".format(model_args.sam_model_name_or_path))
|
132 |
+
sam_model = SamModel.from_pretrained(model_args.sam_model_name_or_path, input_type="boxs")
|
133 |
+
logger.info("SamModel build finish!")
|
134 |
+
|
135 |
+
# read image
|
136 |
+
if os.path.isfile(url):
|
137 |
+
# read image
|
138 |
+
image_pil = Image.open(url)
|
139 |
+
else:
|
140 |
+
image_pil = Image.open(requests.get(url, stream=True).raw)
|
141 |
+
|
142 |
+
logger.info("det prompt: {}".format(data_args.det_prompt))
|
143 |
+
logger.info("inpaint prompt: {}".format(data_args.inpaint_prompt))
|
144 |
+
|
145 |
+
image_pil = image_pil.convert("RGB")
|
146 |
+
|
147 |
+
# preprocess image text_prompt
|
148 |
+
image_tensor, mask, tokenized_out = dino_processor(images=image_pil, text=data_args.det_prompt)
|
149 |
+
|
150 |
+
with paddle.no_grad():
|
151 |
+
outputs = dino_model(
|
152 |
+
image_tensor,
|
153 |
+
mask,
|
154 |
+
input_ids=tokenized_out["input_ids"],
|
155 |
+
attention_mask=tokenized_out["attention_mask"],
|
156 |
+
text_self_attention_masks=tokenized_out["text_self_attention_masks"],
|
157 |
+
position_ids=tokenized_out["position_ids"],
|
158 |
+
)
|
159 |
+
|
160 |
+
logits = F.sigmoid(outputs["pred_logits"])[0] # (nq, 256)
|
161 |
+
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
162 |
+
|
163 |
+
# filter output
|
164 |
+
logits_filt = logits.clone()
|
165 |
+
boxes_filt = boxes.clone()
|
166 |
+
filt_mask = logits_filt.max(axis=1) > model_args.box_threshold
|
167 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
168 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
169 |
+
|
170 |
+
# build pred
|
171 |
+
pred_phrases = []
|
172 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
173 |
+
pred_phrase = dino_processor.decode(logit > model_args.text_threshold)
|
174 |
+
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
175 |
+
|
176 |
+
size = image_pil.size
|
177 |
+
pred_dict = {
|
178 |
+
"boxes": boxes_filt,
|
179 |
+
"size": [size[1], size[0]], # H,W
|
180 |
+
"labels": pred_phrases,
|
181 |
+
}
|
182 |
+
logger.info("dino output{}".format(pred_dict))
|
183 |
+
|
184 |
+
H, W = size[1], size[0]
|
185 |
+
boxes = []
|
186 |
+
for box in zip(boxes_filt):
|
187 |
+
box = box[0] * paddle.to_tensor([W, H, W, H])
|
188 |
+
box[:2] -= box[2:] / 2
|
189 |
+
box[2:] += box[:2]
|
190 |
+
x0, y0, x1, y1 = box.numpy()
|
191 |
+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
192 |
+
boxes.append([x0, y0, x1, y1])
|
193 |
+
boxes = np.array(boxes)
|
194 |
+
image_seg, prompt = sam_processor(image_pil, input_type="boxs", box=boxes, point_coords=None)
|
195 |
+
seg_masks = sam_model(img=image_seg, prompt=prompt)
|
196 |
+
seg_masks = sam_processor.postprocess_masks(seg_masks)
|
197 |
+
|
198 |
+
logger.info("Sam finish!")
|
199 |
+
|
200 |
+
if model_args.visual:
|
201 |
+
# make dir
|
202 |
+
os.makedirs(model_args.output_dir, exist_ok=True)
|
203 |
+
# draw output image
|
204 |
+
plt.figure(figsize=(10, 10))
|
205 |
+
plt.imshow(image_pil)
|
206 |
+
for mask in seg_masks:
|
207 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
208 |
+
for box, label in zip(boxes, pred_phrases):
|
209 |
+
show_box(box, plt.gca(), label)
|
210 |
+
|
211 |
+
plt.axis("off")
|
212 |
+
plt.savefig(
|
213 |
+
os.path.join(model_args.output_dir, "mask_pred.jpg"),
|
214 |
+
bbox_inches="tight",
|
215 |
+
dpi=300,
|
216 |
+
pad_inches=0.0,
|
217 |
+
)
|
218 |
+
|
219 |
+
merge_mask = paddle.sum(seg_masks, axis=0).unsqueeze(0)
|
220 |
+
merge_mask = merge_mask > 0
|
221 |
+
mask_pil = Image.fromarray(merge_mask[0][0].cpu().numpy())
|
222 |
+
|
223 |
+
image_pil = image_pil.resize((512, 512))
|
224 |
+
mask_pil = mask_pil.resize((512, 512))
|
225 |
+
|
226 |
+
image = pipe(prompt=data_args.inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
|
227 |
+
image = image.resize(size)
|
228 |
+
image.save(os.path.join(model_args.output_dir, "grounded_sam_inpainting_output.jpg"))
|
229 |
+
|
230 |
+
logger.info("finish!")
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
main()
|
PaddleMIX/applications/MusicGeneration/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 音乐生成(Music Generation)
|
2 |
+
|
3 |
+
#### 1. Application introduction
|
4 |
+
|
5 |
+
Enter audio and prompt words for question and answer.
|
6 |
+
|
7 |
+
*****
|
8 |
+
- No training is need.
|
9 |
+
- Integration with the moedel of [minigpt4](), [chatglm](), [audioldm]().
|
10 |
+
|
11 |
+
----
|
12 |
+
|
13 |
+
#### 2. Demo
|
14 |
+
*****
|
15 |
+
example:
|
16 |
+
|
17 |
+
|
18 |
+
使用miniGPT4前,需要下载相应权重进行转换,具体可参考[miniGPT4](../../paddlemix/examples/minigpt4/README.md),在完成权重转换后,根据模型权重文件以及配置文件按下存放:
|
19 |
+
```bash
|
20 |
+
--PPMIX_HOME #默认路径 /root/.paddlemix 可通过export PPMIX_HOME 设置
|
21 |
+
--models
|
22 |
+
--miniGPT4
|
23 |
+
--MiniGPT4-7B
|
24 |
+
config.json
|
25 |
+
model_state.pdparams
|
26 |
+
special_tokens_map.json
|
27 |
+
image_preprocessor_config.json
|
28 |
+
preprocessor_config.json
|
29 |
+
tokenizer_config.json
|
30 |
+
model_config.json
|
31 |
+
sentencepiece.bpe.model
|
32 |
+
tokenizer.json
|
33 |
+
--MiniGPT4-13B
|
34 |
+
...
|
35 |
+
...
|
36 |
+
...
|
37 |
+
|
38 |
+
```
|
39 |
+
完成之后,可使用appflow 一键预测
|
40 |
+
|
41 |
+
```python
|
42 |
+
#music generation
|
43 |
+
from paddlemix.appflow import Appflow
|
44 |
+
import paddle
|
45 |
+
from PIL import Image
|
46 |
+
import scipy
|
47 |
+
paddle.seed(1024)
|
48 |
+
|
49 |
+
# Text to music
|
50 |
+
task = Appflow(app="music_generation", models=["cvssp/audioldm"])
|
51 |
+
prompt = "A classic cocktail lounge vibe with smooth jazz piano and a cool, relaxed atmosphere."
|
52 |
+
negative_prompt = 'low quality, average quality, muffled quality, noise interference, poor and low-grade quality, inaudible quality, low-fidelity quality'
|
53 |
+
audio_length_in_s = 5
|
54 |
+
num_inference_steps = 20
|
55 |
+
output_path = "tmp.wav"
|
56 |
+
result = task(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, audio_length_in_s=audio_length_in_s, generator = paddle.Generator().manual_seed(120))['result']
|
57 |
+
scipy.io.wavfile.write(output_path, rate=16000, data=result)
|
58 |
+
|
59 |
+
# image to music
|
60 |
+
task1 = Appflow(app="music_generation", models=["miniGPT4/MiniGPT4-7B"])
|
61 |
+
negative_prompt = 'low quality, average quality, muffled quality, noise interference, poor and low-grade quality, inaudible quality, low-fidelity quality'
|
62 |
+
audio_length_in_s = 5
|
63 |
+
num_inference_steps = 20
|
64 |
+
output_path = "tmp.wav"
|
65 |
+
minigpt4_text = 'describe the image, '
|
66 |
+
image_pil = Image.open("dance.png").convert("RGB")
|
67 |
+
result = task1(image=image_pil, minigpt4_text=minigpt4_text )['result'].split('#')[0]
|
68 |
+
paddle.device.cuda.empty_cache()
|
69 |
+
# miniGPT4 output: The image shows a crowded nightclub with people dancing on the dance floor. The lights on the dance floor are green and red, and there are several people on the dance floor. The stage is at the back of the room, and there are several people on stage. The walls of the nightclub are decorated with neon lights and there are several people sitting at tables in the background. The atmosphere is lively and energetic.
|
70 |
+
|
71 |
+
prompt = "Given the scene description in the following paragraph, please create a musical style sentence that fits the scene. Description:{}.".format(result)
|
72 |
+
task2 = Appflow(app="music_generation", models=["THUDM/chatglm-6b", "cvssp/audioldm"])
|
73 |
+
result = task2(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, audio_length_in_s=audio_length_in_s, generator = paddle.Generator().manual_seed(120))['result']
|
74 |
+
scipy.io.wavfile.write(output_path, rate=16000, data=result)
|
75 |
+
# chatglm ouptput: The music is playing, and the crowd is dancing like never before. The lights are bright and the atmosphere is electric, with people swaying to the rhythm of the music and the energy of the night. The dance floor is a sea of movement, with people moving to the music and feeling the rhythm of their feet. The stage is a place of magic, with people on it, performing their best. The neon lights of the nightclub are a testament to the energy and excitement of the night, with people's faces lit up as they perform. And as the music continues to play, the crowd continues to dance, never letting up, until the night is over.
|
76 |
+
```
|
77 |
+
|
78 |
+
|
79 |
+
#### Text to music
|
80 |
+
| Input Prompt | Output Music |
|
81 |
+
| --- | --- |
|
82 |
+
|'A classic cocktail lounge vibe with smooth jazz piano and a cool, relaxed atmosphere.'| [jazz_output.wav](https://github.com/luyao-cv/file_download/blob/main/assets/jazz_output.wav)
|
83 |
+
|
84 |
+
---
|
85 |
+
|
86 |
+
#### image to music
|
87 |
+
| Input Image | Output Caption | Output Text | Output Music |
|
88 |
+
| --- | --- | --- | --- |
|
89 |
+
| | 'The image shows a crowded nightclub with people dancing on the dance floor. The lights on the dance floor are green and red, and there are several people on the dance floor. The stage is at the back of the room, and there are several people on stage. The walls of the nightclub are decorated with neon lights and there are several people sitting at tables in the background. The atmosphere is lively and energetic.' | 'The music is playing, and the crowd is dancing like never before. The lights are bright and the atmosphere is electric, with people swaying to the rhythm of the music and the energy of the night. The dance floor is a sea of movement, with people moving to the music and feeling the rhythm of their feet. The stage is a place of magic, with people on it, performing their best. The neon lights of the nightclub are a testament to the energy and excitement of the night, with people's faces lit up as they perform. And as the music continues to play, the crowd continues to dance, never letting up, until the night is over.' | [dance_output.wav](https://github.com/luyao-cv/file_download/blob/main/assets/dance_output.wav)
|
PaddleMIX/applications/VLChat/README.md
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 视觉语言对话(Vision-Language-Chat)
|
2 |
+
|
3 |
+
#### 1. 应用介绍
|
4 |
+
输入图像或文字进行多轮对话,包括captions、grounding、视觉定位能力
|
5 |
+
|
6 |
+
|
7 |
+
#### 2. Demo
|
8 |
+
|
9 |
+
example:
|
10 |
+
|
11 |
+
```python
|
12 |
+
|
13 |
+
import paddle
|
14 |
+
from paddlemix.appflow import Appflow
|
15 |
+
from ppdiffusers.utils import load_image
|
16 |
+
paddle.seed(1234)
|
17 |
+
task = Appflow(app="image2text_generation",
|
18 |
+
models=["qwen-vl/qwen-vl-chat-7b"])
|
19 |
+
image= "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg"
|
20 |
+
prompt = "这是什么?"
|
21 |
+
result = task(image=image,prompt=prompt)
|
22 |
+
|
23 |
+
print(result["result"])
|
24 |
+
|
25 |
+
prompt2 = "框出图中公交车的位置"
|
26 |
+
result = task(prompt=prompt2)
|
27 |
+
print(result["result"])
|
28 |
+
|
29 |
+
```
|
30 |
+
|
31 |
+
输入图片:<center><img src="https://github.com/LokeZhou/PaddleMIX/assets/13300429/95f73037-097e-4712-95be-17d5ca489f11" /></center>
|
32 |
+
|
33 |
+
prompt:“这是什么?”
|
34 |
+
|
35 |
+
输出:
|
36 |
+
```
|
37 |
+
这是一张红色城市公交车的图片,它正在道路上行驶,穿越城市。该区域似乎是一个住宅区,因为可以在背景中看到一些房屋。除了公交车之外,还有其他车辆,包括一辆汽车和一辆卡车,共同构成了交通场景。此外,图片中还显示了一一个人,他站在路边,可能是在等待公交车或进行其他活动。
|
38 |
+
```
|
39 |
+
prompt2:“框出图中公交车的位置”
|
40 |
+
|
41 |
+
输出:
|
42 |
+
```
|
43 |
+
<ref>公交车</ref><box>(178,280),(803,894)</box>
|
44 |
+
```
|
PaddleMIX/applications/image2image/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 文本引导的图像放大(Text-Guided Image Upscaling
|
2 |
+
|
3 |
+
```python
|
4 |
+
from paddlemix.appflow import Appflow
|
5 |
+
from PIL import Image
|
6 |
+
from ppdiffusers.utils import load_image
|
7 |
+
|
8 |
+
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/data/low_res_cat.png"
|
9 |
+
|
10 |
+
low_res_img = load_image(url).resize((128, 128))
|
11 |
+
|
12 |
+
prompt = "a white cat"
|
13 |
+
|
14 |
+
app = Appflow(app='image2image_text_guided_upscaling',models=['stabilityai/stable-diffusion-x4-upscaler'])
|
15 |
+
image = app(prompt=prompt,image=low_res_img)['result']
|
16 |
+
|
17 |
+
image.save("upscaled_white_cat.png")
|
18 |
+
```
|
19 |
+
|
20 |
+
效果展示
|
21 |
+
|
22 |
+
<div align="center">
|
23 |
+
|
24 |
+
| prompt |image | Generated Image |
|
25 |
+
|:----:|:----:|:----:|
|
26 |
+
| a white cat|  | |
|
27 |
+
</div>
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
### 文本图像双引导图像生成(Dual Text and Image Guided Generation)
|
33 |
+
|
34 |
+
```python
|
35 |
+
from paddlemix.appflow import Appflow
|
36 |
+
from PIL import Image
|
37 |
+
from ppdiffusers.utils import load_image
|
38 |
+
|
39 |
+
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/data/benz.jpg"
|
40 |
+
image = load_image(url)
|
41 |
+
prompt = "a red car in the sun"
|
42 |
+
|
43 |
+
|
44 |
+
app = Appflow(app='dual_text_and_image_guided_generation',models=['shi-labs/versatile-diffusion'])
|
45 |
+
image = app(prompt=prompt,image=image)['result']
|
46 |
+
image.save("versatile-diffusion-red_car.png")
|
47 |
+
|
48 |
+
```
|
49 |
+
|
50 |
+
效果展示
|
51 |
+
|
52 |
+
<div align="center">
|
53 |
+
|
54 |
+
| prompt |image | Generated Image |
|
55 |
+
|:----:|:----:|:----:|
|
56 |
+
| a red car in the sun |  | |
|
57 |
+
</div>
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
### 文本引导的图像变换(Image-to-Image Text-Guided Generation)
|
62 |
+
|
63 |
+
```python
|
64 |
+
from paddlemix.appflow import Appflow
|
65 |
+
from PIL import Image
|
66 |
+
from ppdiffusers.utils import load_image
|
67 |
+
import paddle
|
68 |
+
|
69 |
+
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/data/image_Kurisu.png"
|
70 |
+
image = load_image(url).resize((512, 768))
|
71 |
+
prompt = "a red car in the sun"
|
72 |
+
|
73 |
+
paddle.seed(42)
|
74 |
+
prompt = "Kurisu Makise, looking at viewer, long hair, standing, 1girl, hair ornament, hair flower, cute, jacket, white flower, white dress"
|
75 |
+
negative_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
|
76 |
+
|
77 |
+
|
78 |
+
app = Appflow(app='image2image_text_guided_generation',models=['admruul/anything-v3.0'])
|
79 |
+
image = app(prompt=prompt,negative_prompt=negative_prompt,image=image)['result']
|
80 |
+
|
81 |
+
image.save("image_Kurisu_img2img.png")
|
82 |
+
|
83 |
+
```
|
84 |
+
|
85 |
+
效果展示
|
86 |
+
|
87 |
+
<div align="center">
|
88 |
+
|
89 |
+
| prompt | negative_prompt |image | Generated Image |
|
90 |
+
|:----:|:----:|:----:| :----:|
|
91 |
+
| Kurisu Makise, looking at viewer, long hair, standing, 1girl, hair ornament, hair flower, cute, jacket, white flower, white dress | lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry | | |
|
92 |
+
</div>
|
PaddleMIX/applications/image2text/README.md
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
### 图文生成(Image-to-Text Generation)
|
4 |
+
|
5 |
+
## miniGPT4
|
6 |
+
使用miniGPT4前,需要下载相应权重进行转换,具体可参考[miniGPT4](../../paddlemix/examples/minigpt4/README.md),在完成权重转换后,根据模型权重文件以及配置文件按下存放:
|
7 |
+
```bash
|
8 |
+
--PPMIX_HOME #默认路径 /root/.paddlemix 可通过export PPMIX_HOME 设置
|
9 |
+
--models
|
10 |
+
--miniGPT4
|
11 |
+
--MiniGPT4-7B
|
12 |
+
config.json
|
13 |
+
model_state.pdparams
|
14 |
+
special_tokens_map.json
|
15 |
+
image_preprocessor_config.json
|
16 |
+
preprocessor_config.json
|
17 |
+
tokenizer_config.json
|
18 |
+
model_config.json
|
19 |
+
sentencepiece.bpe.model
|
20 |
+
tokenizer.json
|
21 |
+
--MiniGPT4-13B
|
22 |
+
...
|
23 |
+
...
|
24 |
+
...
|
25 |
+
|
26 |
+
```
|
27 |
+
完成之后,可使用appflow 一键预测
|
28 |
+
```python
|
29 |
+
from paddlemix.appflow import Appflow
|
30 |
+
import requests
|
31 |
+
|
32 |
+
task = Appflow(app="image2text_generation",
|
33 |
+
models=["miniGPT4/MiniGPT4-7B"])
|
34 |
+
url = "https://paddlenlp.bj.bcebos.com/data/images/mugs.png"
|
35 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
36 |
+
minigpt4_text = "describe the image"
|
37 |
+
result = task(image=image,minigpt4_text=minigpt4_text)
|
38 |
+
```
|
39 |
+
|
40 |
+
效果展示
|
41 |
+
|
42 |
+
<div align="center">
|
43 |
+
|
44 |
+
| Image | text | Generated text|
|
45 |
+
|:----:|:----:|:----:|
|
46 |
+
| | describe the image|The image shows two mugs with cats on them, one is black and white and the other is blue and white. The mugs are sitting on a table with a book in the background. The mugs have a whimsical, cartoon-like appearance. The cats on the mugs are looking at each other with a playful expression. The overall style of the image is cute and fun.###|
|
47 |
+
</div>
|
48 |
+
|
49 |
+
## blip2
|
50 |
+
|
51 |
+
```python
|
52 |
+
from paddlemix.appflow import Appflow
|
53 |
+
from ppdiffusers.utils import load_image
|
54 |
+
|
55 |
+
task = Appflow(app="image2text_generation",
|
56 |
+
models=["paddlemix/blip2-caption-opt2.7b"])
|
57 |
+
url = "https://paddlenlp.bj.bcebos.com/data/images/mugs.png"
|
58 |
+
image_pil = load_image(url)
|
59 |
+
blip2_prompt = 'describe the image'
|
60 |
+
result = task(image=image_pil,blip2_prompt=blip2_prompt)
|
61 |
+
```
|
62 |
+
|
63 |
+
| Image | text | Generated text|
|
64 |
+
|:----:|:----:|:----:|
|
65 |
+
| | describe the image|of the two coffee mugs with cats on them|
|
66 |
+
</div>
|
PaddleMIX/applications/text2image/README.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
### 文图生成(Text-to-Image Generation)
|
4 |
+
|
5 |
+
|
6 |
+
```python
|
7 |
+
import paddle
|
8 |
+
from paddlemix.appflow import Appflow
|
9 |
+
|
10 |
+
paddle.seed(42)
|
11 |
+
task = Appflow(app="text2image_generation",
|
12 |
+
models=["stabilityai/stable-diffusion-xl-base-1.0"]
|
13 |
+
)
|
14 |
+
prompt = "a photo of an astronaut riding a horse on mars."
|
15 |
+
result = task(prompt=prompt)['result']
|
16 |
+
```
|
17 |
+
|
18 |
+
效果展示
|
19 |
+
|
20 |
+
<div align="center">
|
21 |
+
|
22 |
+
| model| prompt | Generated Image |
|
23 |
+
|:----:|:----:|:----:|
|
24 |
+
|stabilityai/stable-diffusion-v1-5| a photo of an astronaut riding a horse on mars | |
|
25 |
+
|stabilityai/stable-diffusion-xl-base-1.0| a photo of an astronaut riding a horse on mars | |
|
26 |
+
</div>
|
27 |
+
|
PaddleMIX/applications/text2video/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### 文本条件的视频生成(Text-to-Video Generation)
|
2 |
+
|
3 |
+
```python
|
4 |
+
from paddlemix.appflow import Appflow
|
5 |
+
import imageio
|
6 |
+
|
7 |
+
|
8 |
+
prompt = "An astronaut riding a horse."
|
9 |
+
|
10 |
+
app = Appflow(app='text_to_video_generation',models=['damo-vilab/text-to-video-ms-1.7b'])
|
11 |
+
video_frames = app(prompt=prompt,num_inference_steps=25)['result']
|
12 |
+
|
13 |
+
imageio.mimsave("text_to_video_generation-synth-result-astronaut_riding_a_horse.gif", video_frames,duration=8)
|
14 |
+
|
15 |
+
```
|
16 |
+
|
17 |
+
<div align="center">
|
18 |
+
|
19 |
+
| Prompt | video |
|
20 |
+
|:----:|:----:|
|
21 |
+
| An astronaut riding a horse.| |
|
22 |
+
|
23 |
+
</div>
|
PaddleMIX/deploy/llava/README.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LLaVA
|
2 |
+
|
3 |
+
## 1. 模型介绍
|
4 |
+
|
5 |
+
[LLaVA](https://arxiv.org/pdf/2310.03744.pdf) 是基于大规模语言模型 llama 的视觉语言模型。支持多个多模态任务,包括零样本图像描述生成(Zero-shot Image Caption)、视觉问答(VQA)、细粒度视觉定位(Referring Expression Comprehension)等任务。
|
6 |
+
|
7 |
+
其性能优于其他模型,在多个任务上取得了更好的效果。
|
8 |
+
|
9 |
+
<p align="center">
|
10 |
+
<img src="https://github.com/haotian-liu/LLaVA/blob/main/images/llava_v1_5_radar.jpg" align="middle" width = "600" />
|
11 |
+
</p>
|
12 |
+
|
13 |
+
注:图片引用自[LLaVA](https://github.com/haotian-liu/LLaVA).
|
14 |
+
|
15 |
+
本目录提供paddle版本的llava静态图推理部署示例,推荐使用A100进行推理部署。
|
16 |
+
|
17 |
+
|
18 |
+
## 2. 安装依赖
|
19 |
+
|
20 |
+
* `paddlenlp_ops`依赖安装
|
21 |
+
|
22 |
+
```bash
|
23 |
+
git submodule update --init --recursive
|
24 |
+
cd PaddleNLP
|
25 |
+
git reset --hard 498f70988431be278dac618411fbfb0287853cd9
|
26 |
+
pip install -e .
|
27 |
+
cd csrc
|
28 |
+
python setup_cuda.py install
|
29 |
+
```
|
30 |
+
* 如果在V100上安装报错,可屏蔽 /PaddleNLP/csrc/generation/quant_int8.cu 以下语句:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
# template<>
|
34 |
+
# __forceinline__ __device__ __nv_bfloat16 add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
|
35 |
+
# return __hmul(__hadd(a, b), c);
|
36 |
+
# }
|
37 |
+
```
|
38 |
+
|
39 |
+
* `fused_ln`需要安装 /PaddleNLP/model_zoo/gpt-3/external_ops 下的自定义OP, `python setup.py install`
|
40 |
+
|
41 |
+
## 3. 示例
|
42 |
+
|
43 |
+
### 3.1 转出静态图推理所需的视觉模型和语言模型
|
44 |
+
|
45 |
+
* 在`PaddleMIX`目录下,执行转换脚本,得到视觉模型部分静态图
|
46 |
+
|
47 |
+
```bash
|
48 |
+
#!/bin/bash
|
49 |
+
export PYTHONPATH=/path/to/PaddleNLP/:/path/to/PaddleMIX
|
50 |
+
python deploy/llava/export_model.py \
|
51 |
+
--model_name_or_path "paddlemix/llava/llava-v1.5-7b" \
|
52 |
+
--save_path "./llava_static" \
|
53 |
+
--encode_image \
|
54 |
+
--fp16
|
55 |
+
```
|
56 |
+
|
57 |
+
* 在`PaddleMIX`目录下,执行转换脚本,得到语言模型部分静态图
|
58 |
+
|
59 |
+
```bash
|
60 |
+
#!/bin/bash
|
61 |
+
export PYTHONPATH=/path/to/PaddleNLP/:/path/to/PaddleMIX
|
62 |
+
python deploy/llava/export_model.py \
|
63 |
+
--model_name_or_path "paddlemix/llava/llava-v1.5-7b" \
|
64 |
+
--save_path "./llava_static" \
|
65 |
+
--encode_text \
|
66 |
+
--fp16
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
### 3.2 静态图推理
|
71 |
+
|
72 |
+
* 在`PaddleMIX`目录下,运行执行脚本,进行静态图推理
|
73 |
+
|
74 |
+
```bash
|
75 |
+
#!/bin/bash
|
76 |
+
|
77 |
+
python deploy/llava/run_static_predict.py --model_name_or_path "paddlemix/llava/llava-v1.5-7b" \
|
78 |
+
--image_file "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg" \
|
79 |
+
--first_model_path "llava_static/encode_image/clip" \
|
80 |
+
--second_model_path "llava_static/encode_text/llama" \
|
81 |
+
--fp16
|
82 |
+
|
83 |
+
```
|
PaddleMIX/deploy/llava/export_model.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
import paddle
|
18 |
+
|
19 |
+
from llama_inference_model import LlamaForClipInferenceModel
|
20 |
+
from paddlemix.auto import AutoConfigMIX, AutoModelMIX
|
21 |
+
from paddlemix.utils.log import logger
|
22 |
+
|
23 |
+
|
24 |
+
def export_encode_text(model, config, compute_dtype):
|
25 |
+
|
26 |
+
# save to static model
|
27 |
+
save_path = args.save_path + "/encode_text/llama"
|
28 |
+
model.to_static(save_path, config, compute_dtype)
|
29 |
+
logger.info(f"static model has been to {save_path}")
|
30 |
+
|
31 |
+
|
32 |
+
def export_encode_image(model, compute_dtype):
|
33 |
+
paddle.save(model.llama.image_newline,args.save_path + "/encode_image/clip/image_newline.pdparams")
|
34 |
+
# convert to static graph with specific input description
|
35 |
+
model = paddle.jit.to_static(
|
36 |
+
model.encode_images,
|
37 |
+
input_spec=[
|
38 |
+
paddle.static.InputSpec(shape=[None,3, 336, 336], dtype=compute_dtype), # images
|
39 |
+
]
|
40 |
+
)
|
41 |
+
|
42 |
+
# save to static model
|
43 |
+
save_path = args.save_path + "/encode_image/clip"
|
44 |
+
paddle.jit.save(model, save_path)
|
45 |
+
logger.info(f"static model has been to {save_path}")
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
parser.add_argument(
|
51 |
+
"--model_name_or_path",
|
52 |
+
default="paddlemix/llava/llava-v1.5-7b",
|
53 |
+
type=str,
|
54 |
+
help="The dir name of llava checkpoint.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--save_path",
|
58 |
+
default="./llava_static",
|
59 |
+
type=str,
|
60 |
+
help="The saving path of static llava vision.",
|
61 |
+
)
|
62 |
+
parser.add_argument("--encode_image", action="store_true")
|
63 |
+
parser.add_argument("--encode_text", action="store_true")
|
64 |
+
parser.add_argument("--fp16", action="store_true")
|
65 |
+
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
compute_dtype = "float16" if args.fp16 else "bfloat16"
|
69 |
+
if not paddle.amp.is_bfloat16_supported() and compute_dtype == "bfloat16":
|
70 |
+
logger.warning("bfloat16 is not supported on your device,change to float32")
|
71 |
+
compute_dtype = "float32"
|
72 |
+
|
73 |
+
if args.encode_image:
|
74 |
+
|
75 |
+
model = AutoModelMIX.from_pretrained(args.model_name_or_path, dtype=compute_dtype)
|
76 |
+
vision_tower = model.get_vision_tower()
|
77 |
+
vision_tower.load_model()
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
export_encode_image(model, compute_dtype)
|
81 |
+
|
82 |
+
elif args.encode_text:
|
83 |
+
|
84 |
+
config = AutoConfigMIX.from_pretrained(args.model_name_or_path)
|
85 |
+
config.tensor_parallel_degree = 1
|
86 |
+
config.tensor_parallel_rank = 0
|
87 |
+
config.weight_only_quant_bits = -1
|
88 |
+
config.quant_type = None
|
89 |
+
|
90 |
+
model = LlamaForClipInferenceModel.from_pretrained(args.model_name_or_path, config=config)
|
91 |
+
|
92 |
+
model.to(dtype=compute_dtype)
|
93 |
+
model.eval()
|
94 |
+
|
95 |
+
export_encode_text(model, config, compute_dtype)
|
96 |
+
|
97 |
+
else:
|
98 |
+
logger.info("please specify the task to export,--encode_image or --encode_text")
|
PaddleMIX/deploy/llava/llama_inference_model.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import paddle
|
16 |
+
from paddlenlp.experimental.transformers import LlamaForCausalLMInferenceModel
|
17 |
+
|
18 |
+
|
19 |
+
class LlamaForClipInferenceModel(LlamaForCausalLMInferenceModel):
|
20 |
+
"""
|
21 |
+
This class is 99% like LlamaForCausalLMInferenceModel.
|
22 |
+
Used only for llava's second part.
|
23 |
+
"""
|
24 |
+
|
25 |
+
@paddle.no_grad()
|
26 |
+
def generate_text_with_image_features(
|
27 |
+
self,
|
28 |
+
input_ids: paddle.Tensor,
|
29 |
+
image_features: paddle.Tensor,
|
30 |
+
img_pos: paddle.Tensor,
|
31 |
+
attention_mask=None,
|
32 |
+
position_ids=None,
|
33 |
+
penalty_score=None,
|
34 |
+
frequency_score=None,
|
35 |
+
presence_score=None,
|
36 |
+
min_length=None,
|
37 |
+
max_length=None,
|
38 |
+
temperature=None,
|
39 |
+
top_p=None,
|
40 |
+
eos_token_id=None,
|
41 |
+
seq_len_encoder=None,
|
42 |
+
seq_len_decoder=None,
|
43 |
+
step_idx=None,
|
44 |
+
stop_flags=None,
|
45 |
+
tgt_ids=None,
|
46 |
+
tgt_pos=None,
|
47 |
+
tgt_generation_mask=None,
|
48 |
+
pre_ids=None,
|
49 |
+
stop_nums=None,
|
50 |
+
cache_kvs=[],
|
51 |
+
**generate_kwargs
|
52 |
+
) -> paddle.Tensor:
|
53 |
+
|
54 |
+
inputs_embeds = self.llama.embed_tokens(input_ids)
|
55 |
+
for batch_idx, pos in enumerate(img_pos):
|
56 |
+
for idx, p in enumerate(pos):
|
57 |
+
index = paddle.arange(p[0], p[1]).unsqueeze(-1)
|
58 |
+
inputs_embeds[batch_idx] = paddle.scatter(inputs_embeds[batch_idx], index, image_features[idx])
|
59 |
+
|
60 |
+
outputs = self.generate(
|
61 |
+
inputs_embeds=inputs_embeds,
|
62 |
+
attention_mask=attention_mask,
|
63 |
+
position_ids=position_ids,
|
64 |
+
penalty_score=penalty_score,
|
65 |
+
frequency_score=frequency_score,
|
66 |
+
presence_score=presence_score,
|
67 |
+
min_length=min_length,
|
68 |
+
max_length=max_length,
|
69 |
+
temperature=temperature,
|
70 |
+
top_p=top_p,
|
71 |
+
eos_token_id=eos_token_id,
|
72 |
+
seq_len_encoder=seq_len_encoder,
|
73 |
+
seq_len_decoder=seq_len_decoder,
|
74 |
+
step_idx=step_idx,
|
75 |
+
stop_flags=stop_flags,
|
76 |
+
tgt_ids=tgt_ids,
|
77 |
+
tgt_pos=tgt_pos,
|
78 |
+
tgt_generation_mask=tgt_generation_mask,
|
79 |
+
pre_ids=pre_ids,
|
80 |
+
stop_nums=stop_nums,
|
81 |
+
cache_kvs=cache_kvs,
|
82 |
+
)
|
83 |
+
return outputs
|
84 |
+
|
85 |
+
def to_static(self, output_path: str, config: dict, compute_dtype: str):
|
86 |
+
|
87 |
+
cache_kvs_shapes = self.get_cache_kvs_shape(config, max_length=config.get("max_length", None))
|
88 |
+
|
89 |
+
input_spec = [
|
90 |
+
paddle.static.InputSpec(shape=[None, None], dtype="int32", name="inputs_ids"),
|
91 |
+
paddle.static.InputSpec(
|
92 |
+
shape=[None, None, None], dtype=compute_dtype, name="image_features"
|
93 |
+
), # image_features
|
94 |
+
paddle.static.InputSpec(shape=[None, None, 2], dtype="int64", name="img_pos"), # img_pos
|
95 |
+
paddle.static.InputSpec(
|
96 |
+
shape=[None, None, None, None], dtype="int64", name="attention_mask"
|
97 |
+
), # attention_mask
|
98 |
+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
|
99 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
|
100 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
|
101 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
|
102 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
|
103 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
|
104 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
|
105 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
|
106 |
+
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
|
107 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
|
108 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
|
109 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
|
110 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
|
111 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
|
112 |
+
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
|
113 |
+
paddle.static.InputSpec(shape=[None, 1, 1, None], name="tgt_generation_mask"), # tgt_generation_mask
|
114 |
+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
|
115 |
+
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
|
116 |
+
[
|
117 |
+
paddle.static.InputSpec(
|
118 |
+
shape=shape,
|
119 |
+
dtype=compute_dtype,
|
120 |
+
name="cache_kvs_{}".format(i),
|
121 |
+
)
|
122 |
+
for i, shape in enumerate(cache_kvs_shapes)
|
123 |
+
], # cache_kvs
|
124 |
+
]
|
125 |
+
|
126 |
+
model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
|
127 |
+
paddle.jit.save(model, output_path, skip_prune_program=True)
|
PaddleMIX/deploy/llava/run_static_predict.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
|
18 |
+
import paddle
|
19 |
+
from utils import load_real_time_tokens
|
20 |
+
|
21 |
+
from paddlemix.auto import AutoConfigMIX, AutoProcessorMIX, AutoTokenizerMIX
|
22 |
+
from paddlemix.models.llava.constants import (
|
23 |
+
DEFAULT_IM_END_TOKEN,
|
24 |
+
DEFAULT_IM_START_TOKEN,
|
25 |
+
DEFAULT_IMAGE_TOKEN,
|
26 |
+
IMAGE_TOKEN_INDEX,
|
27 |
+
)
|
28 |
+
from paddlemix.models.llava.conversation import conv_templates
|
29 |
+
from paddlemix.models.llava.mm_utils import load_image,get_anyres_image_grid_shape
|
30 |
+
from paddlemix.models.llava.llava_arch import unpad_image
|
31 |
+
from paddlemix.utils.log import logger
|
32 |
+
|
33 |
+
|
34 |
+
class Predictor(object):
|
35 |
+
def __init__(self, args):
|
36 |
+
|
37 |
+
self.compute_dtype = "float16" if args.fp16 else "bfloat16"
|
38 |
+
if not paddle.amp.is_bfloat16_supported() and self.compute_dtype == "bfloat16":
|
39 |
+
logger.warning("bfloat16 is not supported on your device,change to float32")
|
40 |
+
self.compute_dtype = "float32"
|
41 |
+
|
42 |
+
self.args = args
|
43 |
+
self.config = AutoConfigMIX.from_pretrained(args.model_name_or_path)
|
44 |
+
self.clip_config = AutoConfigMIX.from_pretrained(self.config.mm_vision_tower)
|
45 |
+
|
46 |
+
|
47 |
+
self.tokenizer = AutoTokenizerMIX.from_pretrained(args.model_name_or_path)
|
48 |
+
self.processor, _ = AutoProcessorMIX.from_pretrained(args.model_name_or_path, image_aspect_ratio=self.config.image_aspect_ratio,eval="eval")
|
49 |
+
|
50 |
+
self.first_predictor = self.create_predictor(args.first_model_path)
|
51 |
+
print(f"first_model_path: {args.first_model_path}, {self.first_predictor}")
|
52 |
+
|
53 |
+
self.second_predictor = self.create_predictor(args.second_model_path)
|
54 |
+
print(f"second_model_path: {args.second_model_path}, {self.second_predictor}")
|
55 |
+
|
56 |
+
self.image_newline = paddle.load(os.path.join(args.first_model_path, "image_newline.pdparams"))
|
57 |
+
|
58 |
+
def create_predictor(self, model_path):
|
59 |
+
|
60 |
+
from paddlenlp.utils.import_utils import import_module
|
61 |
+
|
62 |
+
import_module("paddlenlp_ops.encode_rotary_qk")
|
63 |
+
import_module("paddlenlp_ops.get_padding_offset")
|
64 |
+
import_module("paddlenlp_ops.qkv_transpose_split")
|
65 |
+
import_module("paddlenlp_ops.rebuild_padding")
|
66 |
+
import_module("paddlenlp_ops.transpose_remove_padding")
|
67 |
+
import_module("paddlenlp_ops.write_cache_kv")
|
68 |
+
|
69 |
+
model_file = model_path + ".pdmodel"
|
70 |
+
params_file = model_path + ".pdiparams"
|
71 |
+
if not os.path.exists(model_file):
|
72 |
+
raise ValueError("not find model file path {}".format(model_file))
|
73 |
+
if not os.path.exists(params_file):
|
74 |
+
raise ValueError("not find params file path {}".format(params_file))
|
75 |
+
config = paddle.inference.Config(model_file, params_file)
|
76 |
+
|
77 |
+
config.switch_ir_optim(True)
|
78 |
+
|
79 |
+
if self.args.device == "gpu":
|
80 |
+
config.enable_use_gpu(100, 0)
|
81 |
+
|
82 |
+
config.switch_use_feed_fetch_ops(False)
|
83 |
+
predictor = paddle.inference.create_predictor(config)
|
84 |
+
return predictor
|
85 |
+
|
86 |
+
@paddle.no_grad()
|
87 |
+
def encode_images(self, images, image_sizes):
|
88 |
+
if type(images) is list or images.ndim == 5:
|
89 |
+
if type(images) is list:
|
90 |
+
images = [(x.unsqueeze(axis=0) if x.ndim == 3 else x) for x in images]
|
91 |
+
concat_images = paddle.concat(x=[image for image in images], axis=0)
|
92 |
+
|
93 |
+
image_features = self.first_predictor.run(concat_images)[0]
|
94 |
+
|
95 |
+
split_sizes = [image.shape[0] for image in images]
|
96 |
+
image_features = paddle.split(image_features, split_sizes, axis=0)
|
97 |
+
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
98 |
+
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
99 |
+
if mm_patch_merge_type == "flat":
|
100 |
+
image_features = [x.flatten(start_axis=0, stop_axis=1) for x in image_features]
|
101 |
+
elif mm_patch_merge_type.startswith("spatial"):
|
102 |
+
new_image_features = []
|
103 |
+
for image_idx, image_feature in enumerate(image_features):
|
104 |
+
if image_feature.shape[0] > 1:
|
105 |
+
base_image_feature = image_feature[0]
|
106 |
+
image_feature = image_feature[1:]
|
107 |
+
height = width = self.clip_config.image_resolution // self.clip_config.vision_patch_size
|
108 |
+
assert height * width == base_image_feature.shape[0]
|
109 |
+
if image_aspect_ratio == "anyres":
|
110 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
111 |
+
image_sizes[image_idx],
|
112 |
+
self.config.image_grid_pinpoints,
|
113 |
+
self.clip_config.image_resolution,
|
114 |
+
)
|
115 |
+
|
116 |
+
image_feature = paddle.reshape(
|
117 |
+
image_feature, (num_patch_height, num_patch_width, height, width, -1)
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
if "unpad" in mm_patch_merge_type:
|
122 |
+
image_feature = image_feature.transpose(perm=[4, 0, 2, 1, 3])
|
123 |
+
image_feature = image_feature.flatten(start_axis=1, stop_axis=2).flatten(
|
124 |
+
start_axis=2, stop_axis=3
|
125 |
+
)
|
126 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
127 |
+
image_feature = paddle.concat(
|
128 |
+
x=(
|
129 |
+
image_feature,
|
130 |
+
self.image_newline[:, (None), (None)].expand(
|
131 |
+
shape=[*image_feature.shape[:-1], 1]
|
132 |
+
).astype(image_feature.dtype),
|
133 |
+
),
|
134 |
+
axis=-1,
|
135 |
+
)
|
136 |
+
x = image_feature.flatten(start_axis=1, stop_axis=2)
|
137 |
+
perm_12 = list(range(x.ndim))
|
138 |
+
perm_12[0] = 1
|
139 |
+
perm_12[1] = 0
|
140 |
+
image_feature = x.transpose(perm=perm_12)
|
141 |
+
else:
|
142 |
+
image_feature = image_feature.transpose(perm=[0, 2, 1, 3, 4])
|
143 |
+
image_feature = image_feature.flatten(start_axis=0, stop_axis=3)
|
144 |
+
image_feature = paddle.concat(x=(base_image_feature, image_feature), axis=0)
|
145 |
+
else:
|
146 |
+
image_feature = image_feature[0]
|
147 |
+
if "unpad" in mm_patch_merge_type:
|
148 |
+
image_feature = paddle.concat(
|
149 |
+
x=(image_feature, self.image_newline[None].to(image_feature.place)), axis=0
|
150 |
+
)
|
151 |
+
new_image_features.append(image_feature)
|
152 |
+
image_features = new_image_features
|
153 |
+
image_features = paddle.stack(x=image_features, axis=0)
|
154 |
+
else:
|
155 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
156 |
+
else:
|
157 |
+
image_features = self.first_predictor.run(images)[0]
|
158 |
+
|
159 |
+
return image_features
|
160 |
+
|
161 |
+
@paddle.no_grad()
|
162 |
+
def generate_with_image_features(self, image_features, input_ids):
|
163 |
+
max_len = 2048
|
164 |
+
total_max_length = max_len + 1024
|
165 |
+
batch, seq, _ = image_features.shape
|
166 |
+
seq += input_ids.shape[1] - 1
|
167 |
+
|
168 |
+
_attention_mask = paddle.ones_like(x=input_ids, dtype="bool")
|
169 |
+
input_ids = [
|
170 |
+
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, _attention_mask)
|
171 |
+
]
|
172 |
+
cur_image_idx = 0
|
173 |
+
new_input_ids = []
|
174 |
+
img_pos = []
|
175 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
176 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
177 |
+
|
178 |
+
image_token_indices = (
|
179 |
+
[-1]
|
180 |
+
+ paddle.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].squeeze(axis=1).tolist()
|
181 |
+
+ [cur_input_ids.shape[0]]
|
182 |
+
)
|
183 |
+
cur_input_ids_noim = []
|
184 |
+
|
185 |
+
for i in range(len(image_token_indices) - 1):
|
186 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
187 |
+
|
188 |
+
split_sizes = [x.shape[0] for x in cur_input_ids_noim]
|
189 |
+
|
190 |
+
split_start = 0
|
191 |
+
cur_new_input_ids = []
|
192 |
+
cur_img_pos = []
|
193 |
+
|
194 |
+
for i in range(num_images + 1):
|
195 |
+
cur_new_input_ids.append(cur_input_ids_noim[i])
|
196 |
+
|
197 |
+
if i < num_images:
|
198 |
+
cur_image_features = image_features[cur_image_idx]
|
199 |
+
cur_image_idx += 1
|
200 |
+
cur_new_input_ids.append(paddle.full([cur_image_features.shape[0]], 1, dtype="int64"))
|
201 |
+
split_start += split_sizes[i - 1] if i > 0 else split_sizes[i]
|
202 |
+
cur_img_pos.append([split_start, split_start + cur_image_features.shape[0]])
|
203 |
+
split_start += cur_image_features.shape[0]
|
204 |
+
|
205 |
+
cur_new_input_ids = paddle.concat(x=cur_new_input_ids)
|
206 |
+
new_input_ids.append(cur_new_input_ids)
|
207 |
+
img_pos.append(cur_img_pos)
|
208 |
+
|
209 |
+
new_input_ids = paddle.to_tensor(new_input_ids)
|
210 |
+
img_pos = paddle.to_tensor(img_pos)
|
211 |
+
|
212 |
+
tgt_generation_mask = paddle.full([batch, 1, 1, total_max_length], 1)
|
213 |
+
|
214 |
+
attention_mask = paddle.zeros(
|
215 |
+
shape=(batch, 1, total_max_length, total_max_length),
|
216 |
+
dtype="int64",
|
217 |
+
)
|
218 |
+
length = seq
|
219 |
+
attention_mask[:, 0, :length, :length] = paddle.tril(paddle.ones(shape=(length, length), dtype="int64"))
|
220 |
+
|
221 |
+
position_ids = paddle.full([batch, total_max_length], 0, dtype="int64")
|
222 |
+
position_ids[:, :seq] = paddle.arange(0, seq)
|
223 |
+
|
224 |
+
inputs = [
|
225 |
+
new_input_ids, # input_ids
|
226 |
+
image_features, # image_features
|
227 |
+
img_pos,
|
228 |
+
attention_mask,
|
229 |
+
position_ids,
|
230 |
+
paddle.full([batch, 1], 1.0, dtype="float32"), # penalty_score
|
231 |
+
paddle.full([batch, 1], 0.0, dtype="float32"), # frequency_score,
|
232 |
+
paddle.full([batch, 1], 0.0, dtype="float32"), # presence_score,
|
233 |
+
paddle.full([batch, 1], 1, dtype="int64"), # min_length,
|
234 |
+
paddle.full([batch, 1], max_len, dtype="int64"), # max_length,
|
235 |
+
paddle.full([batch, 1], 0.7, dtype="float32"), # temperature,
|
236 |
+
paddle.full([batch, 1], 0.95, dtype="float32"), # top_p,
|
237 |
+
paddle.full([1], self.config.eos_token_id, dtype="int64"), # eos_token_id,
|
238 |
+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_encoder,
|
239 |
+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_decoder,
|
240 |
+
paddle.full([batch, 1], 0, dtype="int64"), # step_idx,
|
241 |
+
paddle.full([batch, 1], False, dtype="bool"), # stop_flags,
|
242 |
+
paddle.full([batch, 1], 29962, dtype="int64"), # tgt_ids can be be initialized arbitrarily
|
243 |
+
paddle.full([batch, 1], seq - 1, dtype="int64"), # tgt_pos,
|
244 |
+
tgt_generation_mask, # tgt_generation_mask,
|
245 |
+
paddle.full([batch, total_max_length], -1, dtype="int64"), # pre_ids, can be initialized arbitrarily
|
246 |
+
paddle.full([1], batch, dtype="int64"), # stop_nums, be batch
|
247 |
+
]
|
248 |
+
|
249 |
+
for i in range(self.config.num_hidden_layers):
|
250 |
+
tmp = paddle.zeros(
|
251 |
+
shape=[
|
252 |
+
2,
|
253 |
+
batch,
|
254 |
+
self.config.num_attention_heads,
|
255 |
+
total_max_length,
|
256 |
+
self.config.hidden_size // self.config.num_attention_heads,
|
257 |
+
],
|
258 |
+
dtype=self.compute_dtype,
|
259 |
+
)
|
260 |
+
|
261 |
+
inputs.append(tmp)
|
262 |
+
|
263 |
+
self.second_predictor.run(inputs)
|
264 |
+
tokens = load_real_time_tokens()
|
265 |
+
generate_ids = tokens.tolist()
|
266 |
+
|
267 |
+
return generate_ids, None
|
268 |
+
|
269 |
+
def pre_processing(self, inp, first_message):
|
270 |
+
model_name = self.args.model_name_or_path
|
271 |
+
if "llama-2" in model_name.lower():
|
272 |
+
conv_mode = "llava_llama_2"
|
273 |
+
elif "mistral" in model_name.lower():
|
274 |
+
conv_mode = "mistral_instruct"
|
275 |
+
elif "v1.6-34b" in model_name.lower():
|
276 |
+
conv_mode = "chatml_direct"
|
277 |
+
elif "v1" in model_name.lower():
|
278 |
+
conv_mode = "llava_v1"
|
279 |
+
elif "mpt" in model_name.lower():
|
280 |
+
conv_mode = "mpt"
|
281 |
+
else:
|
282 |
+
conv_mode = "llava_v0"
|
283 |
+
if self.args.conv_mode is not None and conv_mode != self.args.conv_mode:
|
284 |
+
print(
|
285 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
286 |
+
conv_mode, self.args.conv_mode, self.args.conv_mode
|
287 |
+
)
|
288 |
+
)
|
289 |
+
else:
|
290 |
+
self.args.conv_mode = conv_mode
|
291 |
+
conv = conv_templates[self.args.conv_mode].copy()
|
292 |
+
|
293 |
+
if self.args.image_file is not None and first_message:
|
294 |
+
if self.config.mm_use_im_start_end:
|
295 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
|
296 |
+
else:
|
297 |
+
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
|
298 |
+
conv.append_message(conv.roles[0], inp)
|
299 |
+
first_message = False
|
300 |
+
else:
|
301 |
+
conv.append_message(conv.roles[0], inp)
|
302 |
+
conv.append_message(conv.roles[1], None)
|
303 |
+
prompt = conv.get_prompt()
|
304 |
+
record = {"image": self.args.image_file, "conversations": prompt}
|
305 |
+
image_size = load_image(args.image_file).size
|
306 |
+
data_dict = self.processor(record=record, image_aspect_ratio=self.config.image_aspect_ratio)
|
307 |
+
data_dict['image_size'] = [image_size]
|
308 |
+
return data_dict
|
309 |
+
|
310 |
+
def post_processing(self, generate_ids):
|
311 |
+
msg = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
312 |
+
return msg
|
313 |
+
|
314 |
+
def run_benchmark(self):
|
315 |
+
first_message = True
|
316 |
+
import time
|
317 |
+
start = 0.0
|
318 |
+
total = 0.0
|
319 |
+
for i in range(20):
|
320 |
+
if i>10:
|
321 |
+
start = time.time()
|
322 |
+
inp = "user: Generate the caption in English with grounding"
|
323 |
+
data_dict = self.pre_processing(inp, first_message)
|
324 |
+
image = paddle.cast(data_dict["images"], self.compute_dtype)
|
325 |
+
|
326 |
+
image_features = self.encode_images(image,data_dict['image_size'])
|
327 |
+
|
328 |
+
generate_ids, _ = self.generate_with_image_features(
|
329 |
+
image_features,
|
330 |
+
data_dict["input_ids"],
|
331 |
+
)
|
332 |
+
|
333 |
+
msg = self.post_processing(generate_ids)
|
334 |
+
if i > 10:
|
335 |
+
total += time.time()-start
|
336 |
+
|
337 |
+
print("Time: ", total/10)
|
338 |
+
|
339 |
+
def predict(self):
|
340 |
+
roles = "user", "assistant"
|
341 |
+
first_message = True
|
342 |
+
|
343 |
+
if self.args.benchmark:
|
344 |
+
self.run_benchmark()
|
345 |
+
else:
|
346 |
+
while True:
|
347 |
+
try:
|
348 |
+
inp = input(f"{roles[0]}: ")
|
349 |
+
except EOFError:
|
350 |
+
inp = ""
|
351 |
+
if not inp:
|
352 |
+
print("exit...")
|
353 |
+
break
|
354 |
+
print(f"{roles[1]}: ", end="")
|
355 |
+
data_dict = self.pre_processing(inp, first_message)
|
356 |
+
image = paddle.cast(data_dict["images"], self.compute_dtype)
|
357 |
+
|
358 |
+
image_features = self.encode_images(image,data_dict['image_size'])
|
359 |
+
|
360 |
+
generate_ids, _ = self.generate_with_image_features(
|
361 |
+
image_features,
|
362 |
+
data_dict["input_ids"],
|
363 |
+
)
|
364 |
+
|
365 |
+
msg = self.post_processing(generate_ids)
|
366 |
+
print("Outputs: ", msg)
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
parser = argparse.ArgumentParser()
|
371 |
+
parser.add_argument(
|
372 |
+
"--first_model_path",
|
373 |
+
default="The dir name of image encoder model",
|
374 |
+
type=str,
|
375 |
+
help="",
|
376 |
+
)
|
377 |
+
parser.add_argument(
|
378 |
+
"--second_model_path",
|
379 |
+
default="The dir name of language model",
|
380 |
+
type=str,
|
381 |
+
help="",
|
382 |
+
)
|
383 |
+
parser.add_argument(
|
384 |
+
"--model_name_or_path",
|
385 |
+
type=str,
|
386 |
+
default="qwen-vl/qwen-vl-7b",
|
387 |
+
help="The path of extraction model path that you want to load.",
|
388 |
+
)
|
389 |
+
parser.add_argument(
|
390 |
+
"--device", default="gpu", choices=["gpu", "cpu", "xpu"], help="Device selected for inference."
|
391 |
+
)
|
392 |
+
parser.add_argument("--seed", default=0)
|
393 |
+
parser.add_argument("--fp16", action="store_true")
|
394 |
+
parser.add_argument("--image_file", type=str, required=True)
|
395 |
+
parser.add_argument("--conv_mode", type=str, default=None)
|
396 |
+
parser.add_argument("--benchmark", action="store_true")
|
397 |
+
|
398 |
+
args = parser.parse_args()
|
399 |
+
|
400 |
+
paddle.seed(args.seed)
|
401 |
+
|
402 |
+
predictor = Predictor(args)
|
403 |
+
predictor.predict()
|
PaddleMIX/deploy/llava/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from __future__ import annotations
|
15 |
+
|
16 |
+
import glob
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
import struct
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
|
24 |
+
def deserialize_from_file(fp):
|
25 |
+
x_type = fp.read(1)
|
26 |
+
x_type_out = struct.unpack("c", x_type)[0]
|
27 |
+
# data
|
28 |
+
data_list = []
|
29 |
+
if x_type_out == b"0":
|
30 |
+
data = fp.read(4)
|
31 |
+
data_out = struct.unpack("f", data)[0]
|
32 |
+
while data:
|
33 |
+
data_out = struct.unpack("f", data)[0]
|
34 |
+
data_list.append(data_out)
|
35 |
+
data = fp.read(4)
|
36 |
+
elif x_type_out == b"1":
|
37 |
+
data = fp.read(8)
|
38 |
+
while data:
|
39 |
+
data_out = struct.unpack("l", data)[0]
|
40 |
+
data_list.append(data_out)
|
41 |
+
data = fp.read(8)
|
42 |
+
elif x_type_out == b"2":
|
43 |
+
data = fp.read(4)
|
44 |
+
while data:
|
45 |
+
data_out = struct.unpack("i", data)[0]
|
46 |
+
data_list.append(data_out)
|
47 |
+
data = fp.read(4)
|
48 |
+
else:
|
49 |
+
print("type error")
|
50 |
+
data_arr = np.array(data_list)
|
51 |
+
return data_arr
|
52 |
+
|
53 |
+
|
54 |
+
def load_real_time_tokens():
|
55 |
+
tokens = []
|
56 |
+
files = glob.glob(os.path.join("./real_time_save.*"))
|
57 |
+
for j in range(1, len(files) + 1):
|
58 |
+
filename = "./real_time_save.temp_ids_rank_0_step_{}".format(j)
|
59 |
+
if not os.path.exists(filename):
|
60 |
+
break
|
61 |
+
fp = open(filename, "rb+")
|
62 |
+
fp.read(1)
|
63 |
+
data_list = deserialize_from_file(fp)
|
64 |
+
fp.close()
|
65 |
+
tokens.append(np.array(data_list).reshape(-1, 1))
|
66 |
+
os.system("rm -f ./real_time_save.temp_ids_rank_*")
|
67 |
+
tokens = np.concatenate(tokens, axis=1)
|
68 |
+
return tokens
|
69 |
+
|
70 |
+
|
71 |
+
def get_alibi_slopes(num_heads):
|
72 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
73 |
+
base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3)))
|
74 |
+
powers = np.arange(1, 1 + closest_power_of_2)
|
75 |
+
slopes = np.power(base, powers)
|
76 |
+
|
77 |
+
if closest_power_of_2 != num_heads:
|
78 |
+
extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3)))
|
79 |
+
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
80 |
+
extra_powers = np.arange(1, 1 + 2 * num_remaining_heads, 2)
|
81 |
+
slopes = np.concatante([slopes, np.power(extra_base, extra_powers)], axis=0)
|
82 |
+
|
83 |
+
return slopes.astype("float32")
|
PaddleMIX/deploy/qwen2_vl/README.md
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Qwen2-VL
|
2 |
+
|
3 |
+
## 1. 模型介绍
|
4 |
+
|
5 |
+
[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/) 是 Qwen 团队推出的一个专注于视觉与语言(Vision-Language, VL)任务的多模态大模型。它旨在通过结合图像和文本信息,提供强大的跨模态理解能力,可以处理涉及图像描述、视觉问答(VQA)、图文检索等多种任务。Qwen2-VL通过引入创新性的技术如 Naive Dynamic Resolution 和 M-RoPE,以及深入探讨大型多模态模型的潜力,显著地提高了多模态内容的视觉理解能力。
|
6 |
+
|
7 |
+
## 2 环境准备
|
8 |
+
|
9 |
+
- **python >= 3.10**
|
10 |
+
- **paddlepaddle-gpu 要求是develop版本**
|
11 |
+
```bash
|
12 |
+
# 安装示例
|
13 |
+
python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
|
14 |
+
```
|
15 |
+
|
16 |
+
- **paddlenlp 需要特定版本**
|
17 |
+
|
18 |
+
在PaddleMIX/代码目录下执行以下命令安装特定版本的paddlenlp:
|
19 |
+
```bash
|
20 |
+
# 安装示例
|
21 |
+
git submodule update --init --recursive
|
22 |
+
cd PaddleNLP
|
23 |
+
git reset --hard e91c2d3d634b12769c30aa419ddf931c20b7ca9f
|
24 |
+
pip install -e .
|
25 |
+
cd csrc
|
26 |
+
python setup_cuda.py install
|
27 |
+
```
|
28 |
+
|
29 |
+
> 注:
|
30 |
+
* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH
|
31 |
+
* (默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡
|
32 |
+
|
33 |
+
## 3 高性能推理
|
34 |
+
|
35 |
+
在Qwen2-VL的高性能推理优化中,**视觉模型部分继续使用PaddleMIX中的模型组网;但是语言模型部分调用PaddleNLP中高性能的Qwen2语言模型**,以得到高性能的Qwen2-VL推理版本。
|
36 |
+
|
37 |
+
### 3.1. 文本&单张图像输入高性能推理
|
38 |
+
```bash
|
39 |
+
python deploy/qwen2_vl/single_image_infer.py \
|
40 |
+
--model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
|
41 |
+
--dtype bfloat16 \
|
42 |
+
--benchmark True \
|
43 |
+
```
|
44 |
+
|
45 |
+
- 在 NVIDIA A100-SXM4-80GB 上测试的单图端到端速度性能如下:
|
46 |
+
|
47 |
+
| model | Paddle Inference| PyTorch | Paddle 动态图 |
|
48 |
+
| ---------------------- | --------------- | ------------ | ------------ |
|
49 |
+
| Qwen2-VL-2B-Instruct | 1.44 s | 2.35 s | 5.215 s |
|
50 |
+
| Qwen2-VL-7B-Instruct | 1.73 s | 4.4s | 6.339 s |
|
PaddleMIX/deploy/qwen2_vl/single_image_infer.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import datetime
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import paddle
|
20 |
+
from paddlenlp.generation import GenerationConfig
|
21 |
+
from paddlenlp.trainer import PdArgumentParser
|
22 |
+
from paddlenlp.transformers import AutoConfig, AutoInferenceModelForCausalLM
|
23 |
+
from paddlenlp.trl import llm_utils
|
24 |
+
|
25 |
+
from paddlemix.models.qwen2_vl import MIXQwen2Tokenizer
|
26 |
+
from paddlemix.models.qwen2_vl.modeling_qwen2_vl import (
|
27 |
+
Qwen2RotaryEmbedding,
|
28 |
+
Qwen2VLForConditionalGeneration,
|
29 |
+
)
|
30 |
+
from paddlemix.processors.qwen2_vl_processing import (
|
31 |
+
Qwen2VLImageProcessor,
|
32 |
+
Qwen2VLProcessor,
|
33 |
+
process_vision_info,
|
34 |
+
)
|
35 |
+
|
36 |
+
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
|
37 |
+
vl_model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16")
|
38 |
+
|
39 |
+
# NOTE: (zhoukangkang、changwenbin) Because we only use the visual model here,
|
40 |
+
# in order to reduce video memory,we delete the language model.
|
41 |
+
del vl_model.model
|
42 |
+
paddle.device.cuda.empty_cache()
|
43 |
+
|
44 |
+
image_processor = Qwen2VLImageProcessor()
|
45 |
+
tokenizer = MIXQwen2Tokenizer.from_pretrained(MODEL_NAME)
|
46 |
+
processor = Qwen2VLProcessor(image_processor, tokenizer)
|
47 |
+
|
48 |
+
# min_pixels = 256*28*28 # 200704
|
49 |
+
# max_pixels = 1280*28*28 # 1003520
|
50 |
+
# processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels)
|
51 |
+
|
52 |
+
messages = [
|
53 |
+
{
|
54 |
+
"role": "user",
|
55 |
+
"content": [
|
56 |
+
{
|
57 |
+
"type": "image",
|
58 |
+
"image": "paddlemix/demo_images/examples_image1.jpg",
|
59 |
+
},
|
60 |
+
{"type": "text", "text": "Describe this image."},
|
61 |
+
],
|
62 |
+
}
|
63 |
+
]
|
64 |
+
|
65 |
+
# Preparation for inference
|
66 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
67 |
+
|
68 |
+
question = "Describe this image."
|
69 |
+
image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>"
|
70 |
+
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n"
|
71 |
+
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class PredictorArgument:
|
75 |
+
# NOTE: (zhoukangkang、changwenbin)
|
76 |
+
# These parameters are all copied from https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/predict/predictor.py
|
77 |
+
# For simplicity and ease of use, only the necessary parameters are retained here.
|
78 |
+
# If you want to know the exact meaning of these parameters, please refer to the link above.
|
79 |
+
|
80 |
+
model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."})
|
81 |
+
src_length = 1024
|
82 |
+
min_length = 2
|
83 |
+
max_length = 200
|
84 |
+
top_k = 0
|
85 |
+
top_p = 0.0
|
86 |
+
temperature = 0.95
|
87 |
+
repetition_penalty = 1.0
|
88 |
+
dtype: str = field(default=None, metadata={"help": "Model dtype"})
|
89 |
+
decode_strategy = "sampling"
|
90 |
+
mode = "dynamic"
|
91 |
+
inference_model = True
|
92 |
+
quant_type = ""
|
93 |
+
benchmark: bool = field(
|
94 |
+
default=False,
|
95 |
+
metadata={
|
96 |
+
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
|
97 |
+
},
|
98 |
+
)
|
99 |
+
use_fake_parameter = False
|
100 |
+
block_attn = True
|
101 |
+
block_size = 64
|
102 |
+
cachekv_int8_type = None
|
103 |
+
append_attn = True
|
104 |
+
total_max_length = 4096
|
105 |
+
speculate_method = None
|
106 |
+
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class ModelArgument:
|
110 |
+
model_type: str = field(
|
111 |
+
default=None,
|
112 |
+
metadata={"help": "the type of the model, which can be one of ['gpt-3', 'ernie-3.5-se', 'llama-img2txt']"},
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def init_llm_model_inputs(vision_model_inputs, inputs_embeds, arg_config: PredictorArgument):
|
117 |
+
assert len(inputs_embeds.shape) == 3
|
118 |
+
batch_size = inputs_embeds.shape[0]
|
119 |
+
|
120 |
+
model_inputs = {}
|
121 |
+
model_inputs["input_ids"] = paddle.zeros(shape=[batch_size, arg_config.total_max_length], dtype="int64")
|
122 |
+
model_inputs["inputs_embeds"] = inputs_embeds
|
123 |
+
|
124 |
+
# I dislike write (arg_config.total_max_length + arg_config.block_size -1 ) // arg_config.block_size
|
125 |
+
assert arg_config.total_max_length % arg_config.block_size == 0
|
126 |
+
|
127 |
+
model_inputs["top_p"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.top_p, dtype="float32")
|
128 |
+
model_inputs["temperature"] = paddle.full(
|
129 |
+
shape=[batch_size, 1], fill_value=arg_config.temperature, dtype="float32"
|
130 |
+
)
|
131 |
+
model_inputs["eos_token_id"] = paddle.to_tensor(
|
132 |
+
np.array(llm_utils.get_eos_token_id(tokenizer, generation_config)).reshape(-1, 1).astype("int64")
|
133 |
+
)
|
134 |
+
model_inputs["penalty_score"] = paddle.full(
|
135 |
+
shape=[batch_size, 1], fill_value=arg_config.repetition_penalty, dtype="float32"
|
136 |
+
)
|
137 |
+
model_inputs["frequency_score"] = paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32")
|
138 |
+
model_inputs["presence_score"] = paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32")
|
139 |
+
model_inputs["min_length"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.min_length, dtype="int64")
|
140 |
+
model_inputs["max_length"] = paddle.full(shape=[batch_size, 1], fill_value=arg_config.max_length, dtype="int64")
|
141 |
+
|
142 |
+
position_ids, _ = vl_model.get_rope_index(
|
143 |
+
config.vision_config["spatial_merge_size"],
|
144 |
+
config.image_token_id,
|
145 |
+
config.video_token_id,
|
146 |
+
config.vision_start_token_id,
|
147 |
+
vision_model_inputs.get("input_ids"),
|
148 |
+
vision_model_inputs.get("image_grid_thw"),
|
149 |
+
vision_model_inputs.get("video_grid_thw", None),
|
150 |
+
vision_model_inputs.get("attention_mask"),
|
151 |
+
)
|
152 |
+
position_start = position_ids[0][0][-1].item()
|
153 |
+
position_end = 4096 - position_ids.shape[-1] + position_start
|
154 |
+
position_value = (
|
155 |
+
paddle.arange(position_start, position_end).reshape([1, 1, -1]).expand([position_ids.shape[0], 1, -1])
|
156 |
+
)
|
157 |
+
position_ids = paddle.concat([position_ids, position_value], axis=-1)
|
158 |
+
|
159 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
160 |
+
qwen2_Embedding = Qwen2RotaryEmbedding(head_dim, 4096, config.rope_theta)
|
161 |
+
cos = qwen2_Embedding.cos_cached
|
162 |
+
sin = qwen2_Embedding.sin_cached
|
163 |
+
|
164 |
+
# NOTE: (zhoukangkang、changwenbin) Copied from PaddleMIX/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py,
|
165 |
+
# for calculating M-ROPE.
|
166 |
+
cos = cos[position_ids]
|
167 |
+
sin = sin[position_ids]
|
168 |
+
mrope_section = config.rope_scaling["mrope_section"] * 2
|
169 |
+
cos = paddle.concat(x=[m[i % 3] for i, m in enumerate(cos.split(mrope_section, axis=-1))], axis=-1)
|
170 |
+
sin = paddle.concat(x=[m[i % 3] for i, m in enumerate(sin.split(mrope_section, axis=-1))], axis=-1)
|
171 |
+
|
172 |
+
rope_emb = paddle.stack([cos, sin], axis=0)
|
173 |
+
rope_emb = rope_emb.reshape([rope_emb.shape[0], 1, rope_emb.shape[2], 1, rope_emb.shape[-1]])
|
174 |
+
model_inputs["rope_emb"] = rope_emb
|
175 |
+
|
176 |
+
model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64")
|
177 |
+
model_inputs["is_block_step"] = paddle.full(shape=[batch_size], fill_value=False, dtype="bool")
|
178 |
+
|
179 |
+
cache_kvs_shape = fast_llm_model.get_cache_kvs_shape(fast_llm_model.config, batch_size)
|
180 |
+
cachekv_dtype = config.dtype if arg_config.cachekv_int8_type is None else "uint8"
|
181 |
+
model_inputs["cache_kvs"] = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in cache_kvs_shape]
|
182 |
+
|
183 |
+
block_nums = arg_config.total_max_length // arg_config.block_size
|
184 |
+
model_inputs["block_tables"] = paddle.arange(block_nums, dtype="int32").tile([batch_size, 1])
|
185 |
+
|
186 |
+
seq_lens = inputs_embeds.shape[1]
|
187 |
+
model_inputs["seq_lens_this_time"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1))
|
188 |
+
model_inputs["seq_lens_encoder"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1))
|
189 |
+
model_inputs["seq_lens_decoder"] = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int32")
|
190 |
+
model_inputs["step_idx"] = paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64")
|
191 |
+
model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool")
|
192 |
+
model_inputs["stop_flags"] = paddle.full(shape=[batch_size, 1], fill_value=False, dtype="bool")
|
193 |
+
model_inputs["stop_nums"] = paddle.full(shape=[1], fill_value=batch_size, dtype="int64")
|
194 |
+
model_inputs["pre_ids"] = paddle.full(shape=[batch_size, arg_config.max_length], fill_value=-1, dtype="int64")
|
195 |
+
model_inputs["next_tokens"] = paddle.full(shape=[batch_size, 1], fill_value=-1, dtype="int64")
|
196 |
+
|
197 |
+
return model_inputs
|
198 |
+
|
199 |
+
|
200 |
+
parser = PdArgumentParser((PredictorArgument, ModelArgument))
|
201 |
+
predictor_args, model_args = parser.parse_args_into_dataclasses()
|
202 |
+
|
203 |
+
paddle.set_default_dtype(predictor_args.dtype)
|
204 |
+
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
|
205 |
+
|
206 |
+
# NOTE: (changwenbin) This is for using the inference optimization of paddlenlp qwen2.
|
207 |
+
config.model_type = "qwen2"
|
208 |
+
generation_config = GenerationConfig.from_pretrained(predictor_args.model_name_or_path)
|
209 |
+
fast_llm_model = AutoInferenceModelForCausalLM.from_pretrained(
|
210 |
+
predictor_args.model_name_or_path,
|
211 |
+
config=config,
|
212 |
+
predictor_args=predictor_args,
|
213 |
+
model_args=model_args,
|
214 |
+
dtype=predictor_args.dtype,
|
215 |
+
tensor_parallel_degree=1,
|
216 |
+
tensor_parallel_rank=0,
|
217 |
+
)
|
218 |
+
fast_llm_model.eval()
|
219 |
+
|
220 |
+
vl_model.model = fast_llm_model
|
221 |
+
|
222 |
+
|
223 |
+
def run_model():
|
224 |
+
|
225 |
+
vision_model_inputs = processor(
|
226 |
+
text=[text],
|
227 |
+
images=image_inputs,
|
228 |
+
videos=video_inputs,
|
229 |
+
padding=True,
|
230 |
+
return_tensors="pd",
|
231 |
+
)
|
232 |
+
inputs_embeds = vl_model.vision_forward(**vision_model_inputs)
|
233 |
+
llm_model_inputs = init_llm_model_inputs(vision_model_inputs, inputs_embeds, arg_config=predictor_args)
|
234 |
+
generated_text = ""
|
235 |
+
while llm_model_inputs["not_need_stop"]:
|
236 |
+
generated_ids = fast_llm_model.generate(**llm_model_inputs) # already trimmed in paddle
|
237 |
+
llm_model_inputs["input_ids"] = generated_ids
|
238 |
+
llm_model_inputs["inputs_embeds"] = None
|
239 |
+
new_text_piece = processor.batch_decode(
|
240 |
+
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
241 |
+
)[0]
|
242 |
+
if new_text_piece == "<|im_end|>":
|
243 |
+
break
|
244 |
+
generated_text += new_text_piece
|
245 |
+
return generated_text
|
246 |
+
|
247 |
+
|
248 |
+
if predictor_args.benchmark:
|
249 |
+
print(f"Benchmarking {predictor_args.model_name_or_path} ...")
|
250 |
+
warm_up = 3
|
251 |
+
repeat_times = 10
|
252 |
+
sumtime = 0.0
|
253 |
+
times = repeat_times + warm_up
|
254 |
+
for i in range(times):
|
255 |
+
if i > 2:
|
256 |
+
paddle.device.synchronize()
|
257 |
+
starttime = datetime.datetime.now()
|
258 |
+
generated_text = run_model()
|
259 |
+
if i > 2:
|
260 |
+
paddle.device.synchronize()
|
261 |
+
endtime = datetime.datetime.now()
|
262 |
+
print("Final output_text:\n", generated_text)
|
263 |
+
|
264 |
+
if i > 2:
|
265 |
+
duringtime = endtime - starttime
|
266 |
+
duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
|
267 |
+
sumtime += duringtime
|
268 |
+
print(f"Single {predictor_args.model_name_or_path} end to end time : ", duringtime, "ms")
|
269 |
+
inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3)
|
270 |
+
print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB")
|
271 |
+
|
272 |
+
print(f"Single {predictor_args.model_name_or_path} ave end to end time : ", sumtime / repeat_times, "ms")
|
273 |
+
|
274 |
+
else:
|
275 |
+
generated_text = run_model()
|
276 |
+
print("Final output_text:\n", generated_text)
|
PaddleMIX/deploy/qwen_vl/run_static_predict.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
|
18 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
19 |
+
os.environ["FLAGS_use_cuda_managed_memory"] = "true"
|
20 |
+
|
21 |
+
import paddle
|
22 |
+
from paddlenlp.transformers.configuration_utils import PretrainedConfig
|
23 |
+
from utils import load_real_time_tokens
|
24 |
+
|
25 |
+
from paddlemix import QwenVLProcessor, QWenVLTokenizer
|
26 |
+
|
27 |
+
|
28 |
+
class Predictor(object):
|
29 |
+
def __init__(self, args):
|
30 |
+
self.args = args
|
31 |
+
self.config = PretrainedConfig.from_pretrained(args.model_name_or_path)
|
32 |
+
self.tokenizer = QWenVLTokenizer.from_pretrained(args.model_name_or_path)
|
33 |
+
self.processor = QwenVLProcessor(tokenizer=self.tokenizer)
|
34 |
+
self.first_predictor = self.create_predictor(args.first_model_path)
|
35 |
+
print(f"first_model_path: {args.first_model_path}, {self.first_predictor}")
|
36 |
+
self.second_predictor = self.create_predictor(args.second_model_path)
|
37 |
+
print(f"second_model_path: {args.second_model_path}, {self.second_predictor}")
|
38 |
+
|
39 |
+
def create_predictor(self, model_path):
|
40 |
+
|
41 |
+
from paddlenlp.utils.import_utils import import_module
|
42 |
+
|
43 |
+
import_module("paddlenlp_ops.encode_rotary_qk")
|
44 |
+
import_module("paddlenlp_ops.get_padding_offset")
|
45 |
+
import_module("paddlenlp_ops.qkv_transpose_split")
|
46 |
+
import_module("paddlenlp_ops.rebuild_padding")
|
47 |
+
import_module("paddlenlp_ops.transpose_remove_padding")
|
48 |
+
import_module("paddlenlp_ops.write_cache_kv")
|
49 |
+
|
50 |
+
model_file = model_path + ".pdmodel"
|
51 |
+
params_file = model_path + ".pdiparams"
|
52 |
+
if not os.path.exists(model_file):
|
53 |
+
raise ValueError("not find model file path {}".format(model_file))
|
54 |
+
if not os.path.exists(params_file):
|
55 |
+
raise ValueError("not find params file path {}".format(params_file))
|
56 |
+
config = paddle.inference.Config(model_file, params_file)
|
57 |
+
|
58 |
+
config.switch_ir_optim(True)
|
59 |
+
|
60 |
+
if self.args.device == "gpu":
|
61 |
+
config.enable_use_gpu(100, 0)
|
62 |
+
|
63 |
+
config.switch_use_feed_fetch_ops(False)
|
64 |
+
predictor = paddle.inference.create_predictor(config)
|
65 |
+
return predictor
|
66 |
+
|
67 |
+
@paddle.no_grad()
|
68 |
+
def encode_images(self, pixel_values):
|
69 |
+
[language_model_inputs] = self.first_predictor.run([pixel_values])
|
70 |
+
return language_model_inputs
|
71 |
+
|
72 |
+
@paddle.no_grad()
|
73 |
+
def generate_with_image_features(self, image_features, input_ids):
|
74 |
+
batch, seq, _ = image_features.shape
|
75 |
+
seq = input_ids.shape[1]
|
76 |
+
max_len = 1024
|
77 |
+
dtype = "float16"
|
78 |
+
tgt_generation_mask = paddle.full([batch, 1, 1, max_len], 1, dtype=dtype)
|
79 |
+
|
80 |
+
img_pos = None
|
81 |
+
if paddle.any(input_ids == self.config.visual["image_start_id"]):
|
82 |
+
bos_pos = paddle.where(input_ids == self.config.visual["image_start_id"])
|
83 |
+
eos_pos = paddle.where(input_ids == self.config.visual["image_start_id"] + 1)
|
84 |
+
assert (bos_pos[0] == eos_pos[0]).astype("bool").all()
|
85 |
+
img_pos = paddle.concat((bos_pos[0], bos_pos[1], eos_pos[1]), axis=1)
|
86 |
+
|
87 |
+
attention_mask = paddle.full([batch, 1, max_len, max_len], 0, dtype=dtype)
|
88 |
+
attention_mask[:, 0, :seq, :seq] = paddle.tril(paddle.ones(shape=(seq, seq), dtype=dtype))
|
89 |
+
position_ids = paddle.full([batch, seq], 0, dtype="int64")
|
90 |
+
for i in range(batch):
|
91 |
+
position_ids[i, :] = paddle.to_tensor([i for i in range(seq)], dtype="int64")
|
92 |
+
|
93 |
+
inputs = [
|
94 |
+
input_ids, # input_ids
|
95 |
+
image_features, # image_features
|
96 |
+
img_pos, # img_pos
|
97 |
+
attention_mask, # attention_mask
|
98 |
+
position_ids, # position_ids
|
99 |
+
paddle.full([batch, 1], 1.0, dtype="float32"), # penalty_score
|
100 |
+
paddle.full([batch, 1], 0.0, dtype="float32"), # frequency_score,
|
101 |
+
paddle.full([batch, 1], 0.0, dtype="float32"), # presence_score,
|
102 |
+
paddle.full([batch, 1], 1, dtype="int64"), # min_length,
|
103 |
+
paddle.full([batch, 1], max_len - seq, dtype="int64"), # max_length,
|
104 |
+
paddle.full([batch, 1], 1.0, dtype="float32"), # temperature,
|
105 |
+
paddle.full([batch, 1], 0.0, dtype="float32"), # top_p,
|
106 |
+
paddle.full([1], 151643, dtype="int64"), # eos_token_id,
|
107 |
+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_encoder,
|
108 |
+
paddle.full([batch, 1], seq, dtype="int32"), # seq_len_decoder,
|
109 |
+
paddle.full([batch, 1], 0, dtype="int64"), # step_idx,
|
110 |
+
paddle.full([batch, 1], False, dtype="bool"), # stop_flags,
|
111 |
+
paddle.full([batch, 1], -123, dtype="int64"), # tgt_ids can be be initialized arbitrarily
|
112 |
+
paddle.full([batch, 1], seq - 1, dtype="int64"), # tgt_pos,
|
113 |
+
tgt_generation_mask, # tgt_generation_mask,
|
114 |
+
paddle.full([batch, max_len], -100, dtype="int64"), # pre_ids, can be initialized arbitrarily
|
115 |
+
paddle.full([1], batch, dtype="int64"), # stop_nums, be batch
|
116 |
+
]
|
117 |
+
for i in range(32):
|
118 |
+
tmp = paddle.rand(shape=[2, batch, 32, max_len, 128], dtype=dtype)
|
119 |
+
inputs.append(tmp)
|
120 |
+
|
121 |
+
self.second_predictor.run(inputs)
|
122 |
+
tokens = load_real_time_tokens()
|
123 |
+
generate_ids = tokens.tolist()
|
124 |
+
return generate_ids, None
|
125 |
+
|
126 |
+
def pre_processing(self, url, prompt):
|
127 |
+
# input query
|
128 |
+
query = []
|
129 |
+
query.append({"image": url})
|
130 |
+
query.append({"text": prompt})
|
131 |
+
inputs = self.processor(query=query, return_tensors="pd")
|
132 |
+
return inputs
|
133 |
+
|
134 |
+
def post_processing(self, generate_ids):
|
135 |
+
msg = self.processor.batch_decode(generate_ids)
|
136 |
+
return msg
|
137 |
+
|
138 |
+
def predict(self, url, prompt):
|
139 |
+
inputs = self.pre_processing(url, prompt)
|
140 |
+
images = inputs["images"]
|
141 |
+
second_input_ids = inputs["input_ids"]
|
142 |
+
|
143 |
+
image_features = self.encode_images(images)
|
144 |
+
generate_ids, _ = self.generate_with_image_features(
|
145 |
+
image_features,
|
146 |
+
second_input_ids,
|
147 |
+
)
|
148 |
+
|
149 |
+
msg = self.post_processing(generate_ids)
|
150 |
+
|
151 |
+
return msg
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
parser = argparse.ArgumentParser()
|
156 |
+
parser.add_argument(
|
157 |
+
"--first_model_path",
|
158 |
+
default="The dir name of image encoder model",
|
159 |
+
type=str,
|
160 |
+
help="",
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--second_model_path",
|
164 |
+
default="The dir name of language model",
|
165 |
+
type=str,
|
166 |
+
help="",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--model_name_or_path",
|
170 |
+
type=str,
|
171 |
+
default="qwen-vl/qwen-vl-7b",
|
172 |
+
help="The path of extraction model path that you want to load.",
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--device", default="gpu", choices=["gpu", "cpu", "xpu"], help="Device selected for inference."
|
176 |
+
)
|
177 |
+
parser.add_argument("--seed", default=1234)
|
178 |
+
parser.add_argument("--benchmark", action="store_true")
|
179 |
+
args = parser.parse_args()
|
180 |
+
|
181 |
+
paddle.seed(args.seed)
|
182 |
+
|
183 |
+
predictor = Predictor(args)
|
184 |
+
|
185 |
+
url = "https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg"
|
186 |
+
prompt = "Generate the caption in English with grounding:"
|
187 |
+
|
188 |
+
if not args.benchmark:
|
189 |
+
msg = predictor.predict(url, prompt)
|
190 |
+
print("Outputs: ", msg)
|
191 |
+
else:
|
192 |
+
import time
|
193 |
+
start = 0.0
|
194 |
+
total = 0.0
|
195 |
+
for i in range(20):
|
196 |
+
if i>10:
|
197 |
+
start = time.time()
|
198 |
+
msg = predictor.predict(url, prompt)
|
199 |
+
|
200 |
+
if i > 10:
|
201 |
+
total += time.time()-start
|
202 |
+
|
203 |
+
print("Time :",total/10)
|
PaddleMIX/deploy/sam/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Segment Anything
|
2 |
+
|
3 |
+
## 1. 模型简介
|
4 |
+
|
5 |
+
[Segment Anything](https://ai.facebook.com/research/publications/segment-anything/) 是 Meta AI Research, FAIR
|
6 |
+
的图像分割模型。根据输入提示(如点或框)生成高质量mask,可为图像中的所有对象进行分割。它已经在1100万张图像和11亿个掩模的数据集上进行了训练,并在各种分割任务上具有强大的零样本性能。
|
7 |
+
本仓库提供该模型的Paddle部署实现。
|
8 |
+
|
9 |
+
## 2. 快速开始
|
10 |
+
|
11 |
+
## 2.1 静态图导出与预测
|
12 |
+
```bash
|
13 |
+
#导出输入类型是 bbox 的静态图
|
14 |
+
python export.py --model_type Sam/SamVitH-1024 --input_type boxs --save_dir sam_export
|
15 |
+
|
16 |
+
#导出输入类型是 points 的静态图
|
17 |
+
python export.py --model_type Sam/SamVitH-1024 --input_type points --save_dir sam_export
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
#bbox 提示词推理
|
22 |
+
python predict.py \
|
23 |
+
--input_image https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg \
|
24 |
+
--box_prompt 112 118 513 382 \
|
25 |
+
--input_type boxs \
|
26 |
+
--model_name_or_path Sam/SamVitH-1024 \
|
27 |
+
--cfg Sam/SamVitH-1024_boxs/deploy.yaml
|
28 |
+
|
29 |
+
|
30 |
+
#points 提示词推理
|
31 |
+
python predict.py \
|
32 |
+
--input_image https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg \
|
33 |
+
--points_prompt 548 372 \
|
34 |
+
--input_type points \
|
35 |
+
--model_name_or_path Sam/SamVitH-1024 \
|
36 |
+
--cfg Sam/SamVitH-1024_points/deploy.yaml
|
37 |
+
```
|
PaddleMIX/deploy/sam/export.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import os
|
17 |
+
|
18 |
+
import paddle
|
19 |
+
import yaml
|
20 |
+
|
21 |
+
from paddlemix.models.sam.modeling import SamModel
|
22 |
+
from paddlemix.utils.log import logger
|
23 |
+
|
24 |
+
|
25 |
+
def parse_args():
|
26 |
+
parser = argparse.ArgumentParser(description="Export Inference Model.")
|
27 |
+
parser.add_argument(
|
28 |
+
"--model_type",
|
29 |
+
choices=["Sam/SamVitH-1024", "Sam/SamVitB", "Sam/SamVitL"],
|
30 |
+
required=True,
|
31 |
+
help="The model type.",
|
32 |
+
type=str,
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--input_type",
|
36 |
+
choices=["boxs", "points", "points_grid"],
|
37 |
+
required=True,
|
38 |
+
help="The model type.",
|
39 |
+
type=str,
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--save_dir",
|
43 |
+
help="The directory for saving the exported inference model",
|
44 |
+
type=str,
|
45 |
+
default="./output/inference_model",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--input_img_shape",
|
49 |
+
nargs="+",
|
50 |
+
help="Export the model with fixed input shape, e.g., `--input_img_shape 1 3 512 1024`.",
|
51 |
+
type=int,
|
52 |
+
default=[1, 3, 1024, 1024],
|
53 |
+
)
|
54 |
+
|
55 |
+
return parser.parse_args()
|
56 |
+
|
57 |
+
|
58 |
+
def main(args):
|
59 |
+
|
60 |
+
os.environ["PADDLESEG_EXPORT_STAGE"] = "True"
|
61 |
+
|
62 |
+
model = SamModel.from_pretrained(args.model_type, input_type=args.input_type)
|
63 |
+
|
64 |
+
shape = [None, 3, None, None] if args.input_img_shape is None else args.input_img_shape
|
65 |
+
if args.input_type == "points":
|
66 |
+
shape2 = [1, 1, 2]
|
67 |
+
elif args.input_type == "boxs":
|
68 |
+
shape2 = [None, 4]
|
69 |
+
elif args.input_type == "points_grid":
|
70 |
+
shape2 = [64, 1, 2]
|
71 |
+
|
72 |
+
input_spec = [
|
73 |
+
paddle.static.InputSpec(shape=shape, dtype="float32"),
|
74 |
+
paddle.static.InputSpec(shape=shape2, dtype="int32"),
|
75 |
+
]
|
76 |
+
model.eval()
|
77 |
+
model = paddle.jit.to_static(model, input_spec=input_spec)
|
78 |
+
save_path = f"{args.model_type}_{args.input_type}"
|
79 |
+
paddle.jit.save(model, os.path.join(save_path, "model"))
|
80 |
+
|
81 |
+
# TODO add test config
|
82 |
+
deploy_info = {
|
83 |
+
"Deploy": {
|
84 |
+
"model": "model.pdmodel",
|
85 |
+
"params": "model.pdiparams",
|
86 |
+
"input_img_shape": shape,
|
87 |
+
"input_prompt_shape": shape2,
|
88 |
+
"input_prompt_type": args.input_type,
|
89 |
+
"model_type": args.model_type,
|
90 |
+
"output_dtype": "float32",
|
91 |
+
}
|
92 |
+
}
|
93 |
+
msg = "\n---------------Deploy Information---------------\n"
|
94 |
+
msg += str(yaml.dump(deploy_info))
|
95 |
+
logger.info(msg)
|
96 |
+
|
97 |
+
yml_file = os.path.join(save_path, "deploy.yaml")
|
98 |
+
with open(yml_file, "w") as file:
|
99 |
+
yaml.dump(deploy_info, file)
|
100 |
+
|
101 |
+
logger.info(f"The inference model is saved in {save_path}")
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
args = parse_args()
|
106 |
+
main(args)
|
PaddleMIX/deploy/sam/predict.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import codecs
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
import numpy as np
|
22 |
+
import requests
|
23 |
+
import yaml
|
24 |
+
from paddle.inference import Config as PredictConfig
|
25 |
+
from paddle.inference import create_predictor
|
26 |
+
from paddlenlp.trainer import PdArgumentParser
|
27 |
+
from PIL import Image
|
28 |
+
|
29 |
+
from paddlemix.processors.sam_processing import SamProcessor
|
30 |
+
from paddlemix.utils.log import logger
|
31 |
+
|
32 |
+
|
33 |
+
def show_mask(mask, ax, random_color=False):
|
34 |
+
if random_color:
|
35 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
36 |
+
else:
|
37 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
38 |
+
h, w = mask.shape[-2:]
|
39 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
40 |
+
ax.imshow(mask_image)
|
41 |
+
|
42 |
+
|
43 |
+
class DeployConfig:
|
44 |
+
def __init__(self, path):
|
45 |
+
with codecs.open(path, "r", "utf-8") as file:
|
46 |
+
self.dic = yaml.load(file, Loader=yaml.FullLoader)
|
47 |
+
|
48 |
+
self._dir = os.path.dirname(path)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def model(self):
|
52 |
+
return os.path.join(self._dir, self.dic["Deploy"]["model"])
|
53 |
+
|
54 |
+
@property
|
55 |
+
def params(self):
|
56 |
+
return os.path.join(self._dir, self.dic["Deploy"]["params"])
|
57 |
+
|
58 |
+
|
59 |
+
def use_auto_tune(args):
|
60 |
+
return (
|
61 |
+
hasattr(PredictConfig, "collect_shape_range_info")
|
62 |
+
and hasattr(PredictConfig, "enable_tuned_tensorrt_dynamic_shape")
|
63 |
+
and args.device == "gpu"
|
64 |
+
and args.use_trt
|
65 |
+
and args.enable_auto_tune
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def auto_tune(args, imgs, img_nums):
|
70 |
+
"""
|
71 |
+
Use images to auto tune the dynamic shape for trt sub graph.
|
72 |
+
The tuned shape saved in args.auto_tuned_shape_file.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
args(dict): input args.
|
76 |
+
imgs(str, list[str], numpy): the path for images or the origin images.
|
77 |
+
img_nums(int): the nums of images used for auto tune.
|
78 |
+
Returns:
|
79 |
+
None
|
80 |
+
"""
|
81 |
+
logger.info("Auto tune the dynamic shape for GPU TRT.")
|
82 |
+
|
83 |
+
assert use_auto_tune(args), (
|
84 |
+
"Do not support auto_tune, which requires " "device==gpu && use_trt==True && paddle >= 2.2"
|
85 |
+
)
|
86 |
+
|
87 |
+
if not isinstance(imgs, (list, tuple)):
|
88 |
+
imgs = [imgs]
|
89 |
+
num = min(len(imgs), img_nums)
|
90 |
+
|
91 |
+
cfg = DeployConfig(args.cfg)
|
92 |
+
pred_cfg = PredictConfig(cfg.model, cfg.params)
|
93 |
+
pass_builder = pred_cfg.pass_builder()
|
94 |
+
pass_builder.delete_pass("identity_op_clean_pass")
|
95 |
+
pred_cfg.enable_use_gpu(100, 0)
|
96 |
+
if not args.print_detail:
|
97 |
+
pred_cfg.disable_glog_info()
|
98 |
+
pred_cfg.collect_shape_range_info(args.auto_tuned_shape_file)
|
99 |
+
|
100 |
+
# todo
|
101 |
+
predictor = create_predictor(pred_cfg)
|
102 |
+
input_names = predictor.get_input_names()
|
103 |
+
input_handle = predictor.get_input_handle(input_names[0])
|
104 |
+
|
105 |
+
for i in range(0, num):
|
106 |
+
if isinstance(imgs[i], str):
|
107 |
+
data = {"img": imgs[i]}
|
108 |
+
data = np.array([cfg.transforms(data)["img"]])
|
109 |
+
else:
|
110 |
+
data = imgs[i]
|
111 |
+
input_handle.reshape(data.shape)
|
112 |
+
input_handle.copy_from_cpu(data)
|
113 |
+
try:
|
114 |
+
predictor.run()
|
115 |
+
except Exception as e:
|
116 |
+
logger.info(str(e))
|
117 |
+
logger.info(
|
118 |
+
"Auto tune failed. Usually, the error is out of GPU memory " "for the model or image is too large. \n"
|
119 |
+
)
|
120 |
+
del predictor
|
121 |
+
if os.path.exists(args.auto_tuned_shape_file):
|
122 |
+
os.remove(args.auto_tuned_shape_file)
|
123 |
+
return
|
124 |
+
|
125 |
+
logger.info("Auto tune success.\n")
|
126 |
+
|
127 |
+
|
128 |
+
class Predictor:
|
129 |
+
def __init__(self, args):
|
130 |
+
"""
|
131 |
+
Prepare for prediction.
|
132 |
+
The usage and docs of paddle inference, please refer to
|
133 |
+
https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html
|
134 |
+
"""
|
135 |
+
self.args = args
|
136 |
+
self.cfg = DeployConfig(args.cfg)
|
137 |
+
self.processor = SamProcessor.from_pretrained(args.model_name_or_path)
|
138 |
+
|
139 |
+
self._init_base_config()
|
140 |
+
|
141 |
+
if args.device == "cpu":
|
142 |
+
self._init_cpu_config()
|
143 |
+
elif args.device == "npu":
|
144 |
+
self.pred_cfg.enable_custom_device("npu")
|
145 |
+
elif args.device == "xpu":
|
146 |
+
self.pred_cfg.enable_xpu()
|
147 |
+
else:
|
148 |
+
self._init_gpu_config()
|
149 |
+
|
150 |
+
try:
|
151 |
+
self.predictor = create_predictor(self.pred_cfg)
|
152 |
+
except Exception as e:
|
153 |
+
logger.info(str(e))
|
154 |
+
logger.info(
|
155 |
+
"If the above error is '(InvalidArgument) some trt inputs dynamic shape info not set, "
|
156 |
+
"..., Expected all_dynamic_shape_set == true, ...', "
|
157 |
+
"please set --enable_auto_tune=True to use auto_tune. \n"
|
158 |
+
)
|
159 |
+
exit()
|
160 |
+
|
161 |
+
def _init_base_config(self):
|
162 |
+
self.pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
|
163 |
+
pass_builder = self.pred_cfg.pass_builder()
|
164 |
+
pass_builder.delete_pass("identity_op_clean_pass")
|
165 |
+
self.pred_cfg.enable_memory_optim()
|
166 |
+
self.pred_cfg.switch_ir_optim(True)
|
167 |
+
|
168 |
+
def _init_cpu_config(self):
|
169 |
+
"""
|
170 |
+
Init the config for x86 cpu.
|
171 |
+
"""
|
172 |
+
logger.info("Use CPU")
|
173 |
+
self.pred_cfg.disable_gpu()
|
174 |
+
if self.args.enable_mkldnn:
|
175 |
+
logger.info("Use MKLDNN")
|
176 |
+
# cache 10 different shapes for mkldnn
|
177 |
+
self.pred_cfg.set_mkldnn_cache_capacity(10)
|
178 |
+
self.pred_cfg.enable_mkldnn()
|
179 |
+
self.pred_cfg.set_cpu_math_library_num_threads(self.args.cpu_threads)
|
180 |
+
|
181 |
+
def _init_gpu_config(self):
|
182 |
+
"""
|
183 |
+
Init the config for nvidia gpu.
|
184 |
+
"""
|
185 |
+
logger.info("Use GPU")
|
186 |
+
self.pred_cfg.enable_use_gpu(100, 0)
|
187 |
+
|
188 |
+
def run(self, image, prompt_out):
|
189 |
+
image, prompt_out = self.preprocess(image, prompt_out)
|
190 |
+
input_names = self.predictor.get_input_names()
|
191 |
+
input_handle1 = self.predictor.get_input_handle(input_names[0])
|
192 |
+
input_handle2 = self.predictor.get_input_handle(input_names[1])
|
193 |
+
output_names = self.predictor.get_output_names()
|
194 |
+
output_handle = self.predictor.get_output_handle(output_names[0])
|
195 |
+
|
196 |
+
input_handle1.reshape(image.shape)
|
197 |
+
input_handle1.copy_from_cpu(image.numpy())
|
198 |
+
if self.args.input_type == "boxs":
|
199 |
+
prompt_out = prompt_out.reshape([-1, 4])
|
200 |
+
input_handle2.reshape(prompt_out.shape)
|
201 |
+
input_handle2.copy_from_cpu(prompt_out.numpy())
|
202 |
+
|
203 |
+
self.predictor.run()
|
204 |
+
|
205 |
+
results = output_handle.copy_to_cpu()
|
206 |
+
|
207 |
+
results = self.postprocess(results)
|
208 |
+
|
209 |
+
return results
|
210 |
+
|
211 |
+
def preprocess(self, image, prompts):
|
212 |
+
|
213 |
+
image_seg, prompt = self.processor(
|
214 |
+
image,
|
215 |
+
input_type=self.args.input_type,
|
216 |
+
box=prompts["boxs"],
|
217 |
+
point_coords=prompts["points"],
|
218 |
+
)
|
219 |
+
|
220 |
+
return [image_seg, prompt]
|
221 |
+
|
222 |
+
def postprocess(self, results):
|
223 |
+
return self.processor.postprocess_masks(results)
|
224 |
+
|
225 |
+
|
226 |
+
@dataclass
|
227 |
+
class DataArguments:
|
228 |
+
"""
|
229 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
230 |
+
Using `PdArgumentParser` we can turn this class
|
231 |
+
into argparse arguments to be able to specify them on
|
232 |
+
the command line.
|
233 |
+
"""
|
234 |
+
|
235 |
+
input_image: str = field(metadata={"help": "The name of input image."})
|
236 |
+
box_prompt: List[int] = field(default=None, metadata={"help": "box promt format as xyxyxyxy...]."})
|
237 |
+
points_prompt: List[int] = field(default=None, metadata={"help": "point promt format as [[xy],[xy]...]."})
|
238 |
+
|
239 |
+
|
240 |
+
@dataclass
|
241 |
+
class ModelArguments:
|
242 |
+
"""
|
243 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
244 |
+
"""
|
245 |
+
|
246 |
+
model_name_or_path: str = field(
|
247 |
+
default="Sam/SamVitH-1024",
|
248 |
+
metadata={"help": "Path to pretrained model or model identifier"},
|
249 |
+
)
|
250 |
+
input_type: str = field(
|
251 |
+
default="boxs",
|
252 |
+
metadata={"help": "The model prompt type, choices ['boxs', 'points', 'points_grid']."},
|
253 |
+
)
|
254 |
+
cfg: str = field(
|
255 |
+
default=None,
|
256 |
+
metadata={"help": "The config file."},
|
257 |
+
)
|
258 |
+
use_trt: bool = field(
|
259 |
+
default=False,
|
260 |
+
metadata={"help": "Whether to use Nvidia TensorRT to accelerate prediction."},
|
261 |
+
)
|
262 |
+
precision: str = field(
|
263 |
+
default="fp32",
|
264 |
+
metadata={"help": "The tensorrt precision."},
|
265 |
+
)
|
266 |
+
min_subgraph_size: int = field(
|
267 |
+
default=3,
|
268 |
+
metadata={"help": "The min subgraph size in tensorrt prediction.'"},
|
269 |
+
)
|
270 |
+
enable_auto_tune: bool = field(
|
271 |
+
default=False,
|
272 |
+
metadata={
|
273 |
+
"help": "Whether to enable tuned dynamic shape. We uses some images to collect \
|
274 |
+
the dynamic shape for trt sub graph, which avoids setting dynamic shape manually."
|
275 |
+
},
|
276 |
+
)
|
277 |
+
device: str = field(
|
278 |
+
default="GPU",
|
279 |
+
metadata={"help": "Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU."},
|
280 |
+
)
|
281 |
+
cpu_threads: int = field(
|
282 |
+
default=10,
|
283 |
+
metadata={"help": "Number of threads to predict when using cpu."},
|
284 |
+
)
|
285 |
+
enable_mkldnn: bool = field(
|
286 |
+
default=False,
|
287 |
+
metadata={"help": "Enable to use mkldnn to speed up when using cpu."},
|
288 |
+
)
|
289 |
+
|
290 |
+
output_dir: str = field(
|
291 |
+
default="seg_output",
|
292 |
+
metadata={"help": "output directory."},
|
293 |
+
)
|
294 |
+
visual: bool = field(
|
295 |
+
default=True,
|
296 |
+
metadata={"help": "save visual image."},
|
297 |
+
)
|
298 |
+
benchmark: bool = field(
|
299 |
+
default=False,
|
300 |
+
metadata={"help": "benchmark"}
|
301 |
+
)
|
302 |
+
|
303 |
+
|
304 |
+
def main(model_args, data_args):
|
305 |
+
|
306 |
+
url = data_args.input_image
|
307 |
+
# read image
|
308 |
+
if os.path.isfile(url):
|
309 |
+
|
310 |
+
image_pil = Image.open(data_args.input_image).convert("RGB")
|
311 |
+
else:
|
312 |
+
image_pil = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
313 |
+
|
314 |
+
if data_args.box_prompt is not None:
|
315 |
+
data_args.box_prompt = np.array(data_args.box_prompt)
|
316 |
+
if data_args.points_prompt is not None:
|
317 |
+
data_args.points_prompt = np.array([data_args.points_prompt])
|
318 |
+
|
319 |
+
if use_auto_tune(model_args):
|
320 |
+
tune_img_nums = 10
|
321 |
+
auto_tune(model_args, [image_pil], tune_img_nums)
|
322 |
+
|
323 |
+
predictor = Predictor(model_args)
|
324 |
+
|
325 |
+
if model_args.benchmark:
|
326 |
+
import time
|
327 |
+
start = 0.0
|
328 |
+
total = 0.0
|
329 |
+
for i in range(20):
|
330 |
+
if i>10:
|
331 |
+
start = time.time()
|
332 |
+
seg_masks = predictor.run(image_pil, {"points": data_args.points_prompt, "boxs": data_args.box_prompt})
|
333 |
+
if i > 10:
|
334 |
+
total += time.time()-start
|
335 |
+
|
336 |
+
print("Time:",total/10)
|
337 |
+
|
338 |
+
seg_masks = predictor.run(image_pil, {"points": data_args.points_prompt, "boxs": data_args.box_prompt})
|
339 |
+
|
340 |
+
if model_args.visual:
|
341 |
+
# make dir
|
342 |
+
os.makedirs(model_args.output_dir, exist_ok=True)
|
343 |
+
# draw output image
|
344 |
+
plt.figure(figsize=(10, 10))
|
345 |
+
plt.imshow(image_pil)
|
346 |
+
for mask in seg_masks:
|
347 |
+
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
|
348 |
+
|
349 |
+
plt.axis("off")
|
350 |
+
plt.savefig(
|
351 |
+
os.path.join(model_args.output_dir, "mask_pred.jpg"),
|
352 |
+
bbox_inches="tight",
|
353 |
+
dpi=300,
|
354 |
+
pad_inches=0.0,
|
355 |
+
)
|
356 |
+
|
357 |
+
if use_auto_tune(model_args) and os.path.exists(model_args.auto_tuned_shape_file):
|
358 |
+
os.remove(model_args.auto_tuned_shape_file)
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == "__main__":
|
362 |
+
|
363 |
+
parser = PdArgumentParser((ModelArguments, DataArguments))
|
364 |
+
model_args, data_args = parser.parse_args_into_dataclasses()
|
365 |
+
|
366 |
+
model_args.device = model_args.device.upper()
|
367 |
+
assert model_args.device in [
|
368 |
+
"CPU",
|
369 |
+
"GPU",
|
370 |
+
"XPU",
|
371 |
+
"NPU",
|
372 |
+
], "device should be CPU, GPU, XPU or NPU"
|
373 |
+
|
374 |
+
main(model_args, data_args)
|
PaddleMIX/docs/hardware_support/ascend_usage.md
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PaddleMIX昇腾使用说明
|
2 |
+
|
3 |
+
为了满足用户对AI芯片多样化的需求, PaddleMIX 团队基于飞桨框架在硬件兼容性和灵活性方面的优势,深度适配了昇腾910芯片,为用户提供了国产计算芯片上的训推能力。只需安装说明安装多硬件版本的飞桨框架后,在模型配置文件中添加一个配置设备的参数,即可在相关硬件上使用PaddleMIX。当前PaddleMIX昇腾版适配涵盖了多模态理解模型InternVL2、LLaVA和多模态生成模型SD3、SDXL。未来我们将继续在用户使用的多种算力平台上适配 PaddleMIX 更多的模型,敬请期待。
|
4 |
+
|
5 |
+
## 1. 模型列表
|
6 |
+
<table align="center">
|
7 |
+
<tbody>
|
8 |
+
<tr align="center" valign="center">
|
9 |
+
<td>
|
10 |
+
<b>多模态理解</b>
|
11 |
+
</td>
|
12 |
+
<td>
|
13 |
+
<b>多模态生成</b>
|
14 |
+
</td>
|
15 |
+
</tr>
|
16 |
+
<tr valign="top">
|
17 |
+
<td>
|
18 |
+
<ul>
|
19 |
+
</ul>
|
20 |
+
<li><b>图文预训练</b></li>
|
21 |
+
<ul>
|
22 |
+
<li><a href="../../paddlemix/examples/llava">LLaVA-1.6</a></li>
|
23 |
+
<li><a href="../../paddlemix/examples/internvl2">InternVL2</a></li>
|
24 |
+
</ul>
|
25 |
+
</td>
|
26 |
+
<td>
|
27 |
+
<ul>
|
28 |
+
</ul>
|
29 |
+
<li><b>文生图</b></li>
|
30 |
+
<ul>
|
31 |
+
<li><a href="../../ppdiffusers/examples/stable_diffusion">Stable Diffusion</a></li>
|
32 |
+
<li><a href="../../ppdiffusers/examples/dreambooth/README_sd3.md">Stable Diffusion 3 (SD3)</a></li>
|
33 |
+
</ul>
|
34 |
+
</td>
|
35 |
+
</tr>
|
36 |
+
</tbody>
|
37 |
+
</table>
|
38 |
+
|
39 |
+
## 2. 安装说明
|
40 |
+
|
41 |
+
### 2.1 创建标准化环境
|
42 |
+
|
43 |
+
当前 PaddleMIX 支持昇腾 910B 芯片,昇腾驱动版本为 23.0.3。考虑到环境差异性,我们推荐使用飞桨官方提供的标准镜像(支持x86服务器与Arm服务器)完成环境准备。
|
44 |
+
|
45 |
+
参考如下命令启动容器,ASCEND_RT_VISIBLE_DEVICES 指定可见的 NPU 卡号
|
46 |
+
|
47 |
+
```shell
|
48 |
+
docker run -it --name paddle-npu-dev -v $(pwd):/work \
|
49 |
+
--privileged --network=host --shm-size=128G -w=/work \
|
50 |
+
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
51 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
52 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
53 |
+
-e ASCEND_RT_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" \
|
54 |
+
registry.baidubce.com/device/paddle-npu:cann80T13-ubuntu20-$(uname -m)-gcc84-py39 /bin/bash
|
55 |
+
```
|
56 |
+
|
57 |
+
### 2.2 安装飞桨
|
58 |
+
|
59 |
+
在容器内安装飞桨
|
60 |
+
|
61 |
+
```shell
|
62 |
+
# 注意需要先安装飞桨 cpu 版本,目前仅支持python3.9版本
|
63 |
+
python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
|
64 |
+
python -m pip install --pre paddle-custom-npu -i https://www.paddlepaddle.org.cn/packages/nightly/npu/
|
65 |
+
```
|
66 |
+
|
67 |
+
### 2.3 安装PaddleMIX
|
68 |
+
|
69 |
+
克隆PaddleMIX仓库
|
70 |
+
|
71 |
+
```shell
|
72 |
+
# 使用最新发布的release/2.1分支
|
73 |
+
git clone https://github.com/PaddlePaddle/PaddleMIX -b release/2.1
|
74 |
+
cd PaddleMIX
|
75 |
+
```
|
76 |
+
|
77 |
+
### 2.4 安装依赖
|
78 |
+
|
79 |
+
```shell
|
80 |
+
sh build_env.sh
|
81 |
+
python -m pip install -U librosa
|
82 |
+
```
|
83 |
+
|
84 |
+
## 3. 多模态理解
|
85 |
+
|
86 |
+
多模态大模型(Multimodal LLM)是当前研究的热点,在 2024 年迎来了井喷式的发展,它将多模态输入经由特定的多模态 encoder 转化为与文本对齐的 token ,随后被输入到大语言模型中来执行多模态任务。PaddleMIX 2.1 新增了两大系列多模态大模型:InternVL2 系列和 Qwen2-VL 系列,同时支持指令微调训练和推理部署,模型能力覆盖了图片问答、文档图表理解、关键信息提取、场景文本理解、 OCR 识别、科学数学问答、视频理解、多图联合理解等。
|
87 |
+
|
88 |
+
InternVL2系列模型支持昇腾 910B 芯片上训练和推理,使用昇腾 910B 芯片训练推理时请先参考本文安装说明章节中的内容安装相应版本的飞桨框架。InternVL2模型训练推理使用方法参考如下:
|
89 |
+
|
90 |
+
### 3.1 微调训练
|
91 |
+
|
92 |
+
#### 3.1.1 数据准备
|
93 |
+
|
94 |
+
参照[文档](../../paddlemix/examples/internvl2)进行数据准备
|
95 |
+
|
96 |
+
#### 3.1.2 环境设置
|
97 |
+
|
98 |
+
设置NPU相关环境变量
|
99 |
+
|
100 |
+
```shell
|
101 |
+
export FLAGS_use_stride_kernel=0
|
102 |
+
export FLAGS_npu_storage_format=0 # 关闭私有格式
|
103 |
+
export FLAGS_npu_jit_compile=0 # 关闭即时编译
|
104 |
+
export FLAGS_npu_scale_aclnn=True # aclnn加速
|
105 |
+
export FLAGS_npu_split_aclnn=True # aclnn加速
|
106 |
+
export CUSTOM_DEVICE_BLACK_LIST=set_value,set_value_with_tensor # set_value加入黑名单
|
107 |
+
|
108 |
+
# 将ppdiffusers加入到PYTHONPATH中
|
109 |
+
export PYTHONPATH=`pwd`/ppdiffusers:$PYTHONPATH
|
110 |
+
```
|
111 |
+
#### 3.1.3 微调训练
|
112 |
+
|
113 |
+
执行微调训练,可以从[PaddleMIX工具箱介绍](../..//paddlemix/tools/README.md)查看详细的参数说明
|
114 |
+
|
115 |
+
```shell
|
116 |
+
# 以2B权重为例子
|
117 |
+
sh paddlemix/examples/internvl2/shell/internvl2.0/2nd_finetune/internvl2_2b_internlm2_1_8b_dynamic_res_2nd_finetune_full.sh
|
118 |
+
```
|
119 |
+
|
120 |
+
### 3.2 推理
|
121 |
+
|
122 |
+
#### 3.2.1 环境设置
|
123 |
+
|
124 |
+
参考上述步骤设置环境变量
|
125 |
+
|
126 |
+
#### 3.2.2 执行推理
|
127 |
+
|
128 |
+
```shell
|
129 |
+
python paddlemix/examples/internvl2/chat_demo.py \
|
130 |
+
--model_name_or_path "OpenGVLab/InternVL2-2B" \
|
131 |
+
--image_path 'paddlemix/demo_images/examples_image1.jpg' \
|
132 |
+
--text "Please describe this image in detail."
|
133 |
+
```
|
134 |
+
|
135 |
+
## 4. 多模态生成
|
136 |
+
|
137 |
+
PPDiffusers 提供了 SD3 的的个性化微调训练样例,只需要少量主题图像即可定制个性化 SD3 模型,支持 DreamBooth LoRA 微调及 DreamBooth 全参数微调。在推理上,提供 SD3 模型高性能推理实现。
|
138 |
+
|
139 |
+
多模态生成Stable Diffusion系列模型支持昇腾 910B 芯片上训练和推理,使用昇腾 910B 芯片训练推理时请先参考本文安装说明章节中的内容安装相应版本的飞桨框架。SDXL模型训练推理使用方法参考如下:
|
140 |
+
|
141 |
+
### 4.1 训练
|
142 |
+
|
143 |
+
#### 4.1.1 环境设置
|
144 |
+
|
145 |
+
昇腾 910B 芯片上进行SDXL训练时设置相应的环境变量
|
146 |
+
|
147 |
+
```shell
|
148 |
+
export FLAGS_npu_storage_format=0
|
149 |
+
export FLAGS_use_stride_kernel=0
|
150 |
+
export FLAGS_npu_scale_aclnn=True
|
151 |
+
export FLAGS_allocator_strategy=auto_growth
|
152 |
+
|
153 |
+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
154 |
+
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
|
155 |
+
export DATASET_NAME="lambdalabs/naruto-blip-captions"
|
156 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
157 |
+
export FLAGS_conv_workspace_size_limit=4096
|
158 |
+
|
159 |
+
# 将ppdiffusers加入到PYTHONPATH中
|
160 |
+
export PYTHONPATH=`pwd`/ppdiffusers:$PYTHONPATH
|
161 |
+
```
|
162 |
+
|
163 |
+
#### 4.1.2 启动SDXL微调训练
|
164 |
+
|
165 |
+
```shell
|
166 |
+
python -u ppdiffusers/examples/text_to_image/train_text_to_image_sdxl.py \
|
167 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
168 |
+
--pretrained_vae_model_name_or_path=$VAE_NAME \
|
169 |
+
--dataset_name=$DATASET_NAME \
|
170 |
+
--resolution=512 --center_crop --random_flip \
|
171 |
+
--proportion_empty_prompts=0.2 \
|
172 |
+
--train_batch_size=1 \
|
173 |
+
--gradient_accumulation_steps=4 --gradient_checkpointing \
|
174 |
+
--max_train_steps=10000 \
|
175 |
+
--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
176 |
+
--mixed_precision="fp16" \
|
177 |
+
--report_to="wandb" \
|
178 |
+
--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
|
179 |
+
--checkpointing_steps=5000 \
|
180 |
+
--output_dir="sdxl-pokemon-model"
|
181 |
+
```
|
182 |
+
|
183 |
+
#### 4.1.3 启动SDXL LoRA训练
|
184 |
+
|
185 |
+
```shell
|
186 |
+
python -u ppdiffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py \
|
187 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
188 |
+
--pretrained_vae_model_name_or_path=$VAE_NAME \
|
189 |
+
--dataset_name=$DATASET_NAME --caption_column="text" \
|
190 |
+
--resolution=1024 --random_flip \
|
191 |
+
--train_batch_size=1 \
|
192 |
+
--num_train_epochs=2 --checkpointing_steps=500 \
|
193 |
+
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
194 |
+
--mixed_precision="fp16" \
|
195 |
+
--seed=42 \
|
196 |
+
--output_dir="sd-pokemon-model-lora-sdxl" \
|
197 |
+
--validation_prompt="cute dragon creature" \
|
198 |
+
--report_to="wandb"
|
199 |
+
```
|
200 |
+
|
201 |
+
### 4.2 推理
|
202 |
+
|
203 |
+
推理脚本参考如下
|
204 |
+
|
205 |
+
```python
|
206 |
+
from ppdiffusers import StableDiffusionXLPipeline
|
207 |
+
from ppdiffusers import (
|
208 |
+
AutoencoderKL,
|
209 |
+
StableDiffusionXLPipeline,
|
210 |
+
UNet2DConditionModel,
|
211 |
+
)
|
212 |
+
import paddle
|
213 |
+
|
214 |
+
unet_path = "your-checkpoint/unet"
|
215 |
+
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", paddle_dtype=paddle.float16)
|
216 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
|
217 |
+
unet = UNet2DConditionModel.from_pretrained(unet_path)
|
218 |
+
|
219 |
+
prompt = "A pokemon with green eyes and red legs."
|
220 |
+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
221 |
+
image.save("pokemon.png")
|
222 |
+
```
|
PaddleMIX/paddlemix/datasets/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Standard imports
|
16 |
+
|
17 |
+
# Local imports
|
18 |
+
from .caption_dataset import *
|
19 |
+
from .chatml_dataset import *
|
20 |
+
from .coco_caption import *
|
21 |
+
from .coco_clip import *
|
22 |
+
from .collator import *
|
23 |
+
from .dataset import *
|
24 |
+
from .mixtoken_dataset import *
|
25 |
+
from .vg_caption import *
|
26 |
+
|
27 |
+
import pkg_resources
|
28 |
+
|
29 |
+
version = pkg_resources.get_distribution("paddlenlp").version
|
30 |
+
try:
|
31 |
+
if version.startswith('3'):
|
32 |
+
from .internvl_dataset import *
|
33 |
+
else:
|
34 |
+
print(f"paddlenlp version {version} is not 3.x, skipping import internvl2 datasets.")
|
35 |
+
|
36 |
+
except ImportError:
|
37 |
+
print("paddlenlp is not installed.")
|
PaddleMIX/paddlemix/datasets/caption_dataset.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
|
19 |
+
from paddle.utils.download import get_path_from_url
|
20 |
+
|
21 |
+
from paddlemix.utils.env import DATA_HOME
|
22 |
+
from paddlemix.utils.log import logger
|
23 |
+
|
24 |
+
from .dataset import DatasetBuilder
|
25 |
+
|
26 |
+
__all__ = ["CaptionDataset"]
|
27 |
+
|
28 |
+
|
29 |
+
class CaptionDataset(DatasetBuilder):
|
30 |
+
"""
|
31 |
+
Caption dataset.
|
32 |
+
"""
|
33 |
+
|
34 |
+
URL = "https://bj.bcebos.com/v1/paddlenlp/datasets/paddlemix/coco.tar"
|
35 |
+
META_INFO = collections.namedtuple("META_INFO", ("images", "annotations", "images_md5", "annotations_md5"))
|
36 |
+
MD5 = ""
|
37 |
+
SPLITS = {
|
38 |
+
"train": META_INFO(
|
39 |
+
os.path.join("coco", "images"),
|
40 |
+
os.path.join("coco", "annotations/coco_karpathy_train.json"),
|
41 |
+
"",
|
42 |
+
"",
|
43 |
+
),
|
44 |
+
"val": META_INFO(
|
45 |
+
os.path.join("coco", "images"),
|
46 |
+
os.path.join("coco", "annotations/coco_karpathy_val.json"),
|
47 |
+
"",
|
48 |
+
"",
|
49 |
+
),
|
50 |
+
"test": META_INFO(
|
51 |
+
os.path.join("coco", "images"),
|
52 |
+
os.path.join("coco", "annotations/coco_karpathy_test.json"),
|
53 |
+
"",
|
54 |
+
"",
|
55 |
+
),
|
56 |
+
}
|
57 |
+
|
58 |
+
def _get_data(self, mode, **kwargs):
|
59 |
+
logger.info("default dataset root is {}".format(DATA_HOME))
|
60 |
+
images, annotations, image_hash, anno_hash = self.SPLITS[mode]
|
61 |
+
image_fullname = os.path.join(DATA_HOME, images)
|
62 |
+
anno_fullname = os.path.join(DATA_HOME, annotations)
|
63 |
+
if not os.path.exists(image_fullname) or not os.path.exists(anno_fullname):
|
64 |
+
get_path_from_url(self.URL, DATA_HOME)
|
65 |
+
|
66 |
+
return image_fullname, anno_fullname, mode
|
67 |
+
|
68 |
+
def _gen_image_id(self, anno):
|
69 |
+
img_ids = {}
|
70 |
+
n = 0
|
71 |
+
for ann in anno:
|
72 |
+
img_id = ann["image_id"]
|
73 |
+
if img_id not in img_ids.keys():
|
74 |
+
img_ids[img_id] = n
|
75 |
+
n += 1
|
76 |
+
return img_ids
|
77 |
+
|
78 |
+
def _gen_image_id_eval(self, anno):
|
79 |
+
img_ids = {}
|
80 |
+
n = 0
|
81 |
+
for ann in anno:
|
82 |
+
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
83 |
+
if img_id not in img_ids.keys():
|
84 |
+
img_ids[img_id] = n
|
85 |
+
n += 1
|
86 |
+
return img_ids
|
87 |
+
|
88 |
+
def _read(self, filename, *args):
|
89 |
+
image_root, anno_path, mode = filename
|
90 |
+
annotations = json.load(open(anno_path, "r"))
|
91 |
+
if mode == "val" or mode == "test":
|
92 |
+
image_ids = self._gen_image_id_eval(annotations)
|
93 |
+
else:
|
94 |
+
image_ids = self._gen_image_id(annotations)
|
95 |
+
for ann in annotations:
|
96 |
+
image_path = os.path.join(image_root, ann["image"])
|
97 |
+
if mode == "train":
|
98 |
+
yield_data = {
|
99 |
+
"image": image_path,
|
100 |
+
"image_id": image_ids[ann["image_id"]],
|
101 |
+
}
|
102 |
+
# only train mode has text input
|
103 |
+
yield_data["text_input"] = ann["caption"]
|
104 |
+
else:
|
105 |
+
yield_data = {
|
106 |
+
"image": image_path,
|
107 |
+
"image_id": ann["image"].split("/")[-1].strip(".jpg").split("_")[-1],
|
108 |
+
}
|
109 |
+
yield yield_data
|
PaddleMIX/paddlemix/datasets/cc_sbu_dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import collections
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
|
20 |
+
from paddle.dataset.common import md5file
|
21 |
+
from paddle.utils.download import get_path_from_url
|
22 |
+
|
23 |
+
from paddlemix.utils.env import DATA_HOME
|
24 |
+
from paddlemix.utils.log import logger
|
25 |
+
|
26 |
+
# from dataset import DatasetBuilder
|
27 |
+
from .dataset import DatasetBuilder
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
__all__ = ["CCSBUAlignDataset"]
|
32 |
+
|
33 |
+
|
34 |
+
class CCSBUAlignDataset(DatasetBuilder):
|
35 |
+
"""
|
36 |
+
CCSBUAlignDataset dataset.
|
37 |
+
"""
|
38 |
+
|
39 |
+
URL = "https://paddlenlp.bj.bcebos.com/datasets/cc_sbu_align.zip"
|
40 |
+
META_INFO = collections.namedtuple(
|
41 |
+
"META_INFO", ("images", "annotations", "num_images", "annotations_md5")
|
42 |
+
)
|
43 |
+
MD5 = "d5fa38be915c8a2aee7ebf3a9c56a95c"
|
44 |
+
SPLITS = {
|
45 |
+
"train": META_INFO(
|
46 |
+
os.path.join("cc_sbu_align", "image"),
|
47 |
+
os.path.join("cc_sbu_align", "filter_cap.json"),
|
48 |
+
3439,
|
49 |
+
"fa3508b6ac29e0ddc7246683d0c3d9a2",
|
50 |
+
),
|
51 |
+
}
|
52 |
+
|
53 |
+
def count_files(self, path):
|
54 |
+
if not os.path.isdir(path):
|
55 |
+
raise ValueError("A directory expected for path, but received {}".format(path))
|
56 |
+
pathes = os.listdir(path)
|
57 |
+
return len(pathes)
|
58 |
+
|
59 |
+
def _get_data(self, mode, **kwargs):
|
60 |
+
logger.info("default dataset root is {}".format(DATA_HOME))
|
61 |
+
images, annotations, num_images, anno_hash = self.SPLITS[mode]
|
62 |
+
image_fullname = os.path.join(DATA_HOME, images)
|
63 |
+
anno_fullname = os.path.join(DATA_HOME, annotations)
|
64 |
+
|
65 |
+
if (not os.path.exists(image_fullname)) or (not os.path.exists(anno_fullname)) or (not md5file(anno_fullname) == anno_hash) or num_images != self.count_files(image_fullname):
|
66 |
+
get_path_from_url(self.URL, DATA_HOME, self.MD5)
|
67 |
+
|
68 |
+
return image_fullname, anno_fullname, mode
|
69 |
+
|
70 |
+
def _gen_image_id(self, anno):
|
71 |
+
img_ids = {}
|
72 |
+
n = 0
|
73 |
+
for ann in anno:
|
74 |
+
# an ann example: {'image_id': '2', 'caption': 'The image shows a man fishing on a lawn next to a river with a bridge in the background. Trees can be seen on the other side of the river, and the sky is cloudy.'}
|
75 |
+
img_id = ann["image_id"]
|
76 |
+
if img_id not in img_ids.keys():
|
77 |
+
img_ids[img_id] = n
|
78 |
+
n += 1
|
79 |
+
return img_ids
|
80 |
+
|
81 |
+
def _read(self, filename, *args):
|
82 |
+
image_root, anno_path, mode = filename
|
83 |
+
with open(anno_path, "r", encoding="utf8") as f:
|
84 |
+
annotations = json.load(f)["annotations"]
|
85 |
+
image_ids = self._gen_image_id(annotations)
|
86 |
+
|
87 |
+
for ann in annotations:
|
88 |
+
image_path = os.path.join(image_root, ann["image_id"]+".jpg")
|
89 |
+
yield_data = {"image": image_path, "image_id": image_ids[ann["image_id"]]}
|
90 |
+
if mode == "train":
|
91 |
+
# only train mode has text input
|
92 |
+
yield_data["text_input"] = ann["caption"]
|
93 |
+
yield yield_data
|
PaddleMIX/paddlemix/datasets/chatml_dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import json
|
17 |
+
|
18 |
+
from paddlenlp.transformers.tokenizer_utils import ChatTemplateMixin
|
19 |
+
|
20 |
+
from .dataset import DatasetBuilder
|
21 |
+
|
22 |
+
__all__ = ["ChatMLDataset"]
|
23 |
+
|
24 |
+
|
25 |
+
class ChatMLDataset(DatasetBuilder, ChatTemplateMixin):
|
26 |
+
"""
|
27 |
+
ChatMLDataset dataset.
|
28 |
+
"""
|
29 |
+
|
30 |
+
SPLITS = {"train": "train.json", "val": "eval.json", "test": "test.json"}
|
31 |
+
|
32 |
+
def _read(self, filename, *args):
|
33 |
+
if self.config["chat_template"] is not None:
|
34 |
+
self.init_chat_template(self.config["chat_template"])
|
35 |
+
raw_data = json.load(open(filename, "r"))
|
36 |
+
annotations = raw_data
|
37 |
+
|
38 |
+
for ann in annotations:
|
39 |
+
yield_data = {}
|
40 |
+
conversations = ann["conversations"]
|
41 |
+
if self.config["chat_template"] is not None:
|
42 |
+
conversations.append([""])
|
43 |
+
yield_data["conversations"] = self.apply_chat_template(conversations, tokenize=False)
|
44 |
+
else:
|
45 |
+
yield_data["conversations"] = conversations
|
46 |
+
|
47 |
+
if "image" in ann.keys():
|
48 |
+
yield_data["image"] = ann["image"]
|
49 |
+
|
50 |
+
yield yield_data
|
PaddleMIX/paddlemix/datasets/coco_caption.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from paddlemix.datasets.caption_dataset import CaptionDataset
|
16 |
+
|
17 |
+
COCOCaption = CaptionDataset
|
PaddleMIX/paddlemix/datasets/coco_vqa.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
|
19 |
+
from paddle.utils.download import get_path_from_url
|
20 |
+
|
21 |
+
from paddlemix.utils.env import DATA_HOME
|
22 |
+
from paddlemix.utils.log import logger
|
23 |
+
|
24 |
+
from .dataset import DatasetBuilder
|
25 |
+
|
26 |
+
__all__ = ["VQADataset"]
|
27 |
+
|
28 |
+
|
29 |
+
class VQADataset(DatasetBuilder):
|
30 |
+
"""
|
31 |
+
Caption dataset.
|
32 |
+
"""
|
33 |
+
|
34 |
+
URL = "https://bj.bcebos.com/v1/paddlenlp/datasets/paddlemix/coco.tar"
|
35 |
+
META_INFO = collections.namedtuple("META_INFO", ("images", "annotations", "images_md5", "annotations_md5"))
|
36 |
+
MD5 = ""
|
37 |
+
SPLITS = {
|
38 |
+
"train": META_INFO(
|
39 |
+
os.path.join("coco", "images"),
|
40 |
+
[os.path.join("coco", "annotations/vqa_train.json"), os.path.join("coco", "annotations/vqa_val.json")],
|
41 |
+
"",
|
42 |
+
"",
|
43 |
+
),
|
44 |
+
"val": META_INFO(
|
45 |
+
os.path.join("coco", "images"),
|
46 |
+
[
|
47 |
+
os.path.join("coco", "annotations/vqa_val_eval.json"),
|
48 |
+
os.path.join("coco", "annotations/answer_list.json"),
|
49 |
+
os.path.join("coco", "annotations/v2_OpenEnded_mscoco_val2014_questions.json"),
|
50 |
+
os.path.join("coco", "annotations/v2_mscoco_val2014_annotations.json"),
|
51 |
+
],
|
52 |
+
"",
|
53 |
+
"",
|
54 |
+
),
|
55 |
+
"test": META_INFO(
|
56 |
+
os.path.join("coco", "images"),
|
57 |
+
[
|
58 |
+
os.path.join("coco", "annotation/vqa_test.json"),
|
59 |
+
os.path.join("coco", "annotation/vqa_test.json"),
|
60 |
+
],
|
61 |
+
"",
|
62 |
+
"",
|
63 |
+
),
|
64 |
+
}
|
65 |
+
|
66 |
+
def _get_data(self, mode, **kwargs):
|
67 |
+
logger.info("default dataset root is {}".format(DATA_HOME))
|
68 |
+
images, annotations, image_hash, anno_hash = self.SPLITS[mode]
|
69 |
+
image_fullname = os.path.join(DATA_HOME, images)
|
70 |
+
if isinstance(annotations, (list, tuple)):
|
71 |
+
anno_fullname = []
|
72 |
+
for ann in annotations:
|
73 |
+
anno_fullname.append(os.path.join(DATA_HOME, ann))
|
74 |
+
if not os.path.exists(image_fullname) or not os.path.exists(os.path.join(DATA_HOME, ann)):
|
75 |
+
get_path_from_url(self.URL, DATA_HOME)
|
76 |
+
else:
|
77 |
+
anno_fullname = os.path.join(DATA_HOME, annotations)
|
78 |
+
if not os.path.exists(image_fullname) or not os.path.exists(anno_fullname):
|
79 |
+
get_path_from_url(self.URL, DATA_HOME)
|
80 |
+
return image_fullname, anno_fullname, mode
|
81 |
+
|
82 |
+
def _read(self, filename, *args):
|
83 |
+
if isinstance(filename, (list, tuple)):
|
84 |
+
image_root, anno_path, mode = filename
|
85 |
+
else:
|
86 |
+
anno_path = [filename]
|
87 |
+
image_root = ""
|
88 |
+
mode = "train"
|
89 |
+
annotations = []
|
90 |
+
if mode == "val" or mode == "test":
|
91 |
+
annotations = json.load(open(anno_path[0]))
|
92 |
+
image_ids = self._gen_image_id_eval(annotations)
|
93 |
+
else:
|
94 |
+
for ann_p in anno_path:
|
95 |
+
annotations.extend(json.load(open(ann_p, "r")))
|
96 |
+
image_ids = self._gen_image_id(annotations)
|
97 |
+
for ann in annotations:
|
98 |
+
image_path = os.path.join(image_root, ann["image"])
|
99 |
+
if mode == "train":
|
100 |
+
yield_data = {
|
101 |
+
"image": image_path,
|
102 |
+
}
|
103 |
+
yield_data["text_input"] = ann["question"]
|
104 |
+
yield_data["answers"] = ann["answer"]
|
105 |
+
yield_data["image_ids"] = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
106 |
+
|
107 |
+
else:
|
108 |
+
yield_data = {
|
109 |
+
"image": image_path,
|
110 |
+
"text_input": ann["question"],
|
111 |
+
"question_id": ann["question_id"],
|
112 |
+
"image_id": ann["image"].split("/")[-1].strip(".jpg").split("_")[-1],
|
113 |
+
}
|
114 |
+
yield_data["image_ids"] = ann["image_ids"]
|
115 |
+
yield yield_data
|
116 |
+
|
117 |
+
def _gen_image_id(self, anno):
|
118 |
+
img_ids = {}
|
119 |
+
n = 0
|
120 |
+
for ann in anno:
|
121 |
+
if "image_id" not in ann.keys():
|
122 |
+
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
123 |
+
else:
|
124 |
+
img_id = ann["image_id"]
|
125 |
+
if img_id not in img_ids.keys():
|
126 |
+
img_ids[img_id] = n
|
127 |
+
n += 1
|
128 |
+
return img_ids
|
129 |
+
|
130 |
+
def _gen_image_id_eval(self, anno):
|
131 |
+
img_ids = {}
|
132 |
+
n = 0
|
133 |
+
for ann in anno:
|
134 |
+
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
|
135 |
+
if img_id not in img_ids.keys():
|
136 |
+
img_ids[img_id] = n
|
137 |
+
n += 1
|
138 |
+
return img_ids
|
PaddleMIX/paddlemix/datasets/collator.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import paddle
|
17 |
+
|
18 |
+
|
19 |
+
class CLIPCollator:
|
20 |
+
"""
|
21 |
+
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
|
22 |
+
Args:
|
23 |
+
processor (`paddlemix.processors.ProcessorMixin`):
|
24 |
+
The processor used for pre-process the data.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, processor):
|
28 |
+
self.processor = processor
|
29 |
+
|
30 |
+
def __call__(self, data_list):
|
31 |
+
if isinstance(data_list[0], dict):
|
32 |
+
images = [sample["image"] for sample in data_list]
|
33 |
+
text = [sample["text_input"] for sample in data_list]
|
34 |
+
batch = self.processor(
|
35 |
+
images=images,
|
36 |
+
text=text,
|
37 |
+
max_length=77,
|
38 |
+
return_tensors="pd",
|
39 |
+
return_attention_mask=False,
|
40 |
+
mode="train",
|
41 |
+
padding_zero=True,
|
42 |
+
)
|
43 |
+
return batch
|
44 |
+
else:
|
45 |
+
images = [sample[0] for sample in data_list]
|
46 |
+
labels = [sample[1] for sample in data_list]
|
47 |
+
batch = self.processor(
|
48 |
+
images=images,
|
49 |
+
text=None,
|
50 |
+
max_length=77,
|
51 |
+
return_tensors="pd",
|
52 |
+
return_attention_mask=False,
|
53 |
+
mode="eval",
|
54 |
+
do_resize=True,
|
55 |
+
do_crop=True,
|
56 |
+
padding_zero=True,
|
57 |
+
)
|
58 |
+
batch["labels"] = paddle.to_tensor(np.array(labels))
|
59 |
+
return batch
|
60 |
+
|
61 |
+
|
62 |
+
class EVA02Collator:
|
63 |
+
"""
|
64 |
+
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
|
65 |
+
Args:
|
66 |
+
processor (`paddlemix.processors.ProcessorMixin`):
|
67 |
+
The processor used for pre-process the data.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, processor, mode="train"):
|
71 |
+
self.processor = processor
|
72 |
+
self.mode = mode
|
73 |
+
|
74 |
+
def __call__(self, data_list):
|
75 |
+
images = [sample[0] for sample in data_list]
|
76 |
+
# get labels from teacher's clip_features
|
77 |
+
batch = self.processor(
|
78 |
+
images=images,
|
79 |
+
return_tensors="pd",
|
80 |
+
mode=self.mode,
|
81 |
+
)
|
82 |
+
return batch
|
83 |
+
|
84 |
+
|
85 |
+
class MiniGPT4Collator:
|
86 |
+
"""
|
87 |
+
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
|
88 |
+
Args:
|
89 |
+
processor (`paddlemix.processors.ProcessorMixin`):
|
90 |
+
The processor used for pre-process the data.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self, processor, mode="test"):
|
94 |
+
self.processor = processor
|
95 |
+
self.mode = mode
|
96 |
+
|
97 |
+
def __call__(self, data_list):
|
98 |
+
images = [sample["image"] for sample in data_list]
|
99 |
+
target_texts = [sample["text_input"] for sample in data_list]
|
100 |
+
# random text from text_list read by processor and combine it with default prompt
|
101 |
+
batch_data = self.processor(images=images, mode="train")
|
102 |
+
target_outputs = self.processor.process_target_texts(target_texts)
|
103 |
+
batch_data.update(target_outputs)
|
104 |
+
return batch_data
|
105 |
+
|
106 |
+
|
107 |
+
class QwenVLCollator:
|
108 |
+
"""
|
109 |
+
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
|
110 |
+
Args:
|
111 |
+
processor (`paddlemix.processors.ProcessorMixin`):
|
112 |
+
The processor used for pre-process the data.
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, processor, mode="test"):
|
116 |
+
self.processor = processor
|
117 |
+
self.mode = mode
|
118 |
+
|
119 |
+
def __call__(self, data_list):
|
120 |
+
input_ids = []
|
121 |
+
labels = []
|
122 |
+
images = []
|
123 |
+
IGNORE_TOKEN_ID = -100
|
124 |
+
for record in data_list:
|
125 |
+
|
126 |
+
if isinstance(record, dict) and "input_ids" in record.keys():
|
127 |
+
raw_data = record
|
128 |
+
else:
|
129 |
+
raw_data = self.processor(query=record, mode=self.mode)
|
130 |
+
|
131 |
+
raw_data["input_ids"] += [self.processor.tokenizer.pad_token_id] * (
|
132 |
+
self.processor.max_len - len(raw_data["input_ids"])
|
133 |
+
)
|
134 |
+
raw_data["labels"] += [IGNORE_TOKEN_ID] * (self.processor.max_len - len(raw_data["labels"]))
|
135 |
+
input_ids.append(raw_data["input_ids"])
|
136 |
+
labels.append(raw_data["labels"])
|
137 |
+
|
138 |
+
if "images" in raw_data:
|
139 |
+
|
140 |
+
if isinstance(raw_data["images"], list):
|
141 |
+
if not isinstance(raw_data["images"][0], list):
|
142 |
+
raw_data["images"] = [raw_data["images"]]
|
143 |
+
raw_data["images"] = [self.processor.image_processor(path) for path in raw_data["images"]]
|
144 |
+
raw_data["images"] = paddle.stack(x=raw_data["images"], axis=0)
|
145 |
+
|
146 |
+
images.append(raw_data["images"])
|
147 |
+
|
148 |
+
input_ids = paddle.to_tensor(data=input_ids, dtype="int32")
|
149 |
+
labels = paddle.to_tensor(data=labels, dtype="int32")
|
150 |
+
attention_mask = input_ids.not_equal(y=paddle.to_tensor(self.processor.tokenizer.pad_token_id, dtype="int32"))
|
151 |
+
|
152 |
+
if len(images) > 0:
|
153 |
+
images = paddle.concat(images, axis=0)
|
154 |
+
image_shape = [-1, 3] + images.shape[-2:]
|
155 |
+
images = images.reshape(image_shape)
|
156 |
+
|
157 |
+
batch_data = dict(
|
158 |
+
input_ids=input_ids,
|
159 |
+
labels=labels,
|
160 |
+
images=images if 0 < len(images) else None,
|
161 |
+
attention_mask=attention_mask,
|
162 |
+
)
|
163 |
+
|
164 |
+
return batch_data
|
165 |
+
|
166 |
+
|
167 |
+
class VisualglmCollator:
|
168 |
+
"""
|
169 |
+
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
|
170 |
+
Args:
|
171 |
+
processor (`paddlemix.processors.ProcessorMixin`):
|
172 |
+
The processor used for pre-process the data.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, processor, mode="test", max_seq_length=2048):
|
176 |
+
self.processor = processor
|
177 |
+
self.mode = mode
|
178 |
+
self.max_seq_length = max_seq_length
|
179 |
+
|
180 |
+
def __call__(self, data_list):
|
181 |
+
|
182 |
+
input_ids = []
|
183 |
+
labels = []
|
184 |
+
images = []
|
185 |
+
|
186 |
+
for record in data_list:
|
187 |
+
if "input_ids" not in record.keys():
|
188 |
+
raw_data = self.processor(record=record, mode=self.mode)
|
189 |
+
else:
|
190 |
+
raw_data = record
|
191 |
+
|
192 |
+
pad_len = self.max_seq_length - len(raw_data["input_ids"])
|
193 |
+
raw_data["input_ids"] = raw_data["input_ids"] + [self.processor.tokenizer.pad_token_id] * pad_len
|
194 |
+
raw_data["labels"] = raw_data["labels"] + [self.processor.tokenizer.pad_token_id] * pad_len
|
195 |
+
raw_data["labels"] = [
|
196 |
+
(l if l != self.processor.tokenizer.pad_token_id else -100) for l in raw_data["labels"]
|
197 |
+
]
|
198 |
+
|
199 |
+
if "images" in raw_data:
|
200 |
+
if isinstance(raw_data["images"], list):
|
201 |
+
raw_data["images"] = paddle.stack(x=raw_data["images"], axis=0)
|
202 |
+
images.append(raw_data["images"])
|
203 |
+
|
204 |
+
input_ids.append(raw_data["input_ids"])
|
205 |
+
labels.append(raw_data["labels"])
|
206 |
+
|
207 |
+
input_ids = paddle.to_tensor(data=input_ids, dtype="int64")
|
208 |
+
labels = paddle.to_tensor(data=labels, dtype="int64")
|
209 |
+
|
210 |
+
if 0 < len(images):
|
211 |
+
images = paddle.concat(images, axis=0)
|
212 |
+
image_shape = [-1, 3] + images.shape[-2:]
|
213 |
+
images = images.reshape(image_shape)
|
214 |
+
|
215 |
+
batch_data = dict(input_ids=input_ids, labels=labels, images=images if 0 < len(images) else None)
|
216 |
+
return batch_data
|
217 |
+
|
218 |
+
|
219 |
+
class LLaVACollator:
|
220 |
+
"""
|
221 |
+
Data collator that will dynamically pad the inputs to the longest sequence in the batch.
|
222 |
+
Args:
|
223 |
+
processor (`paddlemix.processors.ProcessorMixin`):
|
224 |
+
The processor used for pre-process the data.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(self, processor, mode="test", mixtokens=False):
|
228 |
+
self.processor = processor
|
229 |
+
self.mode = mode
|
230 |
+
self.mixtokens = mixtokens
|
231 |
+
|
232 |
+
def __call__(self, data_list):
|
233 |
+
IGNORE_INDEX = -100
|
234 |
+
input_ids = []
|
235 |
+
labels = []
|
236 |
+
images = []
|
237 |
+
for record in data_list:
|
238 |
+
|
239 |
+
if isinstance(record, dict) and "input_ids" in record.keys():
|
240 |
+
raw_data = record
|
241 |
+
else:
|
242 |
+
raw_data = self.processor(record=record, mode=self.mode)
|
243 |
+
|
244 |
+
raw_data["input_ids"] += [self.processor.tokenizer.pad_token_id] * (
|
245 |
+
self.processor.max_len - len(raw_data["input_ids"])
|
246 |
+
)
|
247 |
+
raw_data["labels"] += [IGNORE_INDEX] * (self.processor.max_len - len(raw_data["labels"]))
|
248 |
+
|
249 |
+
input_ids.append(raw_data["input_ids"])
|
250 |
+
labels.append(raw_data["labels"])
|
251 |
+
|
252 |
+
if "images" in raw_data:
|
253 |
+
if isinstance(raw_data["images"], list):
|
254 |
+
raw_data["images"] = paddle.stack(x=raw_data["images"], axis=0)
|
255 |
+
|
256 |
+
images.append(raw_data["images"])
|
257 |
+
|
258 |
+
input_ids = paddle.to_tensor(data=input_ids, dtype="int32")
|
259 |
+
labels = paddle.to_tensor(data=labels, dtype="int32")
|
260 |
+
attention_mask = input_ids.not_equal(y=paddle.to_tensor(self.processor.tokenizer.pad_token_id, dtype="int32"))
|
261 |
+
|
262 |
+
if len(images) > 0:
|
263 |
+
images = paddle.concat(images, axis=0)
|
264 |
+
image_shape = [-1, 3] + images.shape[-2:]
|
265 |
+
images = images.reshape(image_shape)
|
266 |
+
|
267 |
+
batch_data = dict(
|
268 |
+
input_ids=input_ids,
|
269 |
+
labels=labels,
|
270 |
+
images=images if len(images) > 0 else None,
|
271 |
+
attention_mask=attention_mask,
|
272 |
+
)
|
273 |
+
|
274 |
+
return batch_data
|
275 |
+
|
276 |
+
|
277 |
+
class InternLMXComposer2Collator:
|
278 |
+
"""Collate examples for InternLMXComposer2Collator"""
|
279 |
+
|
280 |
+
def __init__(self, processor, mode="train"):
|
281 |
+
self.processor = processor
|
282 |
+
self.mode = mode
|
283 |
+
|
284 |
+
def __call__(self, instances):
|
285 |
+
|
286 |
+
instances = [self.processor(query=instance, mode=self.mode) for instance in instances]
|
287 |
+
|
288 |
+
input_tokens, input_text = tuple(
|
289 |
+
[instance[key] for instance in instances] for key in ("input_tokens", "input_text")
|
290 |
+
)
|
291 |
+
batch = dict(
|
292 |
+
input_tokens=input_tokens,
|
293 |
+
input_text=input_text,
|
294 |
+
)
|
295 |
+
if "images" in instances[0].keys():
|
296 |
+
input_images = tuple([instance["images"] for instance in instances])
|
297 |
+
batch["images"] = input_images
|
298 |
+
|
299 |
+
return dict(samples=batch)
|
300 |
+
|
301 |
+
|
302 |
+
class InternVL2Collator:
|
303 |
+
"""Collate examples for InternVL2Collator"""
|
304 |
+
|
305 |
+
def __init__(self, processor, mode="test"):
|
306 |
+
self.processor = processor
|
307 |
+
self.mode = mode
|
308 |
+
|
309 |
+
def __call__(self, features):
|
310 |
+
pad_id = self.processor.tokenizer.pad_token_id
|
311 |
+
IGNORE_INDEX = -100
|
312 |
+
first = features[0]
|
313 |
+
batch = {}
|
314 |
+
|
315 |
+
batch_lens = [feat["input_ids"].shape for feat in features]
|
316 |
+
max_item_length = max(batch_lens)[0]
|
317 |
+
for idx in range(len(features)):
|
318 |
+
feat = self.processor(features[idx])
|
319 |
+
temp_input_ids = paddle.to_tensor([pad_id] * max_item_length, dtype=paddle.int64)
|
320 |
+
temp_input_ids[: feat["input_ids"].shape[0]] = feat["input_ids"]
|
321 |
+
feat["input_ids"] = temp_input_ids
|
322 |
+
temp_labels = paddle.to_tensor([IGNORE_INDEX] * max_item_length, dtype=paddle.int64)
|
323 |
+
temp_labels[: feat["labels"].shape[0]] = feat["labels"]
|
324 |
+
feat["labels"] = temp_labels
|
325 |
+
feat["attention_mask"] = feat["input_ids"].ne(pad_id)
|
326 |
+
|
327 |
+
# Special handling for labels.
|
328 |
+
# Ensure that tensor is created with the correct type
|
329 |
+
# (it should be automatically the case, but let's make sure of it.)
|
330 |
+
if "label" in first and first["label"] is not None:
|
331 |
+
label = first["label"].item() if isinstance(first["label"], paddle.Tensor) else first["label"]
|
332 |
+
dtype = paddle.int64 if isinstance(label, int) else paddle.float32
|
333 |
+
batch["labels"] = paddle.to_tensor([f["label"] for f in features], dtype=dtype)
|
334 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
335 |
+
if isinstance(first["label_ids"], paddle.Tensor):
|
336 |
+
batch["labels"] = paddle.stack([f["label_ids"] for f in features])
|
337 |
+
else:
|
338 |
+
dtype = paddle.int64 if isinstance(first["label_ids"][0], int) else paddle.float32
|
339 |
+
batch["labels"] = paddle.to_tensor([f["label_ids"] for f in features], dtype=dtype)
|
340 |
+
|
341 |
+
# Handling of all other possible keys.
|
342 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
343 |
+
for k, v in first.items():
|
344 |
+
if (
|
345 |
+
k not in ("label", "label_ids", "pixel_values", "image_flags")
|
346 |
+
and v is not None
|
347 |
+
and not isinstance(v, str)
|
348 |
+
):
|
349 |
+
if isinstance(v, paddle.Tensor):
|
350 |
+
batch[k] = paddle.stack([f[k] for f in features])
|
351 |
+
elif isinstance(v, np.ndarray):
|
352 |
+
batch[k] = paddle.to_tensor(np.stack([f[k] for f in features]))
|
353 |
+
else:
|
354 |
+
batch[k] = paddle.to_tensor([f[k] for f in features])
|
355 |
+
if k in ("pixel_values", "image_flags"):
|
356 |
+
if isinstance(v, paddle.Tensor):
|
357 |
+
batch[k] = paddle.concat([f[k] for f in features])
|
358 |
+
elif isinstance(v, np.ndarray):
|
359 |
+
batch[k] = paddle.concat(np.stack([f[k] for f in features]))
|
360 |
+
else:
|
361 |
+
batch[k] = paddle.concat([f[k] for f in features])
|
362 |
+
return batch
|
PaddleMIX/paddlemix/datasets/dataset.py
ADDED
@@ -0,0 +1,1169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import atexit
|
16 |
+
import inspect
|
17 |
+
import os
|
18 |
+
import time
|
19 |
+
import warnings
|
20 |
+
from collections import namedtuple
|
21 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
import datasets
|
25 |
+
import numpy as np
|
26 |
+
from multiprocess import Pool, RLock
|
27 |
+
from PIL import Image
|
28 |
+
|
29 |
+
import paddlemix
|
30 |
+
|
31 |
+
try:
|
32 |
+
import paddle.distributed as dist
|
33 |
+
except Exception:
|
34 |
+
warnings.warn("paddle.distributed is not contains in you paddle!")
|
35 |
+
|
36 |
+
import importlib
|
37 |
+
from functools import partial
|
38 |
+
|
39 |
+
from paddle.io import Dataset, IterableDataset
|
40 |
+
from paddle.utils.download import _get_unique_endpoints
|
41 |
+
|
42 |
+
from paddlemix.utils.env import DATA_HOME
|
43 |
+
|
44 |
+
__all__ = ["MapDataset", "DatasetBuilder", "IterDataset", "load_dataset", "MixDataset"]
|
45 |
+
|
46 |
+
DATASETS_MODULE_PATH = "paddlemix.datasets."
|
47 |
+
|
48 |
+
# Patch for intranet
|
49 |
+
from datasets import load_dataset as origin_load_dataset # noqa: E402
|
50 |
+
|
51 |
+
|
52 |
+
def load_from_ppvlp(path, *args, **kwargs):
|
53 |
+
ppvlp_path = paddlemix.datasets.__path__[0]
|
54 |
+
new_path = os.path.split(path)[-1]
|
55 |
+
new_path = os.path.join(ppvlp_path, "hf_datasets", new_path + ".py")
|
56 |
+
if os.path.exists(new_path):
|
57 |
+
return origin_load_dataset(new_path, *args, **kwargs)
|
58 |
+
else:
|
59 |
+
return origin_load_dataset(path, *args, **kwargs)
|
60 |
+
|
61 |
+
|
62 |
+
datasets.load_dataset = load_from_ppvlp
|
63 |
+
|
64 |
+
|
65 |
+
class DatasetTuple:
|
66 |
+
def __init__(self, splits):
|
67 |
+
self.identifier_map, identifiers = self._gen_identifier_map(splits)
|
68 |
+
self.tuple_cls = namedtuple("datasets", identifiers)
|
69 |
+
self.tuple = self.tuple_cls(*[None for _ in splits])
|
70 |
+
|
71 |
+
def __getitem__(self, key):
|
72 |
+
if isinstance(key, (int, slice)):
|
73 |
+
return self.tuple[key]
|
74 |
+
if isinstance(key, str):
|
75 |
+
return getattr(self.tuple, self.identifier_map[key])
|
76 |
+
|
77 |
+
def __setitem__(self, key, value):
|
78 |
+
self.tuple = self.tuple._replace(**{self.identifier_map[key]: value})
|
79 |
+
|
80 |
+
def _gen_identifier_map(self, splits):
|
81 |
+
identifier_map = {}
|
82 |
+
identifiers = []
|
83 |
+
for i in range(len(splits)):
|
84 |
+
identifiers.append("splits_" + str(i))
|
85 |
+
identifier_map[splits[i]] = "splits_" + str(i)
|
86 |
+
return identifier_map, identifiers
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.tuple)
|
90 |
+
|
91 |
+
|
92 |
+
def import_main_class(module_path):
|
93 |
+
"""
|
94 |
+
Import a module at module_path and return its DatasetBuilder class.
|
95 |
+
|
96 |
+
"""
|
97 |
+
module_path = DATASETS_MODULE_PATH + module_path
|
98 |
+
module = importlib.import_module(module_path)
|
99 |
+
main_cls_type = DatasetBuilder
|
100 |
+
|
101 |
+
# Find the main class in our imported module
|
102 |
+
module_main_cls = None
|
103 |
+
for name, obj in module.__dict__.items():
|
104 |
+
if isinstance(obj, type) and issubclass(obj, main_cls_type):
|
105 |
+
if name == "DatasetBuilder":
|
106 |
+
continue
|
107 |
+
module_main_cls = obj
|
108 |
+
break
|
109 |
+
|
110 |
+
return module_main_cls
|
111 |
+
|
112 |
+
|
113 |
+
def load_from_hf(path, name=None, splits=None, **kwargs):
|
114 |
+
from datasets import DatasetDict
|
115 |
+
from datasets import load_dataset as load_hf_dataset
|
116 |
+
from datasets.features import ClassLabel
|
117 |
+
|
118 |
+
try:
|
119 |
+
hf_datasets = load_hf_dataset(path, name=name, split=splits, **kwargs)
|
120 |
+
except FileNotFoundError:
|
121 |
+
raise FileNotFoundError("Couldn't find the dataset script for '" + path + "' on PaddleNLP or HuggingFace")
|
122 |
+
else:
|
123 |
+
label_list = []
|
124 |
+
if isinstance(hf_datasets, DatasetDict):
|
125 |
+
datasets = DatasetTuple(list(hf_datasets.keys()))
|
126 |
+
for split, ds in hf_datasets.items():
|
127 |
+
for feature in ds.features.values():
|
128 |
+
if isinstance(feature, ClassLabel):
|
129 |
+
label_list = feature.names
|
130 |
+
datasets[split] = MapDataset(ds, label_list=label_list)
|
131 |
+
elif isinstance(hf_datasets, list):
|
132 |
+
datasets = DatasetTuple(splits)
|
133 |
+
for i, split in enumerate(splits):
|
134 |
+
for feature in hf_datasets[i].features.values():
|
135 |
+
if isinstance(feature, ClassLabel):
|
136 |
+
label_list = feature.names
|
137 |
+
datasets[split] = MapDataset(hf_datasets[i], label_list=label_list)
|
138 |
+
else:
|
139 |
+
for feature in hf_datasets.features.values():
|
140 |
+
if isinstance(feature, ClassLabel):
|
141 |
+
label_list = feature.names
|
142 |
+
datasets = MapDataset(hf_datasets, label_list=label_list)
|
143 |
+
return datasets
|
144 |
+
|
145 |
+
|
146 |
+
def load_dataset(path_or_read_func, name=None, data_files=None, splits=None, lazy=None, **kwargs):
|
147 |
+
"""
|
148 |
+
This method will load a dataset, either form PaddleNLP library or from a
|
149 |
+
self-defined data loading script, by calling functions in `DatasetBuilder`.
|
150 |
+
|
151 |
+
For all the names of datasets in PaddleNLP library, see here: `dataset_list
|
152 |
+
<https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_list.html>`__.
|
153 |
+
|
154 |
+
Either `splits` or `data_files` must be specified.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
path_or_read_func (str|callable): Name of the dataset processing script
|
158 |
+
in PaddleNLP library or a custom data reading function.
|
159 |
+
name (str, optional): Additional name to select a more specific dataset.
|
160 |
+
Defaults to None.
|
161 |
+
data_files (str|list|tuple|dict, optional): Defining the path of dataset
|
162 |
+
files. If None. `splits` must be specified. Defaults to None.
|
163 |
+
splits (str|list|tuple, optional): Which split of the data to load. If None.
|
164 |
+
`data_files` must be specified. Defaults to None.
|
165 |
+
lazy (bool, optional): Weather to return `MapDataset` or an `IterDataset`.
|
166 |
+
True for `IterDataset`. False for `MapDataset`. If None, return the
|
167 |
+
default type of this dataset. Defaults to None.
|
168 |
+
kwargs (dict): Other keyword arguments to be passed to the `DatasetBuilder`.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
A `MapDataset` or `IterDataset` or a tuple of those.
|
172 |
+
|
173 |
+
For how to use this function, please see `dataset_load
|
174 |
+
<https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_load.html>`__
|
175 |
+
and `dataset_self_defined
|
176 |
+
<https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_self_defined.html>`__
|
177 |
+
|
178 |
+
"""
|
179 |
+
if inspect.isfunction(path_or_read_func):
|
180 |
+
assert lazy is not None, "lazy can not be None in custom mode."
|
181 |
+
kwargs["name"] = name
|
182 |
+
kwargs["data_files"] = data_files
|
183 |
+
kwargs["splits"] = splits
|
184 |
+
custom_kwargs = {}
|
185 |
+
for name in inspect.signature(path_or_read_func).parameters.keys():
|
186 |
+
if name in kwargs.keys():
|
187 |
+
custom_kwargs[name] = kwargs[name]
|
188 |
+
|
189 |
+
reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func)
|
190 |
+
return reader_instance.read(**custom_kwargs)
|
191 |
+
else:
|
192 |
+
try:
|
193 |
+
reader_cls = import_main_class(path_or_read_func)
|
194 |
+
except ModuleNotFoundError:
|
195 |
+
datasets = load_from_hf(path_or_read_func, name=name, splits=splits, **kwargs)
|
196 |
+
else:
|
197 |
+
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)
|
198 |
+
# Check if selected name and split is valid in this DatasetBuilder
|
199 |
+
if hasattr(reader_instance, "BUILDER_CONFIGS"):
|
200 |
+
if name in reader_cls.BUILDER_CONFIGS.keys():
|
201 |
+
split_names = reader_cls.BUILDER_CONFIGS[name]["splits"].keys()
|
202 |
+
else:
|
203 |
+
raise ValueError(
|
204 |
+
'Invalid name "{}". Should be one of {}.'.format(name, list(reader_cls.BUILDER_CONFIGS.keys()))
|
205 |
+
)
|
206 |
+
elif hasattr(reader_instance, "SPLITS"):
|
207 |
+
split_names = reader_instance.SPLITS.keys()
|
208 |
+
else:
|
209 |
+
raise AttributeError("Either 'SPLITS' or 'BUILDER_CONFIGS' must be implemented for DatasetBuilder.")
|
210 |
+
|
211 |
+
selected_splits = []
|
212 |
+
if isinstance(splits, list) or isinstance(splits, tuple):
|
213 |
+
selected_splits.extend(splits)
|
214 |
+
else:
|
215 |
+
selected_splits += [splits]
|
216 |
+
|
217 |
+
for split_name in selected_splits:
|
218 |
+
if split_name not in split_names and split_name is not None:
|
219 |
+
raise ValueError('Invalid split "{}". Should be one of {}.'.format(split_name, list(split_names)))
|
220 |
+
|
221 |
+
datasets = reader_instance.read_datasets(data_files=data_files, splits=splits)
|
222 |
+
return datasets
|
223 |
+
|
224 |
+
|
225 |
+
class MapDataset(Dataset):
|
226 |
+
"""
|
227 |
+
Wraps a map-style dataset-like object as an instance of `MapDataset`, and equips it
|
228 |
+
with `map` and other utility methods. All non-magic methods of the raw object
|
229 |
+
are also accessible.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
data (list|Dataset): An object with `__getitem__` and `__len__` methods. It could
|
233 |
+
be a list or a subclass of `paddle.io.Dataset`.
|
234 |
+
kwargs (dict, optional): Other information to be passed to the dataset.
|
235 |
+
|
236 |
+
For examples of this class, please see `dataset_self_defined
|
237 |
+
<https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_self_defined.html>`__.
|
238 |
+
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(self, data, **kwargs):
|
242 |
+
self.data = data
|
243 |
+
self._transform_pipline = []
|
244 |
+
self.new_data = self.data
|
245 |
+
self.info = kwargs
|
246 |
+
self.label_list = self.info.pop("label_list", None)
|
247 |
+
self.vocab_info = self.info.pop("vocab_info", None)
|
248 |
+
|
249 |
+
def _transform(self, data):
|
250 |
+
for fn in self._transform_pipline:
|
251 |
+
data = fn(data)
|
252 |
+
return data
|
253 |
+
|
254 |
+
def __getitem__(self, idx):
|
255 |
+
"""
|
256 |
+
Basic function of `MapDataset` to get sample from dataset with a given
|
257 |
+
index.
|
258 |
+
"""
|
259 |
+
return self._transform(self.new_data[idx]) if self._transform_pipline else self.new_data[idx]
|
260 |
+
|
261 |
+
def __len__(self):
|
262 |
+
"""
|
263 |
+
Returns the number of samples in dataset.
|
264 |
+
"""
|
265 |
+
return len(self.new_data)
|
266 |
+
|
267 |
+
def filter(self, fn, num_workers=0):
|
268 |
+
"""
|
269 |
+
Filters samples by the filter function and uses the filtered data to
|
270 |
+
update this dataset.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
fn (callable): A filter function that takes a sample as input and
|
274 |
+
returns a boolean. Samples that return False would be discarded.
|
275 |
+
num_workers(int, optional): Number of processes for multiprocessing. If
|
276 |
+
set to 0, it doesn't use multiprocessing. Defaults to `0`.
|
277 |
+
"""
|
278 |
+
assert num_workers >= 0, "num_workers should be a non-negative value"
|
279 |
+
if num_workers > 1:
|
280 |
+
shards = [
|
281 |
+
self._shard(num_shards=num_workers, index=index, contiguous=True) for index in range(num_workers)
|
282 |
+
]
|
283 |
+
kwds_per_shard = [dict(self=shards[rank], fn=fn) for rank in range(num_workers)]
|
284 |
+
pool = Pool(num_workers, initargs=(RLock(),))
|
285 |
+
|
286 |
+
results = [pool.apply_async(self.__class__._filter, kwds=kwds) for kwds in kwds_per_shard]
|
287 |
+
transformed_shards = [r.get() for r in results]
|
288 |
+
|
289 |
+
pool.close()
|
290 |
+
pool.join()
|
291 |
+
self.new_data = []
|
292 |
+
for i in range(num_workers):
|
293 |
+
self.new_data += transformed_shards[i].new_data
|
294 |
+
return self
|
295 |
+
else:
|
296 |
+
return self._filter(fn)
|
297 |
+
|
298 |
+
def _filter(self, fn):
|
299 |
+
self.new_data = [self.new_data[idx] for idx in range(len(self.new_data)) if fn(self.new_data[idx])]
|
300 |
+
return self
|
301 |
+
|
302 |
+
def shard(self, num_shards=None, index=None, contiguous=False):
|
303 |
+
self.new_data = self._shard(num_shards=num_shards, index=index, contiguous=contiguous).data
|
304 |
+
return self
|
305 |
+
|
306 |
+
def _shard(self, num_shards=None, index=None, contiguous=False):
|
307 |
+
"""
|
308 |
+
Split the dataset into `num_shards` pieces. Note that the size of each
|
309 |
+
shard might be different because the original dataset may not be evenly
|
310 |
+
divisible.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
num_shards (int, optional): An integer representing the number of
|
314 |
+
data shards. If None, `num_shards` would be number of trainers.
|
315 |
+
Defaults to `None`.
|
316 |
+
index (int, optional): An integer representing the index of the
|
317 |
+
current shard. If None, `index` would be the current trainer rank
|
318 |
+
id. Defaults to `None`.
|
319 |
+
contiguous: (bool, optional): If true, contiguous chunks of data
|
320 |
+
will be select for sharding. And total number of examples will
|
321 |
+
be the same. Otherwise each shard will contain all examples of
|
322 |
+
dataset whose index mod `num_shards` = `index`. Defaults to `False`.
|
323 |
+
"""
|
324 |
+
if num_shards is None:
|
325 |
+
num_shards = dist.get_world_size()
|
326 |
+
if index is None:
|
327 |
+
index = dist.get_rank()
|
328 |
+
|
329 |
+
if contiguous:
|
330 |
+
div = len(self) // num_shards
|
331 |
+
mod = len(self) % num_shards
|
332 |
+
start = div * index + min(index, mod)
|
333 |
+
end = start + div + (1 if index < mod else 0)
|
334 |
+
new_data = [self.new_data[idx] for idx in range(start, end)]
|
335 |
+
else:
|
336 |
+
new_data = [self.new_data[idx] for idx in range(len(self.new_data)) if idx % num_shards == index]
|
337 |
+
|
338 |
+
return MapDataset(new_data)
|
339 |
+
|
340 |
+
def map(self, fn, lazy=True, batched=False, num_workers=0):
|
341 |
+
"""
|
342 |
+
Performs specific function on the dataset to transform and update every sample.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
fn (callable): Transformations to be performed. It receives single
|
346 |
+
sample as argument if batched is False. Else it receives all examples.
|
347 |
+
lazy (bool, optional): If True, transformations would be delayed and
|
348 |
+
performed on demand. Otherwise, transforms all samples at once. Note that
|
349 |
+
if `fn` is stochastic, `lazy` should be True or you will get the same
|
350 |
+
result on all epochs. Defaults to False.
|
351 |
+
batched(bool, optional): If True, transformations would take all examples as
|
352 |
+
input and return a collection of transformed examples. Note that if set
|
353 |
+
True, `lazy` option would be ignored. Defaults to False.
|
354 |
+
num_workers(int, optional): Number of processes for multiprocessing. If
|
355 |
+
set to 0, it doesn't use multiprocessing. Note that if set to positive
|
356 |
+
value, `lazy` option would be ignored. Defaults to 0.
|
357 |
+
"""
|
358 |
+
|
359 |
+
assert num_workers >= 0, "num_workers should be a non-negative value"
|
360 |
+
if num_workers > 1:
|
361 |
+
shards = [
|
362 |
+
self._shard(num_shards=num_workers, index=index, contiguous=True) for index in range(num_workers)
|
363 |
+
]
|
364 |
+
kwds_per_shard = [
|
365 |
+
dict(self=shards[rank], fn=fn, lazy=False, batched=batched) for rank in range(num_workers)
|
366 |
+
]
|
367 |
+
pool = Pool(num_workers, initargs=(RLock(),))
|
368 |
+
results = [pool.apply_async(self.__class__._map, kwds=kwds) for kwds in kwds_per_shard]
|
369 |
+
transformed_shards = [r.get() for r in results]
|
370 |
+
pool.close()
|
371 |
+
pool.join()
|
372 |
+
self.new_data = []
|
373 |
+
for i in range(num_workers):
|
374 |
+
self.new_data += transformed_shards[i].new_data
|
375 |
+
return self
|
376 |
+
else:
|
377 |
+
return self._map(fn, lazy=lazy, batched=batched)
|
378 |
+
|
379 |
+
def _map(self, fn, lazy=True, batched=False):
|
380 |
+
if batched:
|
381 |
+
self.new_data = fn(self.new_data)
|
382 |
+
elif lazy:
|
383 |
+
self._transform_pipline.append(fn)
|
384 |
+
else:
|
385 |
+
self.new_data = [fn(self.new_data[idx]) for idx in range(len(self.new_data))]
|
386 |
+
return self
|
387 |
+
|
388 |
+
|
389 |
+
class IterDataset(IterableDataset):
|
390 |
+
"""
|
391 |
+
Wraps a dataset-like object as an instance of `IterDataset`, and equips it with
|
392 |
+
`map` and other utility methods. All non-magic methods of the raw object
|
393 |
+
also accessible.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
data (Iterable): An object with `__iter__` function. It can be a Iterable or a
|
397 |
+
subclass of `paddle.io.IterableDataset`.
|
398 |
+
kwargs (dict, optional): Other information to be passed to the dataset.
|
399 |
+
|
400 |
+
For examples of this class, please see `dataset_self_defined
|
401 |
+
<https://paddlenlp.readthedocs.io/zh/latest/data_prepare/dataset_self_defined.html>`__.
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(self, data, **kwargs):
|
405 |
+
self.data = data
|
406 |
+
self._transform_pipline = []
|
407 |
+
self._filter_pipline = []
|
408 |
+
|
409 |
+
self.label_list = kwargs.pop("label_list", None)
|
410 |
+
self.vocab_info = kwargs.pop("vocab_info", None)
|
411 |
+
|
412 |
+
def _transform(self, data):
|
413 |
+
for fn in self._transform_pipline:
|
414 |
+
data = fn(data)
|
415 |
+
return data
|
416 |
+
|
417 |
+
def _shard_filter(self, num_samples):
|
418 |
+
return True
|
419 |
+
|
420 |
+
def _filter(self, data):
|
421 |
+
for fn in self._filter_pipline:
|
422 |
+
if not fn(data):
|
423 |
+
return False
|
424 |
+
return True
|
425 |
+
|
426 |
+
def __iter__(self):
|
427 |
+
"""
|
428 |
+
yields sample sequentially.
|
429 |
+
"""
|
430 |
+
num_samples = 0
|
431 |
+
if inspect.isfunction(self.data):
|
432 |
+
for example in self.data():
|
433 |
+
if (not self._filter_pipline or self._filter(self._filter_pipline)) and self._shard_filter(
|
434 |
+
num_samples=num_samples
|
435 |
+
):
|
436 |
+
yield self._transform(example) if self._transform_pipline else example
|
437 |
+
num_samples += 1
|
438 |
+
else:
|
439 |
+
if inspect.isgenerator(self.data):
|
440 |
+
warnings.warn("Reciving generator as data source, data can only be iterated once")
|
441 |
+
for example in self.data:
|
442 |
+
if (not self._filter_pipline or self._filter(self._filter_pipline)) and self._shard_filter(
|
443 |
+
num_samples=num_samples
|
444 |
+
):
|
445 |
+
yield self._transform(example) if self._transform_pipline else example
|
446 |
+
num_samples += 1
|
447 |
+
|
448 |
+
def filter(self, fn):
|
449 |
+
"""
|
450 |
+
Filters samples by the filter function and uses the filtered data to
|
451 |
+
update this dataset.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
fn (callable): A filter function that takes a sample as input and
|
455 |
+
returns a boolean. Samples that return False are discarded.
|
456 |
+
"""
|
457 |
+
|
458 |
+
self._filter_pipline.append(fn)
|
459 |
+
|
460 |
+
return self
|
461 |
+
|
462 |
+
def shard(self, num_shards=None, index=None):
|
463 |
+
"""
|
464 |
+
Split the dataset into `num_shards` pieces.
|
465 |
+
|
466 |
+
Args:
|
467 |
+
num_shards (int, optional): An integer representing the number of
|
468 |
+
data shards. If None, `num_shards` would be number of trainers.
|
469 |
+
Defaults to None.
|
470 |
+
index (int, optional): An integer representing the index of the
|
471 |
+
current shard. If None, `index` would be the current trainer rank
|
472 |
+
id. Defaults to None.
|
473 |
+
"""
|
474 |
+
if num_shards is None:
|
475 |
+
num_shards = dist.get_world_size()
|
476 |
+
if index is None:
|
477 |
+
index = dist.get_rank()
|
478 |
+
|
479 |
+
def sharder(num_shards, index, num_samples):
|
480 |
+
if num_samples % num_shards == index:
|
481 |
+
return True
|
482 |
+
else:
|
483 |
+
return False
|
484 |
+
|
485 |
+
fn = partial(sharder, num_shards=num_shards, index=index)
|
486 |
+
self._shard_filter = fn
|
487 |
+
return self
|
488 |
+
|
489 |
+
def map(self, fn):
|
490 |
+
"""
|
491 |
+
Performs specific function on the dataset to transform and update every sample.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
fn (callable): Transformations to be performed. It receives single
|
495 |
+
sample as argument.
|
496 |
+
"""
|
497 |
+
|
498 |
+
self._transform_pipline.append(fn)
|
499 |
+
|
500 |
+
return self
|
501 |
+
|
502 |
+
|
503 |
+
class DatasetBuilder:
|
504 |
+
"""
|
505 |
+
A base class for all DatasetBuilder. It provides a `read()` function to turn
|
506 |
+
a data file into a MapDataset or IterDataset.
|
507 |
+
|
508 |
+
`_get_data()` function and `_read()` function should be implemented to download
|
509 |
+
data file and read data file into a `Iterable` of the examples.
|
510 |
+
|
511 |
+
For how to define a custom `DatasetBuilder`, please see `contribute_dataset
|
512 |
+
<https://paddlenlp.readthedocs.io/zh/latest/community/contribute_dataset.html>`__.
|
513 |
+
"""
|
514 |
+
|
515 |
+
lazy = False
|
516 |
+
|
517 |
+
def __init__(self, lazy=None, name=None, **config):
|
518 |
+
if lazy is not None:
|
519 |
+
self.lazy = lazy
|
520 |
+
self.name = name
|
521 |
+
self.config = config
|
522 |
+
|
523 |
+
def read_datasets(self, splits=None, data_files=None):
|
524 |
+
def remove_if_exit(filepath):
|
525 |
+
if isinstance(filepath, (list, tuple)):
|
526 |
+
for file in filepath:
|
527 |
+
try:
|
528 |
+
os.remove(file)
|
529 |
+
except OSError:
|
530 |
+
pass
|
531 |
+
else:
|
532 |
+
try:
|
533 |
+
os.remove(filepath)
|
534 |
+
except OSError:
|
535 |
+
pass
|
536 |
+
|
537 |
+
if data_files is None:
|
538 |
+
if splits is None:
|
539 |
+
splits = (
|
540 |
+
list(self.BUILDER_CONFIGS[self.name]["splits"].keys())
|
541 |
+
if hasattr(self, "BUILDER_CONFIGS")
|
542 |
+
else list(self.SPLITS.keys())
|
543 |
+
)
|
544 |
+
|
545 |
+
assert (
|
546 |
+
isinstance(splits, str)
|
547 |
+
or (isinstance(splits, list) and isinstance(splits[0], str))
|
548 |
+
or (isinstance(splits, tuple) and isinstance(splits[0], str))
|
549 |
+
), "`splits` should be a string or list of string or a tuple of string."
|
550 |
+
|
551 |
+
if isinstance(splits, str):
|
552 |
+
splits = [splits]
|
553 |
+
datasets = DatasetTuple(splits)
|
554 |
+
parallel_env = dist.ParallelEnv()
|
555 |
+
unique_endpoints = _get_unique_endpoints(parallel_env.trainer_endpoints[:])
|
556 |
+
# move register hook to first and register togather
|
557 |
+
lock_files = []
|
558 |
+
for split in splits:
|
559 |
+
lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
|
560 |
+
if self.name is not None:
|
561 |
+
lock_file = lock_file + "." + self.name
|
562 |
+
lock_file += "." + split + ".done" + "." + str(os.getppid())
|
563 |
+
lock_files.append(lock_file)
|
564 |
+
# Must register to all procs to make the lock file can be removed
|
565 |
+
# when any proc breaks. Otherwise, the single registered proc may
|
566 |
+
# not receive proper singal send by the parent proc to exit.
|
567 |
+
atexit.register(lambda: remove_if_exit(lock_files))
|
568 |
+
for split in splits:
|
569 |
+
filename = self._get_data(split)
|
570 |
+
lock_file = os.path.join(DATA_HOME, self.__class__.__name__)
|
571 |
+
if self.name is not None:
|
572 |
+
lock_file = lock_file + "." + self.name
|
573 |
+
lock_file += "." + split + ".done" + "." + str(os.getppid())
|
574 |
+
# `lock_file` indicates the finished status of`_get_data`.
|
575 |
+
# `_get_data` only works in the `unique_endpoints` specified
|
576 |
+
# proc since `get_path_from_url` only work for it. The other
|
577 |
+
# procs wait `_get_data` to be finished.
|
578 |
+
if parallel_env.current_endpoint in unique_endpoints:
|
579 |
+
f = open(lock_file, "w")
|
580 |
+
f.close()
|
581 |
+
else:
|
582 |
+
while not os.path.exists(lock_file):
|
583 |
+
time.sleep(1)
|
584 |
+
datasets[split] = self.read(filename=filename, split=split)
|
585 |
+
else:
|
586 |
+
assert (
|
587 |
+
isinstance(data_files, str) or isinstance(data_files, tuple) or isinstance(data_files, list)
|
588 |
+
), "`data_files` should be a string or tuple or list of strings."
|
589 |
+
if isinstance(data_files, str):
|
590 |
+
data_files = [data_files]
|
591 |
+
default_split = "train"
|
592 |
+
if splits:
|
593 |
+
if isinstance(splits, str):
|
594 |
+
splits = [splits]
|
595 |
+
datasets = DatasetTuple(splits)
|
596 |
+
assert len(splits) == len(
|
597 |
+
data_files
|
598 |
+
), "Number of `splits` and number of `data_files` should be the same if you want to specify the split of loacl data file."
|
599 |
+
for i in range(len(data_files)):
|
600 |
+
datasets[splits[i]] = self.read(filename=data_files[i], split=splits[i])
|
601 |
+
else:
|
602 |
+
datasets = DatasetTuple(["split" + str(i) for i in range(len(data_files))])
|
603 |
+
for i in range(len(data_files)):
|
604 |
+
datasets["split" + str(i)] = self.read(filename=data_files[i], split=default_split)
|
605 |
+
|
606 |
+
return datasets if len(datasets) > 1 else datasets[0]
|
607 |
+
|
608 |
+
def read(self, filename, split="train"):
|
609 |
+
"""
|
610 |
+
Returns a dataset containing all the examples that can be read from the file path.
|
611 |
+
|
612 |
+
If `self.lazy` is False, this eagerly reads all instances from `self._read()`
|
613 |
+
and returns a `MapDataset`.
|
614 |
+
|
615 |
+
If `self.lazy` is True, this returns an `IterDataset`, which internally
|
616 |
+
relies on the generator created from `self._read()` to lazily produce examples.
|
617 |
+
In this case your implementation of `_read()` must also be lazy
|
618 |
+
(that is, not load all examples into memory at once).
|
619 |
+
|
620 |
+
Args:
|
621 |
+
filename (str): Path of data file to read, usually provided by `_get_data`
|
622 |
+
function.
|
623 |
+
split (str, optional): The split name of selected dataset. This only makes
|
624 |
+
a different when data files of different splits have different structures.
|
625 |
+
|
626 |
+
Returns:
|
627 |
+
A `MapDataset|IterDataset`.
|
628 |
+
"""
|
629 |
+
|
630 |
+
label_list = self.get_labels()
|
631 |
+
vocab_info = self.get_vocab()
|
632 |
+
|
633 |
+
def _create_dict(labels):
|
634 |
+
# For multiple labels in the form of list.
|
635 |
+
if isinstance(labels[0], list) or isinstance(labels[0], tuple):
|
636 |
+
label_dict = []
|
637 |
+
for sub_labels in labels:
|
638 |
+
sub_dict = {}
|
639 |
+
for i, label in enumerate(sub_labels):
|
640 |
+
sub_dict[label] = i
|
641 |
+
label_dict.append(sub_dict)
|
642 |
+
else:
|
643 |
+
label_dict = {}
|
644 |
+
for i, label in enumerate(labels):
|
645 |
+
label_dict[label] = i
|
646 |
+
return label_dict
|
647 |
+
|
648 |
+
def _convert_label_to_id(labels, label_dict):
|
649 |
+
if isinstance(labels, list) or isinstance(labels, tuple):
|
650 |
+
for label_idx in range(len(labels)):
|
651 |
+
labels[label_idx] = label_dict[labels[label_idx]]
|
652 |
+
else:
|
653 |
+
labels = label_dict[labels]
|
654 |
+
return labels
|
655 |
+
|
656 |
+
if self.lazy:
|
657 |
+
|
658 |
+
def generate_examples():
|
659 |
+
generator = (
|
660 |
+
self._read(filename, split) if self._read.__code__.co_argcount > 2 else self._read(filename)
|
661 |
+
)
|
662 |
+
for example in generator:
|
663 |
+
# We need to check if the example contains label column and confirm its name.
|
664 |
+
# For now we only allow `label` or `labels` to be the name of label column.
|
665 |
+
if "labels" in example.keys():
|
666 |
+
label_col = "labels"
|
667 |
+
elif "label" in example.keys():
|
668 |
+
label_col = "label"
|
669 |
+
else:
|
670 |
+
label_col = None
|
671 |
+
|
672 |
+
# Convert class label to label ids.
|
673 |
+
if label_list is not None and example.get(label_col, None):
|
674 |
+
label_dict = _create_dict(label_list)
|
675 |
+
# For multiple labels in the form of list.
|
676 |
+
if isinstance(label_dict, list):
|
677 |
+
for idx, sub_dict in enumerate(label_dict):
|
678 |
+
example[label_col][idx] = _convert_label_to_id(example[label_col][idx], sub_dict)
|
679 |
+
else:
|
680 |
+
example[label_col] = _convert_label_to_id(example[label_col], label_dict)
|
681 |
+
|
682 |
+
yield example
|
683 |
+
else:
|
684 |
+
yield example
|
685 |
+
|
686 |
+
return IterDataset(generate_examples(), label_list=label_list, vocab_info=vocab_info)
|
687 |
+
else:
|
688 |
+
examples = self._read(filename, split) if self._read.__code__.co_argcount > 2 else self._read(filename)
|
689 |
+
|
690 |
+
# Then some validation.
|
691 |
+
if not isinstance(examples, list):
|
692 |
+
examples = list(examples)
|
693 |
+
|
694 |
+
if not examples:
|
695 |
+
raise ValueError(
|
696 |
+
"No instances were read from the given filepath {}. " "Is the path correct?".format(filename)
|
697 |
+
)
|
698 |
+
|
699 |
+
# We need to check if the example contains label column and confirm its name.
|
700 |
+
# For now we only allow `label` or `labels` to be the name of label column.
|
701 |
+
if isinstance(examples[0], dict):
|
702 |
+
if "labels" in examples[0].keys():
|
703 |
+
label_col = "labels"
|
704 |
+
elif "label" in examples[0].keys():
|
705 |
+
label_col = "label"
|
706 |
+
else:
|
707 |
+
label_col = None
|
708 |
+
|
709 |
+
# Convert class label to label ids.
|
710 |
+
if label_list is not None and examples[0].get(label_col, None):
|
711 |
+
label_dict = _create_dict(label_list)
|
712 |
+
for idx in range(len(examples)):
|
713 |
+
# For multiple labels in the form of list.
|
714 |
+
if isinstance(label_dict, list):
|
715 |
+
for i, sub_dict in enumerate(label_dict):
|
716 |
+
examples[idx][label_col][i] = _convert_label_to_id(examples[idx][label_col][i], sub_dict)
|
717 |
+
else:
|
718 |
+
examples[idx][label_col] = _convert_label_to_id(examples[idx][label_col], label_dict)
|
719 |
+
|
720 |
+
return MapDataset(examples, label_list=label_list, vocab_info=vocab_info)
|
721 |
+
|
722 |
+
def _read(self, filename: str, *args):
|
723 |
+
"""
|
724 |
+
Reads examples from the given file_path and returns them as an
|
725 |
+
`Iterable` (which could be a list or a generator).
|
726 |
+
|
727 |
+
This method must be implemented in self-defined `DatasetBuilder`.
|
728 |
+
"""
|
729 |
+
raise NotImplementedError
|
730 |
+
|
731 |
+
def _get_data(self, mode: str):
|
732 |
+
"""
|
733 |
+
Downloads examples from the given URL and customized split
|
734 |
+
informations and returns a filepath.
|
735 |
+
|
736 |
+
This method must be implemented in self-defined `DatasetBuilder`.
|
737 |
+
"""
|
738 |
+
raise NotImplementedError
|
739 |
+
|
740 |
+
def get_labels(self):
|
741 |
+
"""
|
742 |
+
Returns list of class labels of the dataset if specified.
|
743 |
+
"""
|
744 |
+
return None
|
745 |
+
|
746 |
+
def get_vocab(self):
|
747 |
+
"""
|
748 |
+
Returns vocab file path of the dataset if specified.
|
749 |
+
"""
|
750 |
+
return None
|
751 |
+
|
752 |
+
|
753 |
+
class SimpleBuilder(DatasetBuilder):
|
754 |
+
def __init__(self, lazy, read_func):
|
755 |
+
self._read = read_func
|
756 |
+
self.lazy = lazy
|
757 |
+
|
758 |
+
def read(self, **kwargs):
|
759 |
+
if self.lazy:
|
760 |
+
|
761 |
+
def generate_examples():
|
762 |
+
generator = self._read(**kwargs)
|
763 |
+
for example in generator:
|
764 |
+
yield example
|
765 |
+
|
766 |
+
return IterDataset(generate_examples)
|
767 |
+
else:
|
768 |
+
examples = self._read(**kwargs)
|
769 |
+
if hasattr(examples, "__len__") and hasattr(examples, "__getitem__"):
|
770 |
+
return MapDataset(examples)
|
771 |
+
else:
|
772 |
+
return MapDataset(list(examples))
|
773 |
+
|
774 |
+
|
775 |
+
def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
|
776 |
+
"""Checks if a file is an allowed extension.
|
777 |
+
|
778 |
+
Args:
|
779 |
+
filename (string): path to a file
|
780 |
+
extensions (tuple of strings): extensions to consider (lowercase)
|
781 |
+
|
782 |
+
Returns:
|
783 |
+
bool: True if the filename ends with one of given extensions
|
784 |
+
"""
|
785 |
+
return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
|
786 |
+
|
787 |
+
|
788 |
+
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
|
789 |
+
"""Finds the class folders in a dataset.
|
790 |
+
|
791 |
+
See :class:`DatasetFolder` for details.
|
792 |
+
"""
|
793 |
+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
|
794 |
+
if not classes:
|
795 |
+
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
|
796 |
+
|
797 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
798 |
+
return classes, class_to_idx
|
799 |
+
|
800 |
+
|
801 |
+
def make_dataset(
|
802 |
+
directory: str,
|
803 |
+
class_to_idx: Optional[Dict[str, int]] = None,
|
804 |
+
extensions: Optional[Union[str, Tuple[str, ...]]] = None,
|
805 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
806 |
+
) -> List[Tuple[str, int]]:
|
807 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
808 |
+
|
809 |
+
See :class:`DatasetFolder` for details.
|
810 |
+
|
811 |
+
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
|
812 |
+
by default.
|
813 |
+
"""
|
814 |
+
directory = os.path.expanduser(directory)
|
815 |
+
|
816 |
+
if class_to_idx is None:
|
817 |
+
_, class_to_idx = find_classes(directory)
|
818 |
+
elif not class_to_idx:
|
819 |
+
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
|
820 |
+
|
821 |
+
both_none = extensions is None and is_valid_file is None
|
822 |
+
both_something = extensions is not None and is_valid_file is not None
|
823 |
+
if both_none or both_something:
|
824 |
+
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
|
825 |
+
|
826 |
+
if extensions is not None:
|
827 |
+
|
828 |
+
def is_valid_file(x: str) -> bool:
|
829 |
+
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
|
830 |
+
|
831 |
+
is_valid_file = cast(Callable[[str], bool], is_valid_file)
|
832 |
+
|
833 |
+
instances = []
|
834 |
+
available_classes = set()
|
835 |
+
for target_class in sorted(class_to_idx.keys()):
|
836 |
+
class_index = class_to_idx[target_class]
|
837 |
+
target_dir = os.path.join(directory, target_class)
|
838 |
+
if not os.path.isdir(target_dir):
|
839 |
+
continue
|
840 |
+
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
841 |
+
for fname in sorted(fnames):
|
842 |
+
path = os.path.join(root, fname)
|
843 |
+
if is_valid_file(path):
|
844 |
+
item = path, class_index
|
845 |
+
instances.append(item)
|
846 |
+
|
847 |
+
if target_class not in available_classes:
|
848 |
+
available_classes.add(target_class)
|
849 |
+
|
850 |
+
empty_classes = set(class_to_idx.keys()) - available_classes
|
851 |
+
if empty_classes:
|
852 |
+
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
|
853 |
+
if extensions is not None:
|
854 |
+
msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
|
855 |
+
raise FileNotFoundError(msg)
|
856 |
+
|
857 |
+
return instances
|
858 |
+
|
859 |
+
|
860 |
+
class DatasetFolder(Dataset):
|
861 |
+
"""A generic data loader.
|
862 |
+
|
863 |
+
This default directory structure can be customized by overriding the
|
864 |
+
:meth:`find_classes` method.
|
865 |
+
|
866 |
+
Args:
|
867 |
+
root (string): Root directory path.
|
868 |
+
loader (callable): A function to load a sample given its path.
|
869 |
+
extensions (tuple[string]): A list of allowed extensions.
|
870 |
+
both extensions and is_valid_file should not be passed.
|
871 |
+
transform (callable, optional): A function/transform that takes in
|
872 |
+
a sample and returns a transformed version.
|
873 |
+
E.g, ``transforms.RandomCrop`` for images.
|
874 |
+
target_transform (callable, optional): A function/transform that takes
|
875 |
+
in the target and transforms it.
|
876 |
+
is_valid_file (callable, optional): A function that takes path of a file
|
877 |
+
and check if the file is a valid file (used to check of corrupt files)
|
878 |
+
both extensions and is_valid_file should not be passed.
|
879 |
+
|
880 |
+
Attributes:
|
881 |
+
classes (list): List of the class names sorted alphabetically.
|
882 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
883 |
+
samples (list): List of (sample path, class_index) tuples
|
884 |
+
targets (list): The class_index value for each image in the dataset
|
885 |
+
"""
|
886 |
+
|
887 |
+
def __init__(
|
888 |
+
self,
|
889 |
+
root: str,
|
890 |
+
loader: Callable[[str], Any],
|
891 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
892 |
+
transform: Optional[Callable] = None,
|
893 |
+
target_transform: Optional[Callable] = None,
|
894 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
895 |
+
) -> None:
|
896 |
+
self.root = root
|
897 |
+
self.transform = transform
|
898 |
+
self.target_transform = target_transform
|
899 |
+
|
900 |
+
classes, class_to_idx = self.find_classes(self.root)
|
901 |
+
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
|
902 |
+
print(f"find total {len(classes)} classes and {len(samples)} images.")
|
903 |
+
|
904 |
+
self.loader = loader
|
905 |
+
self.extensions = extensions
|
906 |
+
|
907 |
+
self.classes = classes
|
908 |
+
self.class_to_idx = class_to_idx
|
909 |
+
self.samples = samples
|
910 |
+
self.targets = [s[1] for s in samples]
|
911 |
+
|
912 |
+
@staticmethod
|
913 |
+
def make_dataset(
|
914 |
+
directory: str,
|
915 |
+
class_to_idx: Dict[str, int],
|
916 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
917 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
918 |
+
) -> List[Tuple[str, int]]:
|
919 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
920 |
+
|
921 |
+
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
|
922 |
+
|
923 |
+
Args:
|
924 |
+
directory (str): root dataset directory, corresponding to ``self.root``.
|
925 |
+
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
|
926 |
+
extensions (optional): A list of allowed extensions.
|
927 |
+
Either extensions or is_valid_file should be passed. Defaults to None.
|
928 |
+
is_valid_file (optional): A function that takes path of a file
|
929 |
+
and checks if the file is a valid file
|
930 |
+
(used to check of corrupt files) both extensions and
|
931 |
+
is_valid_file should not be passed. Defaults to None.
|
932 |
+
|
933 |
+
Raises:
|
934 |
+
ValueError: In case ``class_to_idx`` is empty.
|
935 |
+
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
|
936 |
+
FileNotFoundError: In case no valid file was found for any class.
|
937 |
+
|
938 |
+
Returns:
|
939 |
+
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
|
940 |
+
"""
|
941 |
+
if class_to_idx is None:
|
942 |
+
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
|
943 |
+
# find_classes() function, instead of using that of the find_classes() method, which
|
944 |
+
# is potentially overridden and thus could have a different logic.
|
945 |
+
raise ValueError("The class_to_idx parameter cannot be None.")
|
946 |
+
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
|
947 |
+
|
948 |
+
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
|
949 |
+
"""Find the class folders in a dataset structured as follows::
|
950 |
+
|
951 |
+
directory/
|
952 |
+
├── class_x
|
953 |
+
│ ├── xxx.ext
|
954 |
+
│ ├── xxy.ext
|
955 |
+
│ └── ...
|
956 |
+
│ └── xxz.ext
|
957 |
+
└── class_y
|
958 |
+
├── 123.ext
|
959 |
+
├── nsdf3.ext
|
960 |
+
└── ...
|
961 |
+
└── asd932_.ext
|
962 |
+
|
963 |
+
This method can be overridden to only consider
|
964 |
+
a subset of classes, or to adapt to a different dataset directory structure.
|
965 |
+
|
966 |
+
Args:
|
967 |
+
directory(str): Root directory path, corresponding to ``self.root``
|
968 |
+
|
969 |
+
Raises:
|
970 |
+
FileNotFoundError: If ``dir`` has no class folders.
|
971 |
+
|
972 |
+
Returns:
|
973 |
+
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
|
974 |
+
"""
|
975 |
+
return find_classes(directory)
|
976 |
+
|
977 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
978 |
+
"""
|
979 |
+
Args:
|
980 |
+
index (int): Index
|
981 |
+
|
982 |
+
Returns:
|
983 |
+
tuple: (sample, target) where target is class_index of the target class.
|
984 |
+
"""
|
985 |
+
path, target = self.samples[index]
|
986 |
+
sample = self.loader(path)
|
987 |
+
if self.transform is not None:
|
988 |
+
sample = self.transform(sample)
|
989 |
+
if self.target_transform is not None:
|
990 |
+
target = self.target_transform(target)
|
991 |
+
return sample, np.int32(target)
|
992 |
+
|
993 |
+
def __len__(self) -> int:
|
994 |
+
return len(self.samples)
|
995 |
+
|
996 |
+
@property
|
997 |
+
def class_num(self):
|
998 |
+
return len(set(self.classes))
|
999 |
+
|
1000 |
+
|
1001 |
+
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
|
1002 |
+
|
1003 |
+
_image_backend = "pil"
|
1004 |
+
|
1005 |
+
|
1006 |
+
def pil_loader(path: str) -> Image.Image:
|
1007 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
1008 |
+
with open(path, "rb") as f:
|
1009 |
+
img = Image.open(f)
|
1010 |
+
return img.convert("RGB")
|
1011 |
+
|
1012 |
+
|
1013 |
+
def set_image_backend(backend):
|
1014 |
+
"""
|
1015 |
+
Specifies the package used to load images.
|
1016 |
+
|
1017 |
+
Args:
|
1018 |
+
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
|
1019 |
+
The :mod:`accimage` package uses the Intel IPP library. It is
|
1020 |
+
generally faster than PIL, but does not support as many operations.
|
1021 |
+
"""
|
1022 |
+
global _image_backend
|
1023 |
+
if backend not in ["pil", "cv2"]:
|
1024 |
+
raise ValueError(f"Invalid backend '{backend}'. Options are 'pil' and 'cv2'")
|
1025 |
+
_image_backend = backend
|
1026 |
+
|
1027 |
+
|
1028 |
+
def get_image_backend():
|
1029 |
+
"""
|
1030 |
+
Gets the name of the package used to load images
|
1031 |
+
"""
|
1032 |
+
return _image_backend
|
1033 |
+
|
1034 |
+
|
1035 |
+
def cv2_loader(path: str):
|
1036 |
+
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
|
1037 |
+
|
1038 |
+
|
1039 |
+
def default_loader(path: str) -> Any:
|
1040 |
+
if get_image_backend() == "cv2":
|
1041 |
+
return cv2_loader(path)
|
1042 |
+
else:
|
1043 |
+
return pil_loader(path)
|
1044 |
+
|
1045 |
+
|
1046 |
+
class ImageFolder(DatasetFolder):
|
1047 |
+
"""A generic data loader where the images are arranged in this way by default: ::
|
1048 |
+
|
1049 |
+
root/dog/xxx.png
|
1050 |
+
root/dog/xxy.png
|
1051 |
+
root/dog/[...]/xxz.png
|
1052 |
+
|
1053 |
+
root/cat/123.png
|
1054 |
+
root/cat/nsdf3.png
|
1055 |
+
root/cat/[...]/asd932_.png
|
1056 |
+
|
1057 |
+
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
|
1058 |
+
the same methods can be overridden to customize the dataset.
|
1059 |
+
|
1060 |
+
Args:
|
1061 |
+
root (string): Root directory path.
|
1062 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
1063 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
1064 |
+
target_transform (callable, optional): A function/transform that takes in the
|
1065 |
+
target and transforms it.
|
1066 |
+
loader (callable, optional): A function to load an image given its path.
|
1067 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
1068 |
+
and check if the file is a valid file (used to check of corrupt files)
|
1069 |
+
|
1070 |
+
Attributes:
|
1071 |
+
classes (list): List of the class names sorted alphabetically.
|
1072 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
1073 |
+
imgs (list): List of (image path, class_index) tuples
|
1074 |
+
"""
|
1075 |
+
|
1076 |
+
def __init__(
|
1077 |
+
self,
|
1078 |
+
root: str,
|
1079 |
+
transform: Optional[Callable] = None,
|
1080 |
+
target_transform: Optional[Callable] = None,
|
1081 |
+
loader: Callable[[str], Any] = default_loader,
|
1082 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
1083 |
+
):
|
1084 |
+
super().__init__(
|
1085 |
+
root,
|
1086 |
+
loader,
|
1087 |
+
IMG_EXTENSIONS if is_valid_file is None else None,
|
1088 |
+
transform=transform,
|
1089 |
+
target_transform=target_transform,
|
1090 |
+
is_valid_file=is_valid_file,
|
1091 |
+
)
|
1092 |
+
self.imgs = self.samples
|
1093 |
+
|
1094 |
+
|
1095 |
+
import bisect
|
1096 |
+
|
1097 |
+
|
1098 |
+
class ConcatDataset(Dataset):
|
1099 |
+
r"""Dataset as a concatenation of multiple datasets.
|
1100 |
+
|
1101 |
+
This class is useful to assemble different existing datasets.
|
1102 |
+
|
1103 |
+
Args:
|
1104 |
+
datasets (sequence): List of datasets to be concatenated
|
1105 |
+
"""
|
1106 |
+
datasets: List[Dataset]
|
1107 |
+
cumulative_sizes: List[int]
|
1108 |
+
|
1109 |
+
@staticmethod
|
1110 |
+
def cumsum(sequence):
|
1111 |
+
r, s = [], 0
|
1112 |
+
for e in sequence:
|
1113 |
+
l = len(e)
|
1114 |
+
r.append(l + s)
|
1115 |
+
s += l
|
1116 |
+
return r
|
1117 |
+
|
1118 |
+
def __init__(self, datasets) -> None:
|
1119 |
+
super().__init__()
|
1120 |
+
self.datasets = list(datasets)
|
1121 |
+
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
|
1122 |
+
for d in self.datasets:
|
1123 |
+
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
|
1124 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
1125 |
+
|
1126 |
+
def __len__(self):
|
1127 |
+
return self.cumulative_sizes[-1]
|
1128 |
+
|
1129 |
+
def __getitem__(self, idx):
|
1130 |
+
if idx < 0:
|
1131 |
+
if -idx > len(self):
|
1132 |
+
raise ValueError("absolute value of index should not exceed dataset length")
|
1133 |
+
idx = len(self) + idx
|
1134 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
1135 |
+
if dataset_idx == 0:
|
1136 |
+
sample_idx = idx
|
1137 |
+
else:
|
1138 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
1139 |
+
return self.datasets[dataset_idx][sample_idx]
|
1140 |
+
|
1141 |
+
@property
|
1142 |
+
def cummulative_sizes(self):
|
1143 |
+
warnings.warn(
|
1144 |
+
"cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2
|
1145 |
+
)
|
1146 |
+
return self.cumulative_sizes
|
1147 |
+
|
1148 |
+
|
1149 |
+
class MixDataset(Dataset):
|
1150 |
+
datasets_names: List[Dict]
|
1151 |
+
|
1152 |
+
def __init__(self, datasets_names) -> None:
|
1153 |
+
super().__init__()
|
1154 |
+
self.datasets_names = list(datasets_names)
|
1155 |
+
self.datasets = []
|
1156 |
+
for d in self.datasets_names:
|
1157 |
+
name = d["name"]
|
1158 |
+
data_files = d["data_files"] if "data_files" in d else None
|
1159 |
+
splits = d["splits"] if "splits" in d else None
|
1160 |
+
chat_template = d["chat_template"] if "chat_template" in d else None
|
1161 |
+
self.datasets.append(load_dataset(name, data_files=data_files, splits=splits, chat_template=chat_template))
|
1162 |
+
|
1163 |
+
self.datasets = ConcatDataset(self.datasets)
|
1164 |
+
|
1165 |
+
def __len__(self):
|
1166 |
+
return len(self.datasets)
|
1167 |
+
|
1168 |
+
def __getitem__(self, idx):
|
1169 |
+
return self.datasets[idx]
|
PaddleMIX/paddlemix/datasets/got_dataset.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import copy
|
16 |
+
import json
|
17 |
+
import logging
|
18 |
+
import random
|
19 |
+
from typing import Dict
|
20 |
+
import paddle
|
21 |
+
from paddle import Tensor
|
22 |
+
import paddlenlp
|
23 |
+
from PIL import Image, ImageFile
|
24 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
25 |
+
from ..models.GOT.utils.conversation import (
|
26 |
+
SeparatorStyle,
|
27 |
+
conv_mpt,
|
28 |
+
)
|
29 |
+
from dataclasses import dataclass
|
30 |
+
from functools import partial
|
31 |
+
from typing import List, Union
|
32 |
+
from megfile import smart_glob
|
33 |
+
from natsort import natsorted
|
34 |
+
|
35 |
+
|
36 |
+
IGNORE_INDEX = -100
|
37 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
38 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
39 |
+
|
40 |
+
LOGDIR = "log"
|
41 |
+
|
42 |
+
IGNORE_INDEX = -100
|
43 |
+
# DEFAULT_PAD_TOKEN = "[PAD]"
|
44 |
+
|
45 |
+
DEFAULT_PAD_TOKEN = "<|endoftext|>"
|
46 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
47 |
+
DEFAULT_BOS_TOKEN = "</s>"
|
48 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
49 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
50 |
+
DEFAULT_BOX_TOKEN = "<box>"
|
51 |
+
|
52 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<imgpad>"
|
53 |
+
|
54 |
+
DEFAULT_IM_START_TOKEN = "<img>"
|
55 |
+
DEFAULT_IM_END_TOKEN = "</img>"
|
56 |
+
|
57 |
+
|
58 |
+
class BaseDataset(paddle.io.Dataset):
|
59 |
+
def __init__(self, datasets: str, tokenizer: paddlenlp.transformers.PretrainedTokenizer, multimodal_cfg: dict):
|
60 |
+
super(BaseDataset, self).__init__()
|
61 |
+
self.tokenizer = tokenizer
|
62 |
+
self.multimodal_cfg = multimodal_cfg
|
63 |
+
|
64 |
+
logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")
|
65 |
+
|
66 |
+
def image_processor(self, image):
|
67 |
+
# processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit)
|
68 |
+
processor_high = self.multimodal_cfg[
|
69 |
+
"image_processor_high"
|
70 |
+
] # the second processor, usually is the designed image encoder (sam/swin/cnn)
|
71 |
+
image_high = image.copy()
|
72 |
+
image_high = processor_high(image_high)
|
73 |
+
return image_high
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.list_data_dict)
|
77 |
+
|
78 |
+
def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
|
79 |
+
pass
|
80 |
+
|
81 |
+
|
82 |
+
class ConversationDataset(BaseDataset):
|
83 |
+
"""Conversation format dataset stage2 fine-tuning."""
|
84 |
+
|
85 |
+
def __init__(self, meta_path, tokenizer, multimodal_cfg):
|
86 |
+
super(ConversationDataset, self).__init__(meta_path, tokenizer, multimodal_cfg)
|
87 |
+
# v0 version format conversation
|
88 |
+
# default_conversation = conv_templates["mpt"]
|
89 |
+
logging.warning("Formatting inputs into conversation type: mpt-fixed")
|
90 |
+
logging.warning("Loading data...")
|
91 |
+
|
92 |
+
list_data_dict = []
|
93 |
+
list_image_path = []
|
94 |
+
|
95 |
+
# add your data [data1, data2, data3, .....]
|
96 |
+
# got_data_dict = {
|
97 |
+
# "pdf-ocr": ["data1"],
|
98 |
+
# #'scene-ocr': ["data3", "data4"]
|
99 |
+
# # ......
|
100 |
+
# }
|
101 |
+
# for name_all in datasets.split("+"):
|
102 |
+
# for name in got_data_dict[name_all]:
|
103 |
+
ds_collections = json.loads(open(meta_path).read())
|
104 |
+
#ds_collections = json.load(open(meta_path, 'r'))
|
105 |
+
for ds_idx, ds_name in enumerate(ds_collections.keys()):
|
106 |
+
# dataset = CONVERSATION_DATA[ds_name]
|
107 |
+
dataset = ds_collections[ds_name]
|
108 |
+
|
109 |
+
data_path = dataset["annotations"]
|
110 |
+
#image_root = dataset["images"]
|
111 |
+
if data_path.endswith(".json"):
|
112 |
+
data = json.load(open(data_path, "r"))
|
113 |
+
elif data_path.endswith(".jsonl"):
|
114 |
+
with open(data_path, "r") as f:
|
115 |
+
data = f.readlines()
|
116 |
+
for ii in range(len(data)):
|
117 |
+
data[ii] = json.loads(data[ii])
|
118 |
+
else:
|
119 |
+
raise ValueError(f"Unknown file extension: {data_path}")
|
120 |
+
|
121 |
+
list_data_dict.extend(data)
|
122 |
+
|
123 |
+
image_path = dataset["images"] # image_root
|
124 |
+
|
125 |
+
list_image_path.extend([image_path] * len(data))
|
126 |
+
|
127 |
+
logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
|
128 |
+
|
129 |
+
assert len(list_data_dict) == len(list_image_path)
|
130 |
+
logging.warning(f"{len(list_data_dict)} conversations in total.")
|
131 |
+
a_new_list = list(zip(list_data_dict, list_image_path))
|
132 |
+
random.shuffle(a_new_list)
|
133 |
+
list_data_dict_new, list_image_path_new = zip(*a_new_list)
|
134 |
+
self.list_data_dict = list_data_dict_new
|
135 |
+
self.list_image_path = list_image_path_new
|
136 |
+
|
137 |
+
self.im_patch_token = 151859
|
138 |
+
self.im_start_token = 151857
|
139 |
+
self.im_end_token = 151858
|
140 |
+
|
141 |
+
def multimodal_processor(self, sources, flag_num_patches):
|
142 |
+
for source in sources:
|
143 |
+
if self.multimodal_cfg["sep_image_conv_front"]:
|
144 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
|
145 |
+
source[0]["value"] = source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
146 |
+
source[0]["value"] = DEFAULT_IMAGE_TOKEN + conv_mpt.sep + conv_mpt.roles[0] + ": " + source[0]["value"]
|
147 |
+
|
148 |
+
for sentence in source:
|
149 |
+
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg["image_token_len"] * flag_num_patches
|
150 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
151 |
+
# sentence["value"] = str(sentence["value"]).replace('\qquad', '\quad')
|
152 |
+
sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
153 |
+
return sources
|
154 |
+
|
155 |
+
def _tokenize_fn(self, strings):
|
156 |
+
"""Tokenize a list of strings."""
|
157 |
+
tokenized_list = [
|
158 |
+
self.tokenizer(
|
159 |
+
text,
|
160 |
+
return_tensors="pd",
|
161 |
+
padding="longest",
|
162 |
+
max_length=self.tokenizer.model_max_length,
|
163 |
+
truncation=True,
|
164 |
+
)
|
165 |
+
for text in strings
|
166 |
+
]
|
167 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
168 |
+
input_ids_lens = labels_lens = [
|
169 |
+
tokenized.input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)).sum().item()
|
170 |
+
for tokenized in tokenized_list
|
171 |
+
]
|
172 |
+
return dict(
|
173 |
+
input_ids=input_ids,
|
174 |
+
labels=labels,
|
175 |
+
input_ids_lens=input_ids_lens,
|
176 |
+
labels_lens=labels_lens,
|
177 |
+
)
|
178 |
+
|
179 |
+
def _mask_targets(self, target, tokenized_lens, speakers):
|
180 |
+
# cur_idx = 0
|
181 |
+
cur_idx = tokenized_lens[0]
|
182 |
+
tokenized_lens = tokenized_lens[1:]
|
183 |
+
target[:cur_idx] = IGNORE_INDEX
|
184 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
185 |
+
if speaker.lower() == "human":
|
186 |
+
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
|
187 |
+
cur_idx += tokenized_len
|
188 |
+
|
189 |
+
def token_processor(self, sources, image_name):
|
190 |
+
conv = conv_mpt.copy()
|
191 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
192 |
+
|
193 |
+
# Apply prompt templates
|
194 |
+
conversations = []
|
195 |
+
for i, source in enumerate(sources):
|
196 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
197 |
+
# Skip the first one if it is not from human
|
198 |
+
source = source[1:]
|
199 |
+
|
200 |
+
conv.messages = []
|
201 |
+
for j, sentence in enumerate(source):
|
202 |
+
role = roles[sentence["from"]]
|
203 |
+
assert role == conv.roles[j % 2], f"{i}"
|
204 |
+
conv.append_message(role, sentence["value"])
|
205 |
+
conversations.append(conv.get_prompt())
|
206 |
+
|
207 |
+
# Tokenize conversations
|
208 |
+
input_ids = self.tokenizer(
|
209 |
+
conversations,
|
210 |
+
return_tensors="pd",
|
211 |
+
padding="longest",
|
212 |
+
max_length=self.tokenizer.model_max_length,
|
213 |
+
truncation=True,
|
214 |
+
).input_ids
|
215 |
+
|
216 |
+
# input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
217 |
+
targets = input_ids.clone()
|
218 |
+
assert conv.sep_style == SeparatorStyle.MPT
|
219 |
+
|
220 |
+
# Mask targets
|
221 |
+
sep = conv.sep + conv.roles[1]
|
222 |
+
for conversation, target in zip(conversations, targets):
|
223 |
+
total_len = int(target.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)).sum())
|
224 |
+
|
225 |
+
rounds = conversation.split(conv.sep)
|
226 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
227 |
+
for conv_idx in range(3, len(rounds), 2):
|
228 |
+
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt
|
229 |
+
cur_len = 0
|
230 |
+
target[:cur_len] = IGNORE_INDEX
|
231 |
+
for i, rou in enumerate(re_rounds):
|
232 |
+
if rou == "":
|
233 |
+
break
|
234 |
+
|
235 |
+
parts = rou.split(sep)
|
236 |
+
if len(parts) != 2:
|
237 |
+
break
|
238 |
+
parts[0] += sep
|
239 |
+
round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
|
240 |
+
# round_len = len(tokenizer_image_token(rou, self.tokenizer)) + len(tokenizer_image_token(conv.sep, self.tokenizer))
|
241 |
+
# instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
242 |
+
instruction_len = len(self.tokenizer(parts[0]).input_ids)
|
243 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
244 |
+
|
245 |
+
cur_len += round_len
|
246 |
+
target[cur_len:] = IGNORE_INDEX
|
247 |
+
|
248 |
+
if cur_len < self.tokenizer.model_max_length:
|
249 |
+
if cur_len != total_len:
|
250 |
+
target[:] = IGNORE_INDEX
|
251 |
+
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
|
252 |
+
print(image_name)
|
253 |
+
|
254 |
+
return dict(
|
255 |
+
input_ids=input_ids,
|
256 |
+
labels=targets,
|
257 |
+
)
|
258 |
+
|
259 |
+
def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
|
260 |
+
# data = self.list_data_dict[i]
|
261 |
+
data = copy.deepcopy(self.list_data_dict[i])
|
262 |
+
|
263 |
+
if isinstance(data, dict):
|
264 |
+
image_list = []
|
265 |
+
image_high_list = []
|
266 |
+
flag_num_patches = 1
|
267 |
+
if "image" in data:
|
268 |
+
image_path = self.list_image_path[i]
|
269 |
+
image_file = data["image"]
|
270 |
+
|
271 |
+
# multi-crop or multi page, only support .png files
|
272 |
+
if (
|
273 |
+
0
|
274 |
+
): # ('.jpg' not in image_file and '.png' not in image_file and '.jpeg' not in image_file) and ('.jpg' not in image_path and '.png' not in image_path and '.jpeg' not in image_path):
|
275 |
+
if image_file[0] == "/":
|
276 |
+
patch_dir = image_path[:-1] + image_file
|
277 |
+
patches = smart_glob(patch_dir + "*.png")
|
278 |
+
else:
|
279 |
+
patch_dir = image_path + image_file
|
280 |
+
patches = smart_glob(patch_dir + "*.png")
|
281 |
+
|
282 |
+
# print(patches)
|
283 |
+
if not patches:
|
284 |
+
print(f"cannot glob the dir {patch_dir}.")
|
285 |
+
return self.__getitem__(0)
|
286 |
+
|
287 |
+
# sort multi images by name
|
288 |
+
patches = natsorted(patches)
|
289 |
+
flag_num_patches = len(patches)
|
290 |
+
|
291 |
+
for patch in patches:
|
292 |
+
try:
|
293 |
+
image = Image.open(patch).convert("RGB")
|
294 |
+
except:
|
295 |
+
print(f"cannot identify image file {patch}.")
|
296 |
+
return self.__getitem__(0)
|
297 |
+
|
298 |
+
try:
|
299 |
+
img = self.image_processor(image)
|
300 |
+
image_list.append(img)
|
301 |
+
image_high_list.append(img)
|
302 |
+
except:
|
303 |
+
print(
|
304 |
+
f"image {image_path + image_file + patch} are broken or grayscale! we thus select 0-th sample instead!"
|
305 |
+
)
|
306 |
+
return self.__getitem__(0)
|
307 |
+
|
308 |
+
else:
|
309 |
+
flag_num_patches = 1
|
310 |
+
try:
|
311 |
+
image = Image.open(image_path + image_file).convert("RGB")
|
312 |
+
except:
|
313 |
+
print(f"cannot identify image file {image_file}.")
|
314 |
+
return self.__getitem__(0)
|
315 |
+
|
316 |
+
try:
|
317 |
+
image = self.image_processor(image)
|
318 |
+
except:
|
319 |
+
print(f"image {image_file} are broken or grayscale! we thus select 0-th sample instead!")
|
320 |
+
return self.__getitem__(0)
|
321 |
+
|
322 |
+
conversations = self.multimodal_processor([data["conversations"]], flag_num_patches)
|
323 |
+
# print(conversations)
|
324 |
+
# exit()
|
325 |
+
else:
|
326 |
+
conversations = [data]
|
327 |
+
|
328 |
+
# align with fastchat & llava here, put the conversation into a list for tokenization
|
329 |
+
image_name = image_path + image_file
|
330 |
+
data_dict = self.token_processor(conversations, image_name)
|
331 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
|
332 |
+
|
333 |
+
if isinstance(data, dict) and "image" in data:
|
334 |
+
if image_list and image_high_list:
|
335 |
+
data_dict["image"] = image_list
|
336 |
+
data_dict["image_high"] = image_high_list
|
337 |
+
else:
|
338 |
+
data_dict["image"] = [image]
|
339 |
+
data_dict["image_high"] = [image]
|
340 |
+
else:
|
341 |
+
# crop_size = self.multimodal_cfg['image_processor'].crop_size
|
342 |
+
# data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
|
343 |
+
# Vary for two image, GOT does not use the data_dict['image]
|
344 |
+
data_dict["image"] = [paddle.zeros([3, 1024, 1024])]
|
345 |
+
data_dict["image_high"] = [paddle.zeros([3, 1024, 1024])]
|
346 |
+
return data_dict
|
347 |
+
|
348 |
+
|
349 |
+
# helpers
|
350 |
+
def pad_sequence_paddle(sequences, padding_value=0):
|
351 |
+
"""
|
352 |
+
Implement a function similar to PyTorch's pad_sequence in PaddlePaddle.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
- sequences (list of Tensor): The list of sequences to be padded.
|
356 |
+
- padding_value (float, optional): The value used for padding, default is 0.
|
357 |
+
|
358 |
+
Returns:
|
359 |
+
- Tensor: The result of padding all sequences to the same length.
|
360 |
+
"""
|
361 |
+
# Calculate the maximum length
|
362 |
+
max_len = max([seq.shape[0] for seq in sequences])
|
363 |
+
|
364 |
+
# Pad sequences
|
365 |
+
padded_sequences = []
|
366 |
+
for seq in sequences:
|
367 |
+
# Calculate the length to pad
|
368 |
+
padding_len = max_len - seq.shape[0]
|
369 |
+
|
370 |
+
# Create a padding tensor
|
371 |
+
if padding_len > 0:
|
372 |
+
padding_tensor = paddle.full([padding_len] + list(seq.shape[1:]), padding_value, dtype=seq.dtype)
|
373 |
+
# Concatenate the original sequence and the padding tensor
|
374 |
+
padded_seq = paddle.concat([seq, padding_tensor], axis=0)
|
375 |
+
else:
|
376 |
+
padded_seq = seq
|
377 |
+
|
378 |
+
padded_sequences.append(padded_seq)
|
379 |
+
|
380 |
+
# Stack the padded sequences to form a batch
|
381 |
+
padded_batch = paddle.stack(padded_sequences, axis=0)
|
382 |
+
return padded_batch
|
383 |
+
|
384 |
+
|
385 |
+
def orig_pad_sequence(
|
386 |
+
sequences: Union[Tensor, List[Tensor]],
|
387 |
+
batch_first: bool = False,
|
388 |
+
padding_value: float = 0.0,
|
389 |
+
) -> Tensor:
|
390 |
+
if batch_first:
|
391 |
+
return pad_sequence_paddle(sequences, padding_value)
|
392 |
+
else:
|
393 |
+
assert False, "Not implemented"
|
394 |
+
|
395 |
+
|
396 |
+
@dataclass
|
397 |
+
class DataCollatorForSupervisedDataset(object):
|
398 |
+
tokenizer: paddlenlp.transformers.PretrainedTokenizer
|
399 |
+
|
400 |
+
def __call__(self, instances):
|
401 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
402 |
+
images = [paddle.stack(instance["image"]) for instance in instances]
|
403 |
+
images_high = [paddle.stack(instance["image_high"]) for instance in instances]
|
404 |
+
images = list(zip(images, images_high))
|
405 |
+
|
406 |
+
pad_sequence = partial(orig_pad_sequence, batch_first=True)
|
407 |
+
|
408 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
409 |
+
|
410 |
+
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
411 |
+
|
412 |
+
batch = dict(
|
413 |
+
input_ids=input_ids,
|
414 |
+
labels=labels,
|
415 |
+
attention_mask=input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)),
|
416 |
+
images=images,
|
417 |
+
)
|
418 |
+
return batch
|
419 |
+
|
420 |
+
|
421 |
+
def make_supervised_data_module(interleave, with_box, tokenizer, data_args):
|
422 |
+
assert data_args.conversation_version == "mpt"
|
423 |
+
|
424 |
+
train_dataset = ConversationDataset(
|
425 |
+
tokenizer=tokenizer,
|
426 |
+
# datasets=data_args.datasets,
|
427 |
+
meta_path=data_args.meta_path,
|
428 |
+
multimodal_cfg=dict(
|
429 |
+
sep_image_conv_front=data_args.sep_image_conv_front,
|
430 |
+
image_token_len=data_args.image_token_len,
|
431 |
+
image_aspect_ratio=data_args.image_aspect_ratio,
|
432 |
+
use_im_start_end=data_args.use_im_start_end,
|
433 |
+
image_processor=data_args.image_processor,
|
434 |
+
image_processor_high=data_args.image_processor_high,
|
435 |
+
box_limit=data_args.box_limit,
|
436 |
+
),
|
437 |
+
)
|
438 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
439 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
PaddleMIX/paddlemix/datasets/internvl_dataset.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import io
|
16 |
+
import sys
|
17 |
+
IGNORE_TOKEN_ID = -100 # LabelSmoother.ignore_index
|
18 |
+
import random
|
19 |
+
from typing import Dict
|
20 |
+
from collections.abc import Sequence
|
21 |
+
import paddle
|
22 |
+
import paddle.vision.transforms as T
|
23 |
+
from paddlemix.models.internvl2.conversation import get_conv_template
|
24 |
+
from PIL import Image
|
25 |
+
from paddle.io import ConcatDataset, WeightedRandomSampler
|
26 |
+
from paddlemix.models.internvl2.constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
|
27 |
+
IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
|
28 |
+
SIGLIP_MEAN, SIGLIP_STD)
|
29 |
+
|
30 |
+
class WeightedConcatDataset(ConcatDataset):
|
31 |
+
def __init__(self, datasets, weights):
|
32 |
+
super().__init__(datasets)
|
33 |
+
self.weights = paddle.to_tensor(weights, dtype='float32')
|
34 |
+
self.total_size = sum(len(d) for d in datasets)
|
35 |
+
self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
return iter(self.sampler)
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return self.total_size
|
42 |
+
|
43 |
+
|
44 |
+
def pil_loader(img_str):
|
45 |
+
buff = io.BytesIO(img_str)
|
46 |
+
img = Image.open(buff)
|
47 |
+
return img.convert('RGB')
|
48 |
+
|
49 |
+
|
50 |
+
def expand2square(pil_img, background_color):
|
51 |
+
width, height = pil_img.size
|
52 |
+
if width == height:
|
53 |
+
return pil_img
|
54 |
+
elif width > height:
|
55 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
56 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
57 |
+
return result
|
58 |
+
else:
|
59 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
60 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
61 |
+
return result
|
62 |
+
|
63 |
+
|
64 |
+
def simulate_jpeg_degradation(quality):
|
65 |
+
def jpeg_degrade(img):
|
66 |
+
with io.BytesIO() as output:
|
67 |
+
img.convert('RGB').save(output, format='JPEG', quality=quality)
|
68 |
+
output.seek(0) # Move the reading cursor to the start of the stream
|
69 |
+
img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
|
70 |
+
return img_jpeg
|
71 |
+
return jpeg_degrade
|
72 |
+
|
73 |
+
|
74 |
+
# Define the JPEG compression quality range, pre-create all JPEG compression functions
|
75 |
+
qualities = list(range(75, 101))
|
76 |
+
jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
|
77 |
+
|
78 |
+
|
79 |
+
class Lambda:
|
80 |
+
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
lambd (function): Lambda/function to be used for transform.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, lambd):
|
87 |
+
#_log_api_usage_once(self)
|
88 |
+
if not callable(lambd):
|
89 |
+
raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
|
90 |
+
self.lambd = lambd
|
91 |
+
|
92 |
+
def __call__(self, img):
|
93 |
+
return self.lambd(img)
|
94 |
+
|
95 |
+
def __repr__(self) -> str:
|
96 |
+
return f"{self.__class__.__name__}()"
|
97 |
+
|
98 |
+
|
99 |
+
class RandomTransforms:
|
100 |
+
"""Base class for a list of transformations with randomness
|
101 |
+
|
102 |
+
Args:
|
103 |
+
transforms (sequence): list of transformations
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, transforms):
|
107 |
+
#_log_api_usage_once(self)
|
108 |
+
if not isinstance(transforms, Sequence):
|
109 |
+
raise TypeError("Argument transforms should be a sequence")
|
110 |
+
self.transforms = transforms
|
111 |
+
|
112 |
+
def __call__(self, *args, **kwargs):
|
113 |
+
raise NotImplementedError()
|
114 |
+
|
115 |
+
def __repr__(self) -> str:
|
116 |
+
format_string = self.__class__.__name__ + "("
|
117 |
+
for t in self.transforms:
|
118 |
+
format_string += "\n"
|
119 |
+
format_string += f" {t}"
|
120 |
+
format_string += "\n)"
|
121 |
+
return format_string
|
122 |
+
|
123 |
+
|
124 |
+
class RandomChoice(RandomTransforms):
|
125 |
+
"""Apply single transformation randomly picked from a list. This transform does not support torchscript."""
|
126 |
+
|
127 |
+
def __init__(self, transforms, p=None):
|
128 |
+
super().__init__(transforms)
|
129 |
+
if p is not None and not isinstance(p, Sequence):
|
130 |
+
raise TypeError("Argument p should be a sequence")
|
131 |
+
self.p = p
|
132 |
+
|
133 |
+
def __call__(self, *args):
|
134 |
+
t = random.choices(self.transforms, weights=self.p)[0]
|
135 |
+
return t(*args)
|
136 |
+
|
137 |
+
def __repr__(self) -> str:
|
138 |
+
return f"{super().__repr__()}(p={self.p})"
|
139 |
+
|
140 |
+
|
141 |
+
def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
|
142 |
+
if normalize_type == 'imagenet':
|
143 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
144 |
+
elif normalize_type == 'clip':
|
145 |
+
MEAN, STD = CLIP_MEAN, CLIP_STD
|
146 |
+
elif normalize_type == 'siglip':
|
147 |
+
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
|
148 |
+
else:
|
149 |
+
raise NotImplementedError
|
150 |
+
if is_train: # use data augumentation
|
151 |
+
transform = T.Compose([
|
152 |
+
Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
153 |
+
RandomChoice([Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
|
154 |
+
T.Resize((input_size, input_size), interpolation='bicubic'),
|
155 |
+
T.ToTensor(),
|
156 |
+
T.Normalize(mean=MEAN, std=STD)
|
157 |
+
])
|
158 |
+
else:
|
159 |
+
if pad2square is False: # now we use this transform function by default
|
160 |
+
# run this
|
161 |
+
transform = T.Compose([
|
162 |
+
Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
163 |
+
T.Resize((input_size, input_size), interpolation='bicubic'),
|
164 |
+
T.ToTensor(),
|
165 |
+
T.Normalize(mean=MEAN, std=STD)
|
166 |
+
])
|
167 |
+
else:
|
168 |
+
transform = T.Compose([
|
169 |
+
Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
170 |
+
Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
|
171 |
+
T.Resize((input_size, input_size), interpolation='bicubic'),
|
172 |
+
T.ToTensor(),
|
173 |
+
T.Normalize(mean=MEAN, std=STD)
|
174 |
+
])
|
175 |
+
|
176 |
+
return transform
|
177 |
+
|
178 |
+
|
179 |
+
def preprocess(
|
180 |
+
template_name,
|
181 |
+
sources,
|
182 |
+
tokenizer,
|
183 |
+
num_image_token_list: list,
|
184 |
+
text_only: bool = False,
|
185 |
+
group_by_length: bool = False,
|
186 |
+
use_packed_ds: bool = False,
|
187 |
+
ds_name: str = None,
|
188 |
+
num_image: int = 1,
|
189 |
+
):
|
190 |
+
conv = get_conv_template(template_name)
|
191 |
+
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
|
192 |
+
|
193 |
+
# Apply prompt templates
|
194 |
+
conversations = []
|
195 |
+
for i, source in enumerate(sources):
|
196 |
+
if roles[source[0]['from']] != conv.roles[0]:
|
197 |
+
# Skip the first one if it is not from human
|
198 |
+
source = source[1:]
|
199 |
+
|
200 |
+
conv.messages = []
|
201 |
+
for j, sentence in enumerate(source):
|
202 |
+
role = roles[sentence['from']]
|
203 |
+
assert role == conv.roles[j % 2], f'{i}'
|
204 |
+
conv.append_message(role, sentence['value'])
|
205 |
+
conversations.append(conv.get_prompt())
|
206 |
+
|
207 |
+
if not text_only:
|
208 |
+
new_conversations = []
|
209 |
+
for conversation in conversations:
|
210 |
+
for i in range(num_image):
|
211 |
+
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
|
212 |
+
conversation = conversation.replace('<image>', image_tokens, 1)
|
213 |
+
new_conversations.append(conversation)
|
214 |
+
conversations = new_conversations
|
215 |
+
|
216 |
+
# Tokenize conversations
|
217 |
+
input_ids = tokenizer(
|
218 |
+
conversations,
|
219 |
+
return_tensors='pd',
|
220 |
+
padding=False if group_by_length or use_packed_ds else 'max_length',
|
221 |
+
max_length=tokenizer.model_max_length,
|
222 |
+
truncation=True,
|
223 |
+
).input_ids
|
224 |
+
targets = input_ids.clone()
|
225 |
+
|
226 |
+
# assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
|
227 |
+
|
228 |
+
# Mask targets. Only compute loss on the assistant outputs.
|
229 |
+
sep = conv.sep + conv.roles[1] + ': '
|
230 |
+
for conversation, target in zip(conversations, targets):
|
231 |
+
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum())
|
232 |
+
|
233 |
+
turns = conversation.split(conv.sep2)
|
234 |
+
cur_len = 1
|
235 |
+
target[:cur_len] = IGNORE_TOKEN_ID
|
236 |
+
for i, turn in enumerate(turns):
|
237 |
+
if turn == '':
|
238 |
+
break
|
239 |
+
turn_len = len(tokenizer(turn).input_ids)
|
240 |
+
|
241 |
+
parts = turn.split(sep)
|
242 |
+
if len(parts) != 2:
|
243 |
+
break
|
244 |
+
parts[0] += sep
|
245 |
+
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
|
246 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
247 |
+
|
248 |
+
if i != 0 and not tokenizer.legacy:
|
249 |
+
# The legacy and non-legacy modes handle special tokens differently
|
250 |
+
instruction_len -= 1
|
251 |
+
|
252 |
+
# Ignore the user instructions
|
253 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
|
254 |
+
cur_len += turn_len
|
255 |
+
|
256 |
+
if i != 0 and not tokenizer.legacy:
|
257 |
+
# The legacy and non-legacy modes handle special tokens differently
|
258 |
+
cur_len -= 1
|
259 |
+
|
260 |
+
target[cur_len:] = IGNORE_TOKEN_ID
|
261 |
+
|
262 |
+
if False: # Inspect and check the correctness of masking
|
263 |
+
z = target.clone()
|
264 |
+
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
|
265 |
+
logger.info(tokenizer.decode(z))
|
266 |
+
exit()
|
267 |
+
|
268 |
+
if cur_len < tokenizer.model_max_length:
|
269 |
+
if cur_len != total_len:
|
270 |
+
target[:] = IGNORE_TOKEN_ID
|
271 |
+
print(
|
272 |
+
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
|
273 |
+
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
|
274 |
+
)
|
275 |
+
sys.stdout.flush()
|
276 |
+
|
277 |
+
return dict(
|
278 |
+
input_ids=input_ids,
|
279 |
+
labels=targets,
|
280 |
+
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
def preprocess_mpt(
|
285 |
+
template_name,
|
286 |
+
sources,
|
287 |
+
tokenizer,
|
288 |
+
num_image_token_list: list,
|
289 |
+
text_only: bool = False,
|
290 |
+
group_by_length: bool = False,
|
291 |
+
use_packed_ds: bool = False,
|
292 |
+
ds_name: str = None,
|
293 |
+
num_image: int = 1
|
294 |
+
) -> Dict:
|
295 |
+
conv = get_conv_template(template_name)
|
296 |
+
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
|
297 |
+
|
298 |
+
# Apply prompt templates
|
299 |
+
conversations = []
|
300 |
+
for i, source in enumerate(sources):
|
301 |
+
if roles[source[0]['from']] != conv.roles[0]:
|
302 |
+
# Skip the first one if it is not from human
|
303 |
+
source = source[1:]
|
304 |
+
|
305 |
+
conv.messages = []
|
306 |
+
for j, sentence in enumerate(source):
|
307 |
+
role = roles[sentence['from']]
|
308 |
+
assert role == conv.roles[j % 2], f'{i}'
|
309 |
+
conv.append_message(role, sentence['value'])
|
310 |
+
conversations.append(conv.get_prompt())
|
311 |
+
|
312 |
+
if not text_only:
|
313 |
+
new_conversations = []
|
314 |
+
for conversation in conversations:
|
315 |
+
for i in range(num_image):
|
316 |
+
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
|
317 |
+
conversation = conversation.replace('<image>', image_tokens, 1)
|
318 |
+
new_conversations.append(conversation)
|
319 |
+
conversations = new_conversations
|
320 |
+
|
321 |
+
# Tokenize conversations
|
322 |
+
input_ids = tokenizer(
|
323 |
+
conversations,
|
324 |
+
return_tensors='pd',
|
325 |
+
padding=False if group_by_length or use_packed_ds else 'max_length',
|
326 |
+
max_length=tokenizer.model_max_length,
|
327 |
+
truncation=True,
|
328 |
+
).input_ids
|
329 |
+
targets = input_ids.clone()
|
330 |
+
|
331 |
+
# Mask targets. Only compute loss on the assistant outputs.
|
332 |
+
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
|
333 |
+
for conversation, target in zip(conversations, targets):
|
334 |
+
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum())
|
335 |
+
|
336 |
+
turns = conversation.split(conv.sep)
|
337 |
+
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
|
338 |
+
for conv_idx in range(3, len(turns), 2):
|
339 |
+
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
|
340 |
+
cur_len = 0
|
341 |
+
target[:cur_len] = IGNORE_TOKEN_ID
|
342 |
+
for i, turn in enumerate(re_turns):
|
343 |
+
if turn == '':
|
344 |
+
break
|
345 |
+
turn_len = len(tokenizer(turn).input_ids) + 1
|
346 |
+
|
347 |
+
parts = turn.split(sep)
|
348 |
+
if len(parts) != 2:
|
349 |
+
break
|
350 |
+
parts[0] += sep
|
351 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
352 |
+
|
353 |
+
# Ignore the user instructions
|
354 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
|
355 |
+
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
|
356 |
+
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
|
357 |
+
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
|
358 |
+
cur_len += turn_len
|
359 |
+
|
360 |
+
target[cur_len:] = IGNORE_TOKEN_ID
|
361 |
+
|
362 |
+
if cur_len < tokenizer.model_max_length:
|
363 |
+
if cur_len != total_len:
|
364 |
+
target[:] = IGNORE_TOKEN_ID
|
365 |
+
print(
|
366 |
+
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
|
367 |
+
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
|
368 |
+
)
|
369 |
+
sys.stdout.flush()
|
370 |
+
|
371 |
+
return dict(
|
372 |
+
input_ids=input_ids,
|
373 |
+
labels=targets,
|
374 |
+
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
|
375 |
+
)
|
376 |
+
|
377 |
+
|
378 |
+
def preprocess_phi3(
|
379 |
+
template_name,
|
380 |
+
sources,
|
381 |
+
tokenizer,
|
382 |
+
num_image_token_list: list,
|
383 |
+
text_only: bool = False,
|
384 |
+
group_by_length: bool = False,
|
385 |
+
use_packed_ds: bool = False,
|
386 |
+
ds_name: str = None,
|
387 |
+
num_image: int = 1
|
388 |
+
) -> Dict:
|
389 |
+
conv = get_conv_template(template_name)
|
390 |
+
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
|
391 |
+
|
392 |
+
# Apply prompt templates
|
393 |
+
conversations = []
|
394 |
+
for i, source in enumerate(sources):
|
395 |
+
if roles[source[0]['from']] != conv.roles[0]:
|
396 |
+
# Skip the first one if it is not from human
|
397 |
+
source = source[1:]
|
398 |
+
|
399 |
+
conv.messages = []
|
400 |
+
for j, sentence in enumerate(source):
|
401 |
+
role = roles[sentence['from']]
|
402 |
+
assert role == conv.roles[j % 2], f'{i}'
|
403 |
+
conv.append_message(role, sentence['value'])
|
404 |
+
conversations.append(conv.get_prompt())
|
405 |
+
|
406 |
+
if not text_only:
|
407 |
+
new_conversations = []
|
408 |
+
for conversation in conversations:
|
409 |
+
for i in range(num_image):
|
410 |
+
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
|
411 |
+
conversation = conversation.replace('<image>', image_tokens, 1)
|
412 |
+
new_conversations.append(conversation)
|
413 |
+
conversations = new_conversations
|
414 |
+
|
415 |
+
# Tokenize conversations
|
416 |
+
tokenizer.padding_side = 'right'
|
417 |
+
input_ids = tokenizer(
|
418 |
+
conversations,
|
419 |
+
return_tensors='pd',
|
420 |
+
padding=False if group_by_length or use_packed_ds else 'max_length',
|
421 |
+
max_length=tokenizer.model_max_length,
|
422 |
+
truncation=True,
|
423 |
+
).input_ids
|
424 |
+
targets = input_ids.clone()
|
425 |
+
|
426 |
+
# Mask targets. Only compute loss on the assistant outputs.
|
427 |
+
sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
|
428 |
+
for conversation, target in zip(conversations, targets):
|
429 |
+
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum())
|
430 |
+
|
431 |
+
turns = conversation.split(conv.sep)
|
432 |
+
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
|
433 |
+
for conv_idx in range(3, len(turns), 2):
|
434 |
+
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
|
435 |
+
cur_len = 1
|
436 |
+
target[:cur_len] = IGNORE_TOKEN_ID
|
437 |
+
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
|
438 |
+
target[target == endoftext_id] = IGNORE_TOKEN_ID
|
439 |
+
|
440 |
+
for i, turn in enumerate(re_turns):
|
441 |
+
if turn == '':
|
442 |
+
break
|
443 |
+
if i == 0:
|
444 |
+
turn_len = len(tokenizer(turn).input_ids)
|
445 |
+
else:
|
446 |
+
turn_len = len(tokenizer(turn).input_ids) - 1
|
447 |
+
parts = turn.split(sep)
|
448 |
+
if len(parts) != 2:
|
449 |
+
break
|
450 |
+
parts[0] += sep
|
451 |
+
|
452 |
+
if i == 0:
|
453 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
454 |
+
else:
|
455 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
456 |
+
|
457 |
+
# Ignore the user instructions
|
458 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
|
459 |
+
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
|
460 |
+
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
|
461 |
+
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
|
462 |
+
cur_len += turn_len
|
463 |
+
|
464 |
+
target[cur_len:] = IGNORE_TOKEN_ID
|
465 |
+
|
466 |
+
if False: # Inspect and check the correctness of masking
|
467 |
+
z = target.clone()
|
468 |
+
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
|
469 |
+
print(repr(tokenizer.decode(z)))
|
470 |
+
|
471 |
+
if cur_len < tokenizer.model_max_length:
|
472 |
+
if cur_len != total_len:
|
473 |
+
target[:] = IGNORE_TOKEN_ID
|
474 |
+
print(
|
475 |
+
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
|
476 |
+
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
|
477 |
+
)
|
478 |
+
sys.stdout.flush()
|
479 |
+
|
480 |
+
return dict(
|
481 |
+
input_ids=input_ids,
|
482 |
+
labels=targets,
|
483 |
+
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
|
484 |
+
)
|
485 |
+
|
486 |
+
|
487 |
+
def preprocess_internlm(
|
488 |
+
template_name,
|
489 |
+
sources,
|
490 |
+
tokenizer,
|
491 |
+
num_image_token_list: list,
|
492 |
+
text_only: bool = False,
|
493 |
+
group_by_length: bool = False,
|
494 |
+
use_packed_ds: bool = False,
|
495 |
+
ds_name: str = None,
|
496 |
+
num_image: int = 1
|
497 |
+
) -> Dict:
|
498 |
+
conv = get_conv_template(template_name)
|
499 |
+
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
|
500 |
+
|
501 |
+
# Apply prompt templates
|
502 |
+
conversations = []
|
503 |
+
for i, source in enumerate(sources):
|
504 |
+
if roles[source[0]['from']] != conv.roles[0]:
|
505 |
+
# Skip the first one if it is not from human
|
506 |
+
source = source[1:]
|
507 |
+
|
508 |
+
conv.messages = []
|
509 |
+
for j, sentence in enumerate(source):
|
510 |
+
role = roles[sentence['from']]
|
511 |
+
assert role == conv.roles[j % 2], f'{i}'
|
512 |
+
sentence['value'] = sentence['value'].strip()
|
513 |
+
conv.append_message(role, sentence['value'])
|
514 |
+
conversations.append(conv.get_prompt())
|
515 |
+
|
516 |
+
if not text_only:
|
517 |
+
new_conversations = []
|
518 |
+
for conversation in conversations:
|
519 |
+
for i in range(num_image):
|
520 |
+
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
|
521 |
+
conversation = conversation.replace('<image>', image_tokens, 1)
|
522 |
+
new_conversations.append(conversation)
|
523 |
+
conversations = new_conversations
|
524 |
+
|
525 |
+
# Tokenize conversations
|
526 |
+
input_ids = tokenizer(
|
527 |
+
conversations,
|
528 |
+
return_tensors='pd',
|
529 |
+
padding=False if group_by_length or use_packed_ds else 'max_length',
|
530 |
+
max_length=tokenizer.model_max_length,
|
531 |
+
truncation=True,
|
532 |
+
).input_ids
|
533 |
+
targets = input_ids.clone()
|
534 |
+
|
535 |
+
new_targets = []
|
536 |
+
# print('tokenizer.pad_token_id:\n', tokenizer.pad_token_id) # 151643
|
537 |
+
# print('targets', targets, targets.shape, targets.sum().item())
|
538 |
+
# [[151644, 8948 , 198 , ..., 103978, 1773 , 151645]] [1, 1918] 281157253
|
539 |
+
for conversation, target in zip(conversations, targets):
|
540 |
+
total_len = int(target.not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum()) # 浦语里面 pad_token_id = eos_token_id
|
541 |
+
cur_len = 1
|
542 |
+
target[:cur_len] = IGNORE_TOKEN_ID # <s>
|
543 |
+
parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
|
544 |
+
info = parts[0] + conv.roles[1]
|
545 |
+
temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s>
|
546 |
+
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
|
547 |
+
cur_len = cur_len + temp_len
|
548 |
+
|
549 |
+
for index in range(1, len(parts) - 1):
|
550 |
+
info = parts[index]
|
551 |
+
part1, part2 = info.split(conv.roles[0])
|
552 |
+
temp_len = len(tokenizer(part1).input_ids) - 1
|
553 |
+
cur_len = cur_len + temp_len
|
554 |
+
part = conv.roles[0] + part2 + conv.roles[1]
|
555 |
+
temp_len = len(tokenizer(part).input_ids) - 1
|
556 |
+
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
|
557 |
+
cur_len = cur_len + temp_len
|
558 |
+
last_info = parts[-1]
|
559 |
+
temp_len = len(tokenizer(last_info).input_ids) - 1
|
560 |
+
cur_len = cur_len + temp_len
|
561 |
+
|
562 |
+
target[cur_len:] = IGNORE_TOKEN_ID
|
563 |
+
if False: # Inspect and check the correctness of masking
|
564 |
+
z = target.clone()
|
565 |
+
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
|
566 |
+
print(repr(tokenizer.decode(z)))
|
567 |
+
|
568 |
+
if cur_len < tokenizer.model_max_length:
|
569 |
+
if cur_len != total_len:
|
570 |
+
target[:] = IGNORE_TOKEN_ID
|
571 |
+
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
|
572 |
+
sys.stdout.flush()
|
573 |
+
|
574 |
+
new_targets.append(target)
|
575 |
+
|
576 |
+
new_targets = paddle.stack(new_targets, axis=0)
|
577 |
+
|
578 |
+
return dict(
|
579 |
+
input_ids=input_ids,
|
580 |
+
labels=new_targets,
|
581 |
+
attention_mask=input_ids.not_equal(paddle.to_tensor(tokenizer.pad_token_id)),
|
582 |
+
)
|
583 |
+
|
584 |
+
|
585 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
586 |
+
best_ratio_diff = float('inf')
|
587 |
+
best_ratio = (1, 1)
|
588 |
+
area = width * height
|
589 |
+
for ratio in target_ratios:
|
590 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
591 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
592 |
+
if ratio_diff < best_ratio_diff:
|
593 |
+
best_ratio_diff = ratio_diff
|
594 |
+
best_ratio = ratio
|
595 |
+
elif ratio_diff == best_ratio_diff:
|
596 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
597 |
+
best_ratio = ratio
|
598 |
+
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
|
599 |
+
return best_ratio
|
600 |
+
|
601 |
+
|
602 |
+
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_target_aspect_ratio=False):
|
603 |
+
orig_width, orig_height = image.size
|
604 |
+
aspect_ratio = orig_width / orig_height
|
605 |
+
|
606 |
+
# calculate the existing image aspect ratio
|
607 |
+
target_ratios = set(
|
608 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
609 |
+
i * j <= max_num and i * j >= min_num)
|
610 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
611 |
+
|
612 |
+
# find the closest aspect ratio to the target
|
613 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
614 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
615 |
+
|
616 |
+
# calculate the target width and height
|
617 |
+
target_width = image_size * target_aspect_ratio[0]
|
618 |
+
target_height = image_size * target_aspect_ratio[1]
|
619 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
620 |
+
|
621 |
+
# resize the image
|
622 |
+
resized_img = image.resize((target_width, target_height))
|
623 |
+
processed_images = []
|
624 |
+
for i in range(blocks):
|
625 |
+
box = (
|
626 |
+
(i % (target_width // image_size)) * image_size,
|
627 |
+
(i // (target_width // image_size)) * image_size,
|
628 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
629 |
+
((i // (target_width // image_size)) + 1) * image_size
|
630 |
+
)
|
631 |
+
# split the image
|
632 |
+
split_img = resized_img.crop(box)
|
633 |
+
processed_images.append(split_img)
|
634 |
+
assert len(processed_images) == blocks
|
635 |
+
if use_thumbnail and len(processed_images) != 1:
|
636 |
+
thumbnail_img = image.resize((image_size, image_size))
|
637 |
+
processed_images.append(thumbnail_img)
|
638 |
+
if return_target_aspect_ratio:
|
639 |
+
return processed_images, target_aspect_ratio
|
640 |
+
else:
|
641 |
+
return processed_images
|
642 |
+
|
643 |
+
|
644 |
+
def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
|
645 |
+
orig_width, orig_height = image.size
|
646 |
+
aspect_ratio = orig_width / orig_height
|
647 |
+
|
648 |
+
# calculate the existing image aspect ratio
|
649 |
+
target_ratios = set(
|
650 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
651 |
+
i * j <= max_num and i * j >= min_num)
|
652 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
653 |
+
|
654 |
+
new_target_ratios = []
|
655 |
+
if prior_aspect_ratio is not None:
|
656 |
+
for i in target_ratios:
|
657 |
+
if prior_aspect_ratio[0]%i[0] != 0 and prior_aspect_ratio[1]%i[1] != 0:
|
658 |
+
new_target_ratios.append(i)
|
659 |
+
else:
|
660 |
+
continue
|
661 |
+
|
662 |
+
# find the closest aspect ratio to the target
|
663 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
664 |
+
aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
|
665 |
+
|
666 |
+
# calculate the target width and height
|
667 |
+
target_width = image_size * target_aspect_ratio[0]
|
668 |
+
target_height = image_size * target_aspect_ratio[1]
|
669 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
670 |
+
|
671 |
+
# resize the image
|
672 |
+
resized_img = image.resize((target_width, target_height))
|
673 |
+
processed_images = []
|
674 |
+
for i in range(blocks):
|
675 |
+
box = (
|
676 |
+
(i % (target_width // image_size)) * image_size,
|
677 |
+
(i // (target_width // image_size)) * image_size,
|
678 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
679 |
+
((i // (target_width // image_size)) + 1) * image_size
|
680 |
+
)
|
681 |
+
# split the image
|
682 |
+
split_img = resized_img.crop(box)
|
683 |
+
processed_images.append(split_img)
|
684 |
+
assert len(processed_images) == blocks
|
685 |
+
if use_thumbnail and len(processed_images) != 1:
|
686 |
+
thumbnail_img = image.resize((image_size, image_size))
|
687 |
+
processed_images.append(thumbnail_img)
|
688 |
+
return processed_images
|
PaddleMIX/paddlemix/datasets/laiondata.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import base64
|
16 |
+
import gzip
|
17 |
+
import io
|
18 |
+
import json
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
|
22 |
+
from paddle.io import IterableDataset, get_worker_info
|
23 |
+
from PIL import Image
|
24 |
+
|
25 |
+
|
26 |
+
def paddle_worker_info(group=None):
|
27 |
+
"""Return node and worker info for paddle and some distributed environments."""
|
28 |
+
rank = 0
|
29 |
+
world_size = 1
|
30 |
+
worker = 0
|
31 |
+
num_workers = 1
|
32 |
+
|
33 |
+
if "WORKER" in os.environ and "NUM_WORKERS" in os.environ:
|
34 |
+
worker = int(os.environ["WORKER"])
|
35 |
+
num_workers = int(os.environ["NUM_WORKERS"])
|
36 |
+
else:
|
37 |
+
try:
|
38 |
+
worker_info = get_worker_info()
|
39 |
+
if worker_info is not None:
|
40 |
+
worker = worker_info.id
|
41 |
+
num_workers = worker_info.num_workers
|
42 |
+
except ModuleNotFoundError:
|
43 |
+
pass
|
44 |
+
return rank, world_size, worker, num_workers
|
45 |
+
|
46 |
+
|
47 |
+
class LaionDataset(IterableDataset):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
file_list,
|
51 |
+
get_text_emb="",
|
52 |
+
data_world_rank=0,
|
53 |
+
data_world_size=1,
|
54 |
+
buffer_size=1,
|
55 |
+
shuffle_every_n_samples=1000,
|
56 |
+
total_seen_samples=None,
|
57 |
+
):
|
58 |
+
with open(file_list, "r", encoding="utf-8") as f:
|
59 |
+
self.file_list = f.read().strip().split("\n")
|
60 |
+
self.get_text_emb = get_text_emb
|
61 |
+
self.buffer_size = buffer_size
|
62 |
+
self.shuffle_every_n_samples = shuffle_every_n_samples
|
63 |
+
self.min_size = 5
|
64 |
+
self.total_seen_samples = total_seen_samples
|
65 |
+
self.data_world_rank = data_world_rank
|
66 |
+
self.data_world_size = data_world_size
|
67 |
+
|
68 |
+
def parse_line(self, line, filename):
|
69 |
+
try:
|
70 |
+
vec = line.strip().split("\t")
|
71 |
+
text_json = json.loads(vec[2])
|
72 |
+
img_b64 = vec[5]
|
73 |
+
caption = text_json.get("caption_en", text_json.get("blip_caption_en", ""))
|
74 |
+
|
75 |
+
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
|
76 |
+
return dict(image=image, text=caption)
|
77 |
+
except Exception:
|
78 |
+
print(f"error when parse file {filename}")
|
79 |
+
return None
|
80 |
+
|
81 |
+
def get_data(self, data):
|
82 |
+
w, h = data["image"].size
|
83 |
+
if w < self.min_size or h < self.min_size:
|
84 |
+
return None
|
85 |
+
return data
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return self.total_seen_samples
|
89 |
+
|
90 |
+
def sample(self):
|
91 |
+
_, _, worker, num_workers = paddle_worker_info()
|
92 |
+
total_num_workers = num_workers * self.data_world_size
|
93 |
+
global_worker_id = self.data_world_rank * num_workers + worker
|
94 |
+
|
95 |
+
print("[CHECK ME] LaionDataset", global_worker_id, total_num_workers)
|
96 |
+
while True:
|
97 |
+
random.shuffle(self.file_list)
|
98 |
+
for i in range(len(self.file_list)):
|
99 |
+
if i % total_num_workers == global_worker_id:
|
100 |
+
filename = self.file_list[i].strip("\n")
|
101 |
+
|
102 |
+
with gzip.open(filename, "rb") if filename.endswith(".gz") else open(filename, "rb") as f:
|
103 |
+
while True:
|
104 |
+
line = f.readline()
|
105 |
+
|
106 |
+
if line == b"":
|
107 |
+
break
|
108 |
+
try:
|
109 |
+
try:
|
110 |
+
line = line.decode(encoding="utf-8")
|
111 |
+
except:
|
112 |
+
line = line.decode(encoding="gb18030")
|
113 |
+
except:
|
114 |
+
print(f"error on file {filename}")
|
115 |
+
continue
|
116 |
+
data = self.parse_line(line, filename)
|
117 |
+
|
118 |
+
if data is None:
|
119 |
+
continue
|
120 |
+
else:
|
121 |
+
data = self.get_data(data)
|
122 |
+
if data is None:
|
123 |
+
continue
|
124 |
+
yield data
|
125 |
+
|
126 |
+
def shuffle(self, iterator):
|
127 |
+
buffer_list = []
|
128 |
+
for _ in range(self.buffer_size):
|
129 |
+
buffer_list.append(next(iterator))
|
130 |
+
i = 0
|
131 |
+
while True:
|
132 |
+
if i % self.shuffle_every_n_samples == 0:
|
133 |
+
random.shuffle(buffer_list)
|
134 |
+
yield buffer_list.pop()
|
135 |
+
buffer_list.append(next(iterator))
|
136 |
+
i += 1
|
137 |
+
|
138 |
+
def __iter__(self):
|
139 |
+
return self.shuffle(iter(self.sample()))
|
PaddleMIX/paddlemix/datasets/mixtoken_dataset.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from paddle.io import Dataset
|
17 |
+
from scipy.linalg import block_diag
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
|
21 |
+
class MIXToken:
|
22 |
+
required_input_keys = ["input_ids", "labels"]
|
23 |
+
required_output_keys = ["input_ids", "labels", "attention_mask"]
|
24 |
+
# Only supported the following keys for MIXToken. Keys outside of the set will be ignored.
|
25 |
+
supported_input_keys = ["input_ids", "labels", "attention_mask", "position_ids", "images"]
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def _pad_batch_records(cls, batch_records):
|
29 |
+
# Only consider supported input keys
|
30 |
+
input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys]
|
31 |
+
|
32 |
+
# Check required_keys
|
33 |
+
for key in cls.required_input_keys:
|
34 |
+
if key not in input_keys:
|
35 |
+
raise ValueError(f"feature `{key}` is required for MIXTokenDataset")
|
36 |
+
# Output features must include all required output keys
|
37 |
+
for key in cls.required_output_keys:
|
38 |
+
if key not in input_keys:
|
39 |
+
input_keys.append(key)
|
40 |
+
|
41 |
+
batched_features = {key: [] for key in input_keys}
|
42 |
+
|
43 |
+
for record in batch_records:
|
44 |
+
batched_features["input_ids"].extend(record["input_ids"])
|
45 |
+
batched_features["labels"].extend(record["labels"])
|
46 |
+
seq_length = len(record["input_ids"])
|
47 |
+
# If attention_mask is not given, assume it's causal mask
|
48 |
+
attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool)))
|
49 |
+
batched_features["attention_mask"].append(attention_mask)
|
50 |
+
# NOTE: position_ids is optional and not required by every model
|
51 |
+
# We append instead of extend here to accomodate 2D position ids
|
52 |
+
if "position_ids" in record:
|
53 |
+
batched_features["position_ids"].append(record["position_ids"])
|
54 |
+
if "images" in record:
|
55 |
+
batched_features["images"].append(record["images"])
|
56 |
+
|
57 |
+
block_attention_mask = block_diag(*batched_features["attention_mask"])
|
58 |
+
# convert to 3-D [batch_size(1), seq_length, seq_length]
|
59 |
+
batched_features["attention_mask"] = np.expand_dims(block_attention_mask, axis=0)
|
60 |
+
if "position_ids" in batched_features:
|
61 |
+
# Accommodate both 1D and 2D position ids
|
62 |
+
batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist()
|
63 |
+
return batched_features
|
64 |
+
|
65 |
+
|
66 |
+
class MIXTokenMapDataset(MIXToken, Dataset):
|
67 |
+
"""
|
68 |
+
MIXToken is a unique feature of PaddleMix training, which replaces traditional pad tokens by
|
69 |
+
concatenating effective tokens to increase the throughput of a single sample and improve training speed.
|
70 |
+
|
71 |
+
traditional pad tokens:
|
72 |
+
len( imageToken + query + paddingToken ) = max_length
|
73 |
+
|
74 |
+
MIXToken:
|
75 |
+
len( imageToken1 + query1 + imageToken2 + query2 + ... + paddingToken ) = max_length
|
76 |
+
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, data, max_length, processor=None, tokenizer=None, mode="train"):
|
80 |
+
self.max_length = max_length
|
81 |
+
self.processor = processor
|
82 |
+
self.tokenizer = tokenizer
|
83 |
+
self.mode = mode
|
84 |
+
self.new_data = self._create_intokens_data(data)
|
85 |
+
|
86 |
+
def _create_intokens_data(self, data):
|
87 |
+
batch_records, max_len = [], 0
|
88 |
+
cur_len_so_far = 0
|
89 |
+
|
90 |
+
total_data = []
|
91 |
+
|
92 |
+
for i in tqdm(range(len(data))):
|
93 |
+
record = data[i]
|
94 |
+
|
95 |
+
if self.processor:
|
96 |
+
record = self.processor(record=record, mode=self.mode)
|
97 |
+
|
98 |
+
if getattr(self.tokenizer, "image_token_span", None) is not None and record["images"] is not None:
|
99 |
+
image_token_span = self.tokenizer.image_token_span - 1 # image token
|
100 |
+
else:
|
101 |
+
image_token_span = 0
|
102 |
+
|
103 |
+
max_len = max(max_len, len(record["input_ids"]))
|
104 |
+
to_append = (cur_len_so_far + int(image_token_span) + len(record["input_ids"])) <= self.max_length
|
105 |
+
|
106 |
+
if to_append:
|
107 |
+
batch_records.append(record)
|
108 |
+
cur_len_so_far += len(record["input_ids"]) + image_token_span
|
109 |
+
else:
|
110 |
+
# exceed max length
|
111 |
+
padded_list = self._pad_batch_records(batch_records)
|
112 |
+
total_data.append(padded_list)
|
113 |
+
# reset
|
114 |
+
batch_records, max_len = [], 0
|
115 |
+
cur_len_so_far = 0
|
116 |
+
# append current data
|
117 |
+
batch_records.append(record)
|
118 |
+
cur_len_so_far += len(record["input_ids"]) + image_token_span
|
119 |
+
|
120 |
+
# remaining data
|
121 |
+
if batch_records:
|
122 |
+
padded_list = self._pad_batch_records(batch_records)
|
123 |
+
total_data.append(padded_list)
|
124 |
+
|
125 |
+
return total_data
|
126 |
+
|
127 |
+
def __getitem__(self, idx):
|
128 |
+
return self.new_data[idx]
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.new_data)
|
PaddleMIX/paddlemix/datasets/vg_caption.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections
|
16 |
+
import os
|
17 |
+
|
18 |
+
__all__ = ["VGCaption"]
|
19 |
+
from paddlemix.datasets.caption_dataset import CaptionDataset
|
20 |
+
|
21 |
+
|
22 |
+
class VGCaption(CaptionDataset):
|
23 |
+
"""
|
24 |
+
VG Caption dataset.
|
25 |
+
"""
|
26 |
+
|
27 |
+
URL = "https://bj.bcebos.com/paddlemix/datasets/vg.tar.gz"
|
28 |
+
META_INFO = collections.namedtuple("META_INFO", ("images", "annotations", "images_md5", "annotations_md5"))
|
29 |
+
MD5 = ""
|
30 |
+
SPLITS = {
|
31 |
+
"train": META_INFO(
|
32 |
+
os.path.join("coco", "images"),
|
33 |
+
os.path.join("coco", "annotations/vg_caption.json"),
|
34 |
+
"",
|
35 |
+
"",
|
36 |
+
),
|
37 |
+
}
|
PaddleMIX/paddlemix/demo_images/critic_img_seven.png
ADDED
![]() |
PaddleMIX/paddlemix/external_ops/setup.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import multiprocessing
|
16 |
+
import os
|
17 |
+
|
18 |
+
|
19 |
+
def get_gencode_flags():
|
20 |
+
import paddle
|
21 |
+
|
22 |
+
prop = paddle.device.cuda.get_device_properties()
|
23 |
+
cc = prop.major * 10 + prop.minor
|
24 |
+
return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)]
|
25 |
+
|
26 |
+
|
27 |
+
def run(func):
|
28 |
+
p = multiprocessing.Process(target=func)
|
29 |
+
p.start()
|
30 |
+
p.join()
|
31 |
+
|
32 |
+
|
33 |
+
def change_pwd():
|
34 |
+
path = os.path.dirname(__file__)
|
35 |
+
if path:
|
36 |
+
os.chdir(path)
|
37 |
+
|
38 |
+
|
39 |
+
def setup_fast_ln():
|
40 |
+
from paddle.utils.cpp_extension import CUDAExtension, setup
|
41 |
+
|
42 |
+
gencode_flags = get_gencode_flags()
|
43 |
+
change_pwd()
|
44 |
+
setup(
|
45 |
+
name="fast_ln",
|
46 |
+
ext_modules=CUDAExtension(
|
47 |
+
sources=[
|
48 |
+
"fast_ln/ln_api.cpp",
|
49 |
+
"fast_ln/ln_bwd_semi_cuda_kernel.cu",
|
50 |
+
"fast_ln/ln_fwd_cuda_kernel.cu",
|
51 |
+
],
|
52 |
+
extra_compile_args={
|
53 |
+
"cxx": ["-O3"],
|
54 |
+
"nvcc": [
|
55 |
+
"-O3",
|
56 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
57 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
58 |
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
59 |
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
60 |
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
61 |
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
62 |
+
"-I./apex/contrib/csrc/layer_norm/",
|
63 |
+
"--expt-relaxed-constexpr",
|
64 |
+
"--expt-extended-lambda",
|
65 |
+
"--use_fast_math",
|
66 |
+
]
|
67 |
+
+ gencode_flags,
|
68 |
+
},
|
69 |
+
),
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def setup_fused_ln():
|
74 |
+
from paddle.utils.cpp_extension import CUDAExtension, setup
|
75 |
+
|
76 |
+
gencode_flags = get_gencode_flags()
|
77 |
+
change_pwd()
|
78 |
+
setup(
|
79 |
+
name="fused_ln",
|
80 |
+
ext_modules=CUDAExtension(
|
81 |
+
sources=[
|
82 |
+
"fused_ln/layer_norm_cuda.cu",
|
83 |
+
],
|
84 |
+
extra_compile_args={
|
85 |
+
"cxx": ["-O3"],
|
86 |
+
"nvcc": [
|
87 |
+
"-O3",
|
88 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
89 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
90 |
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
91 |
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
92 |
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
93 |
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
94 |
+
"-I./apex/contrib/csrc/layer_norm/",
|
95 |
+
"--expt-relaxed-constexpr",
|
96 |
+
"--expt-extended-lambda",
|
97 |
+
"--use_fast_math",
|
98 |
+
"-maxrregcount=50",
|
99 |
+
]
|
100 |
+
+ gencode_flags,
|
101 |
+
},
|
102 |
+
),
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
run(setup_fast_ln)
|
107 |
+
run(setup_fused_ln)
|
PaddleMIX/paddlemix/metrics/clip_zero_shot.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
import paddle
|
18 |
+
import paddle.nn.functional as F
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from paddlemix.processors.tokenizer import tokenize
|
22 |
+
|
23 |
+
|
24 |
+
def zero_shot_classifier(model, classnames_filename, templates_filename, args, text_tower=None):
|
25 |
+
classnames = [i.strip() for i in open(classnames_filename).readlines()]
|
26 |
+
templates = [i.strip() for i in open(templates_filename).readlines()]
|
27 |
+
|
28 |
+
if text_tower is None:
|
29 |
+
if hasattr(model, "_layers"):
|
30 |
+
text_tower = model._layers.encode_text
|
31 |
+
else:
|
32 |
+
text_tower = model.encode_text
|
33 |
+
tokenizer = tokenize
|
34 |
+
with paddle.no_grad():
|
35 |
+
zeroshot_weights = []
|
36 |
+
for classname in tqdm(classnames):
|
37 |
+
texts = [template.format(classname) for template in templates] # format with class
|
38 |
+
texts = tokenizer(texts) # tokenize
|
39 |
+
class_embeddings = text_tower(texts)
|
40 |
+
class_embedding = F.normalize(class_embeddings, axis=-1).mean(0)
|
41 |
+
class_embedding /= class_embedding.norm()
|
42 |
+
zeroshot_weights.append(class_embedding)
|
43 |
+
zeroshot_weights = paddle.stack(zeroshot_weights, axis=1)
|
44 |
+
return zeroshot_weights
|
45 |
+
|
46 |
+
|
47 |
+
def accuracy(output, target, topk=(1,)):
|
48 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
49 |
+
maxk = min(max(topk), output.shape[1])
|
50 |
+
pred = output.topk(maxk, 1, True, True)[1].t()
|
51 |
+
correct = pred == target.reshape([1, -1]).expand_as(pred)
|
52 |
+
return [
|
53 |
+
float(correct[: min(k, maxk)].reshape([-1]).astype(paddle.float32).sum(0, keepdim=True).numpy() * 100.0)
|
54 |
+
for k in topk
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
class DummyAutocast:
|
59 |
+
def __init__(self, *args, **kwargs):
|
60 |
+
return
|
61 |
+
|
62 |
+
def __enter__(self, *args, **kwargs):
|
63 |
+
return
|
64 |
+
|
65 |
+
def __exit__(self, *args, **kwargs):
|
66 |
+
return
|
67 |
+
|
68 |
+
|
69 |
+
def get_autocast(precision):
|
70 |
+
if precision == "float16":
|
71 |
+
return paddle.amp.auto_cast
|
72 |
+
elif precision == "bfloat16":
|
73 |
+
return lambda: paddle.amp.auto_cast(dtype="bfloat16")
|
74 |
+
else:
|
75 |
+
return DummyAutocast
|
76 |
+
|
77 |
+
|
78 |
+
def get_cast_dtype(args):
|
79 |
+
cast_dtype = None
|
80 |
+
if args.bf16:
|
81 |
+
cast_dtype = "bfloat16"
|
82 |
+
elif args.fp16:
|
83 |
+
cast_dtype = "float16"
|
84 |
+
return cast_dtype
|
85 |
+
|
86 |
+
|
87 |
+
class ClipZeroShot:
|
88 |
+
def __init__(self, model, args):
|
89 |
+
data_path = args.classification_eval.strip()
|
90 |
+
classname_filename = f"{data_path}/labels.txt"
|
91 |
+
template_filename = f"{data_path}/templates.txt"
|
92 |
+
|
93 |
+
self.data_name = os.path.basename(args.classification_eval)
|
94 |
+
classifier_filename = (
|
95 |
+
f"{os.path.dirname(classname_filename)}/{args.pretrained_text_model}_{self.data_name}_classifier.pdparams"
|
96 |
+
)
|
97 |
+
if os.path.exists(classifier_filename):
|
98 |
+
print("load classifier from disk")
|
99 |
+
classifier = paddle.load(classifier_filename)
|
100 |
+
else:
|
101 |
+
print("constructing classifier: {}.".format(classifier_filename))
|
102 |
+
classifier = zero_shot_classifier(model, classname_filename, template_filename, args)
|
103 |
+
paddle.save(classifier, classifier_filename)
|
104 |
+
print(f"zero-shot evaluating classification task: {self.data_name}")
|
105 |
+
if args.bf16:
|
106 |
+
self.classifier = classifier.astype(paddle.bfloat16)
|
107 |
+
elif args.fp16:
|
108 |
+
self.classifier = classifier.astype(paddle.float16)
|
109 |
+
else:
|
110 |
+
self.classifier = classifier
|
111 |
+
self.batch_size = args.per_device_eval_batch_size
|
112 |
+
self.cast_dtype = get_cast_dtype(args)
|
113 |
+
|
114 |
+
def zero_shot_eval(self, evalres):
|
115 |
+
results = {}
|
116 |
+
print("Extract features done, starting zero-shot classification evaluation.")
|
117 |
+
predictions, labels = evalres.predictions, evalres.label_ids
|
118 |
+
n = predictions.shape[0]
|
119 |
+
top1, top5 = 0.0, 0.0
|
120 |
+
|
121 |
+
autocast = get_autocast(self.cast_dtype)
|
122 |
+
with paddle.no_grad():
|
123 |
+
for step in tqdm(range((predictions.shape[0] + self.batch_size - 1) // self.batch_size)):
|
124 |
+
with autocast():
|
125 |
+
image_features = paddle.to_tensor(
|
126 |
+
predictions[step * self.batch_size : (step + 1) * self.batch_size]
|
127 |
+
)
|
128 |
+
target = paddle.to_tensor(labels[step * self.batch_size : (step + 1) * self.batch_size])
|
129 |
+
logits = 100.0 * image_features @ self.classifier
|
130 |
+
if logits.shape[-1] < 5:
|
131 |
+
(acc1,) = accuracy(logits, target, topk=(1,))
|
132 |
+
acc5 = -1
|
133 |
+
else:
|
134 |
+
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
|
135 |
+
top1 += acc1
|
136 |
+
top5 += acc5
|
137 |
+
top1 = top1 / n
|
138 |
+
top5 = top5 / n
|
139 |
+
results["val/imagenet-zeroshot-val-top1"] = top1
|
140 |
+
results["val/imagenet-zeroshot-val-top5"] = top5
|
141 |
+
|
142 |
+
results["top1"] = top1
|
143 |
+
print(f"zero-shot classification task: {self.data_name}: top1: {top1}, top5: {top5}")
|
144 |
+
print("Finished zero-shot evaluation.")
|
145 |
+
|
146 |
+
return results
|