tuandunghcmut commited on
Commit
a60bdd6
·
verified ·
1 Parent(s): f435a72

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. PaddleMIX/.travis/precommit.sh +21 -0
  2. PaddleMIX/applications/gradio_text2image.py +54 -0
  3. PaddleMIX/comfyui/README.md +49 -0
  4. PaddleMIX/deploy/README_en.md +108 -0
  5. PaddleMIX/docs/train_tutorial.md +10 -0
  6. PaddleMIX/paddlemix/activations.py +174 -0
  7. PaddleMIX/paddlemix/checkpoint.py +216 -0
  8. PaddleMIX/ppdiffusers/LICENSE +203 -0
  9. PaddleMIX/ppdiffusers/Makefile +30 -0
  10. PaddleMIX/ppdiffusers/deploy-deprecated/export.md +67 -0
  11. PaddleMIX/ppdiffusers/deploy-deprecated/export.sh +17 -0
  12. PaddleMIX/ppdiffusers/deploy-deprecated/export_model.py +201 -0
  13. PaddleMIX/ppdiffusers/deploy-deprecated/gradio_demo.py +683 -0
  14. PaddleMIX/ppdiffusers/deploy-deprecated/infer.py +742 -0
  15. PaddleMIX/ppdiffusers/deploy-deprecated/infer_dygraph.py +380 -0
  16. PaddleMIX/ppdiffusers/deploy-deprecated/infer_dygraph_torch.py +447 -0
  17. PaddleMIX/ppdiffusers/deploy-deprecated/requirements.txt +2 -0
  18. PaddleMIX/ppdiffusers/deploy/README.md +65 -0
  19. PaddleMIX/ppdiffusers/ppdiffusers/__init__.py +814 -0
  20. PaddleMIX/ppdiffusers/ppdiffusers/accelerate/__init__.py +30 -0
  21. PaddleMIX/ppdiffusers/ppdiffusers/accelerate/logging.py +123 -0
  22. PaddleMIX/ppdiffusers/ppdiffusers/accelerate/optimizer.py +180 -0
  23. PaddleMIX/ppdiffusers/ppdiffusers/accelerate/scheduler.py +96 -0
  24. PaddleMIX/ppdiffusers/ppdiffusers/accelerate/tracking.py +1103 -0
  25. PaddleMIX/ppdiffusers/ppdiffusers/callbacks.py +156 -0
  26. PaddleMIX/ppdiffusers/ppdiffusers/configuration_utils.py +695 -0
  27. PaddleMIX/ppdiffusers/ppdiffusers/image_processor.py +671 -0
  28. PaddleMIX/ppdiffusers/ppdiffusers/initializer.py +20 -0
  29. PaddleMIX/ppdiffusers/ppdiffusers/models/attention_processor.py +0 -0
  30. PaddleMIX/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py +1190 -0
  31. PaddleMIX/ppdiffusers/ppdiffusers/models/autoencoder_kl_temporal_decoder.py +396 -0
  32. PaddleMIX/ppdiffusers/ppdiffusers/models/autoencoder_tiny.py +363 -0
  33. PaddleMIX/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py +394 -0
  34. PaddleMIX/ppdiffusers/ppdiffusers/models/consistency_decoder_vae.py +445 -0
  35. PaddleMIX/ppdiffusers/ppdiffusers/models/controlnet.py +889 -0
  36. PaddleMIX/ppdiffusers/ppdiffusers/models/dit_llama_t2i.py +582 -0
  37. PaddleMIX/ppdiffusers/ppdiffusers/models/downsampling.py +383 -0
  38. PaddleMIX/ppdiffusers/ppdiffusers/models/dual_transformer_2d.py +158 -0
  39. PaddleMIX/ppdiffusers/ppdiffusers/models/lora.py +462 -0
  40. PaddleMIX/ppdiffusers/ppdiffusers/models/lvdm_attention_temporal.py +462 -0
  41. PaddleMIX/ppdiffusers/ppdiffusers/models/lvdm_unet_3d.py +713 -0
  42. PaddleMIX/ppdiffusers/ppdiffusers/models/lvdm_util.py +296 -0
  43. PaddleMIX/ppdiffusers/ppdiffusers/models/modeling_pytorch_paddle_utils.py +117 -0
  44. PaddleMIX/ppdiffusers/ppdiffusers/models/modeling_utils.py +1356 -0
  45. PaddleMIX/ppdiffusers/ppdiffusers/models/modelscope_gaussion_sdedit.py +451 -0
  46. PaddleMIX/ppdiffusers/ppdiffusers/models/modelscope_st_unet_video2video.py +409 -0
  47. PaddleMIX/ppdiffusers/ppdiffusers/models/prior_transformer.py +398 -0
  48. PaddleMIX/ppdiffusers/ppdiffusers/models/simplified_sd3.py +216 -0
  49. PaddleMIX/ppdiffusers/ppdiffusers/models/transformer_2d.py +538 -0
  50. PaddleMIX/ppdiffusers/ppdiffusers/models/unet_1d_blocks.py +752 -0
PaddleMIX/.travis/precommit.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ function abort(){
18
+ echo "Your commit not fit PaddlePaddle code style" 1>&2
19
+ echo "Please use pre-commit scripts to auto-format your code" 1>&2
20
+ exit 1
21
+ }
PaddleMIX/applications/gradio_text2image.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddlemix.appflow import Appflow
2
+ from ppdiffusers.utils import load_image
3
+ import paddle
4
+ import imageio
5
+
6
+ from PIL import Image
7
+ import gradio as gr
8
+ import traceback
9
+
10
+ # upscaling
11
+ def ups_fun(low_res_img, prompt):
12
+ low_res_img = Image.fromarray(low_res_img.astype('uint8')).convert('RGB')
13
+ app = Appflow(app='image2image_text_guided_upscaling',models=['stabilityai/stable-diffusion-x4-upscaler'])
14
+ image = app(prompt=prompt,image=low_res_img)['result']
15
+ return image
16
+
17
+ # text_guided_generation
18
+ def tge_fun(image, prompt_pos, prompt_neg):
19
+ image = Image.fromarray(image.astype('uint8')).convert('RGB')
20
+ app = Appflow(app='image2image_text_guided_generation',models=['Linaqruf/anything-v3.0'])
21
+ image = app(prompt=prompt_pos,negative_prompt=prompt_neg,image=image)['result']
22
+ return image
23
+
24
+ # video_generation
25
+ def vge_fun(prompt):
26
+ app = Appflow(app='text_to_video_generation',models=['damo-vilab/text-to-video-ms-1.7b'])
27
+ video_frames = app(prompt=prompt,num_inference_steps=25)['result']
28
+ imageio.mimsave("gen_video.gif", video_frames, duration=8)
29
+ return "gen_video.gif"
30
+
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("# Appflow应用:text2image")
33
+ with gr.Tab("文本引导的图像放大"):
34
+ with gr.Row():
35
+ ups_image_in = gr.Image(label = "输入图片")
36
+ ups_image_out = gr.Image(label = "输出图片")
37
+ ups_text_in = gr.Text(label = "Prompt")
38
+ ups_button = gr.Button()
39
+ ups_button.click(fn=ups_fun, inputs = [ups_image_in, ups_text_in], outputs = [ups_image_out])
40
+ with gr.Tab("文本引导的图像变换"):
41
+ with gr.Row():
42
+ tge_image_in = gr.Image(label = "输入图片")
43
+ tge_image_out = gr.Image(label = "输出图片")
44
+ tge_text_pos_in = gr.Text(label = "Positive Prompt")
45
+ tge_text_neg_in = gr.Text(label = "Negative Prompt")
46
+ tge_button = gr.Button()
47
+ tge_button.click(fn=tge_fun, inputs = [tge_image_in, tge_text_pos_in, tge_text_neg_in], outputs = [tge_image_out])
48
+ with gr.Tab("文本条件的视频生成"):
49
+ vge_text_in = gr.Text(label = "Prompt")
50
+ vge_video_out = gr.Video(label = "输出视频")
51
+ vge_button = gr.Button()
52
+ vge_button.click(fn=vge_fun, inputs = [vge_text_in], outputs = [vge_video_out])
53
+
54
+ demo.launch()
PaddleMIX/comfyui/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PaddleMIX 扩展插件 for ComfyUI
2
+
3
+ ## 简介
4
+ [ComfyUI](https://github.com/comfyanonymous/ComfyUI/) 是一个在开源社区广受欢迎的AIGC程序。它通过节点拆分和工作流组合的方式,让不同模型协同工作,完成复杂的高级生产任务。本目录包含PaddleMIX为ComfyUI开发的一些节点扩展程序,支持文本到图像生成、图像分割、图像生成文本描述等多模态能力。
5
+
6
+ ## 安装与使用指南
7
+
8
+ ### 一、准备ComfyUI环境
9
+
10
+ #### 从源代码部署
11
+ 访问 [ComfyUI GitHub仓库](https://github.com/comfyanonymous/ComfyUI) 获取源代码。
12
+
13
+ #### 使用Docker部署
14
+ 1. **拉取镜像文件压缩包并加载**(或直接使用 `docker pull` 命令拉取网上的任意ComfyUI镜像):
15
+ ```shell
16
+ wget https://paddlenlp.bj.bcebos.com/models/community/aistudio/comfyui_docker/comfyui_aistudio_v1.tar
17
+ docker load -i comfyui_aistudio_v1.tar
18
+ ```
19
+ 2. **创建Docker实例**,注意替换路径和镜像名称:
20
+ ```shell
21
+ nvidia-docker run --name comfyui_env -it -e HOME="/root" -w "/root" -v </path/to/temp_data_dir>:/root --ipc=host --net=host <docker-image-name> /bin/bash --login
22
+ ```
23
+ 3. **进入Docker环境**:
24
+ ```shell
25
+ docker exec -it comfyui_env /bin/bash
26
+ ```
27
+ 4. **启动ComfyUI**:
28
+ ```shell
29
+ cd /comfyui_env
30
+ ./python_env/bin/python ComfyUI/main.py --listen 0.0.0.0 --port 8889 &
31
+ ```
32
+
33
+ ### 二、安装PaddleMIX ComfyUI扩展程序
34
+
35
+ 将PaddleMIX/comfyui/下的对应插件文件夹复制到ComfyUI/custom_nodes/文件夹下,并安装对应的requirements.txt文件即可使用。
36
+
37
+ #### 安装文生图扩展节点的示例:
38
+ ```shell
39
+ # 复制扩展程序文件夹到ComfyUI/custom_nodes/目录
40
+ cp -r PaddleMIX/comfyui/ComfyUI_ppdiffusers /path/to/your/ComfyUI/custom_nodes/
41
+ # 安装扩展程序所需要的依赖包
42
+ pip install -r PaddleMIX/comfyui/ComfyUI_ppdiffusers/requirements.txt
43
+ ```
44
+
45
+ ### 三、加载工作流
46
+
47
+ 每个扩展程序目录下都有一个workflows文件夹,你可以通过浏览器加载其中的json文件来使用对应的工作流。具体用例可参考:[PaddleMIX ComfyUI扩展程序示例](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/comfyui/ComfyUI_ppdiffusers)。
48
+
49
+
PaddleMIX/deploy/README_en.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PaddleMIX Inference Deployment
2
+
3
+ [[中文文档](README.md)]
4
+
5
+ PaddleMIX utilizes Paddle Inference and provides a Python-based deployment solution. There are two deployment methods:
6
+
7
+ 1. **APPflow Deployment**:
8
+ - By setting the `static_mode = True` variable in APPflow, you can enable static graph inference. Additionally, you can accelerate inference using TensorRT. Note that not all models support static graph or TensorRT. Please refer to the [Multi Modal And Scenario](../applications/README_en.md/#multi-modal-and-scenario) section for specific model support.
9
+
10
+ 2. **Single Model Deployment**:
11
+
12
+ For APPflow usage, you can set the `static_mode = True` variable to enable static graph inference and optionally accelerate inference using TensorRT.
13
+
14
+ ### 1.1 Exmaples
15
+
16
+ ```python
17
+ >>> from paddlemix.appflow import Appflow
18
+ >>> from PIL import Image
19
+
20
+ >>> task = Appflow(app="openset_det_sam",
21
+ models=["GroundingDino/groundingdino-swint-ogc","Sam/SamVitH-1024"],
22
+ static_mode=True,
23
+ precision="fp32")
24
+ >>> image_pil = Image.open("beauty.png").convert("RGB")
25
+ >>> result = task(image=image_pil,prompt="women")
26
+ ```
27
+
28
+ ### 1.2 Parameter Explanation
29
+ | Parameter | Required? | Meaning |
30
+ |-------|-------|---------------------------------------------------------------------------------------------|
31
+ | --app | Yes| Application name |
32
+ | --models | Yes | Model(s) used. Can be one model, or multiple models |
33
+ | --static_mode | Optional | Whether to use static graph inference, default to False |
34
+ | --precision | Optional | When `static_mode == True`, it defaults to using FP32. You can optionally select `trt_fp32` or `trt_fp16`. |
35
+
36
+ Instructions:
37
+ - Some models do not support static graph or TensorRT. For specific information, please refer to [Multi Modal And Scenario](../applications/README_en.md/#multi-modal-and-scenario).
38
+
39
+ - The generated static graph will be located in the folder corresponding to the model name, for example: `GroundingDino/groundingdino-swint-ogc/`.
40
+
41
+ ## 2. Single Model Prediction Deployment
42
+
43
+ Python-based prediction deployment mainly involves two steps:
44
+ - Exporting the predictive model
45
+ - Performing prediction using Python
46
+
47
+ Currently supported models:
48
+ - [blip2](./blip2/README.md)
49
+ - [groundingdino](./groundingdino/README.md)
50
+ - [sam](./sam/README.md)
51
+ - [qwen_vl](./qwen_vl/README.md)
52
+
53
+ Using groundingdino as an exmaple.
54
+
55
+ ### 2.1 Exporting Predictive Model
56
+
57
+ ```bash
58
+ cd deploy/groundingdino
59
+ # 导出groundingdino模型
60
+ python export.py \
61
+ --dino_type GroundingDino/groundingdino-swint-ogc
62
+ ```
63
+ Will be exported to the following directory, including `model_state.pdiparams`, `model_state.pdiparams.info`, `model_state.pdmodel`and other files.
64
+
65
+ ### 2.2 Python-based Inference
66
+
67
+ ```bash
68
+ python predict.py \
69
+ --text_encoder_type GroundingDino/groundingdino-swint-ogc \
70
+ --model_path output_groundingdino/GroundingDino/groundingdino-swint-ogc \
71
+ --input_image https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg \
72
+ --output_dir ./groundingdino_predict_output \
73
+ --prompt "bus"
74
+
75
+ ```
76
+
77
+ ## 3. BenchMark
78
+
79
+ > Note:
80
+ > environment
81
+ Paddle 3.0
82
+ PaddleMIX release/2.0
83
+ PaddleNLP 2.7.2
84
+ A100 80G。
85
+
86
+ ### 3.1 benchmark cmd
87
+
88
+ Add -- benchmark after running in the 'deploy' corresponding model directory to obtain the running time of the model.
89
+ example: GroundingDino benchmark:
90
+
91
+ ```bash
92
+ cd deploy/groundingdino
93
+ python predict.py \
94
+ --text_encoder_type GroundingDino/groundingdino-swint-ogc \
95
+ --model_path output_groundingdino/GroundingDino/groundingdino-swint-ogc \
96
+ --input_image https://bj.bcebos.com/v1/paddlenlp/models/community/GroundingDino/000000004505.jpg \
97
+ --output_dir ./groundingdino_predict_output \
98
+ --prompt "bus" \
99
+ --benchmark True
100
+ ```
101
+
102
+ |Model|image size|dtype |Paddle Deploy |
103
+ |-|-|-|-|
104
+ |qwen-vl-7b|448*448|fp16|669.8 ms|
105
+ |llava-1.5-7b|336*336|fp16|981.2 ms|
106
+ |llava-1.6-7b|336*336|fp16|778.7 ms|
107
+ |groundingDino/groundingdino-swint-ogc|800*1193|fp32|100 ms|
108
+ |Sam/SamVitH-1024|1024*1024|fp32|121 ms|
PaddleMIX/docs/train_tutorial.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Train Tutorial
2
+
3
+
4
+ ## 训练微调示例
5
+ - [Blip2](../paddlemix/examples/blip2/README.md)
6
+ - [clip](../paddlemix/examples/clip/README.md)
7
+ - [coca](../paddlemix/examples/coca/README.md)
8
+ - [eva02](../paddlemix/examples/eva02/README.md)
9
+ - [evaclip](../paddlemix/examples/evaclip/README.md)
10
+ - [Stable Diffusion](../ppdiffusers/examples/text_to_image/README.md)
PaddleMIX/paddlemix/activations.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from collections import OrderedDict
18
+
19
+ import paddle
20
+ import paddle.nn.functional as F
21
+ from paddle import Tensor, nn
22
+
23
+
24
+ class NewGELUActivation(nn.Layer):
25
+ """
26
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
27
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
28
+ """
29
+
30
+ def forward(self, input: Tensor) -> Tensor:
31
+ return (
32
+ 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
33
+ )
34
+
35
+
36
+ class GELUActivation(nn.Layer):
37
+ """
38
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
39
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
40
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
41
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
42
+ """
43
+
44
+ def __init__(self, use_gelu_python: bool = False):
45
+ super().__init__()
46
+ if use_gelu_python:
47
+ self.act = self._gelu_python
48
+ else:
49
+ self.act = nn.functional.gelu
50
+
51
+ def _gelu_python(self, input: Tensor) -> Tensor:
52
+ return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
53
+
54
+ def forward(self, input: Tensor) -> Tensor:
55
+ return self.act(input)
56
+
57
+
58
+ class FastGELUActivation(nn.Layer):
59
+ """
60
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
61
+ """
62
+
63
+ def forward(self, input: Tensor) -> Tensor:
64
+ return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
65
+
66
+
67
+ class QuickGELUActivation(nn.Layer):
68
+ """
69
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
70
+ """
71
+
72
+ def forward(self, input: Tensor) -> Tensor:
73
+ return input * F.sigmoid(1.702 * input)
74
+
75
+
76
+ class ClippedGELUActivation(nn.Layer):
77
+ """
78
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
79
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
80
+ https://arxiv.org/abs/2004.09602.
81
+
82
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
83
+ initially created.
84
+
85
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
86
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
87
+ """
88
+
89
+ def __init__(self, min: float, max: float):
90
+ if min > max:
91
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
92
+
93
+ super().__init__()
94
+ self.min = min
95
+ self.max = max
96
+
97
+ def forward(self, x: Tensor) -> Tensor:
98
+ return paddle.clip(gelu(x), self.min, self.max)
99
+
100
+
101
+ class SiLUActivation(nn.Layer):
102
+ """
103
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
104
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
105
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
106
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
107
+ later.
108
+ """
109
+
110
+ def forward(self, input: Tensor) -> Tensor:
111
+ return F.silu(input)
112
+
113
+
114
+ class MishActivation(nn.Layer):
115
+ """
116
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
117
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
118
+ """
119
+
120
+ def forward(self, input: Tensor) -> Tensor:
121
+ return F.mish(input)
122
+
123
+
124
+ class LinearActivation(nn.Layer):
125
+ """
126
+ Applies the linear activation function, i.e. forwarding input directly to output.
127
+ """
128
+
129
+ def forward(self, input: Tensor) -> Tensor:
130
+ return input
131
+
132
+
133
+ class ClassInstantier(OrderedDict):
134
+ def __getitem__(self, key):
135
+ content = super().__getitem__(key)
136
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
137
+ return cls(**kwargs)
138
+
139
+
140
+ ACT2CLS = {
141
+ "gelu": GELUActivation,
142
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
143
+ "gelu_fast": FastGELUActivation,
144
+ "gelu_new": NewGELUActivation,
145
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
146
+ "linear": LinearActivation,
147
+ "mish": MishActivation,
148
+ "quick_gelu": QuickGELUActivation,
149
+ "relu": nn.ReLU,
150
+ "relu6": nn.ReLU6,
151
+ "sigmoid": nn.Sigmoid,
152
+ "silu": SiLUActivation,
153
+ "swish": SiLUActivation,
154
+ "tanh": nn.Tanh,
155
+ }
156
+ ACT2FN = ClassInstantier(ACT2CLS)
157
+
158
+
159
+ def get_activation(activation_string):
160
+ if activation_string in ACT2FN:
161
+ return ACT2FN[activation_string]
162
+ else:
163
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
164
+
165
+
166
+ # For backwards compatibility with: from activations import gelu_python
167
+ gelu_python = get_activation("gelu_python")
168
+ gelu_new = get_activation("gelu_new")
169
+ gelu = get_activation("gelu")
170
+ gelu_fast = get_activation("gelu_fast")
171
+ quick_gelu = get_activation("quick_gelu")
172
+ silu = get_activation("silu")
173
+ mish = get_activation("mish")
174
+ linear_act = get_activation("linear")
PaddleMIX/paddlemix/checkpoint.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import shutil
17
+
18
+ import paddle
19
+ import paddle.nn.functional as F
20
+
21
+
22
+ def save(args, model, optimizer, epoch=0, step=0, output_dir="", is_best=False):
23
+ """
24
+ save the state dicts of model and optimizer into an checkpoint.
25
+ """
26
+ if args.dp_rank != 0:
27
+ return
28
+
29
+ if output_dir and isinstance(output_dir, str):
30
+ output_dir = os.path.join(output_dir, "epoch_%d_step_%d" % (epoch, step))
31
+ if not os.path.exists(output_dir):
32
+ os.makedirs(output_dir, exist_ok=True)
33
+ print("Save model to %s" % output_dir)
34
+
35
+ save_dir = "{}/mp_{:0>2d}_sharding_{:0>2d}".format(output_dir, args.mp_rank, args.sharding_rank)
36
+
37
+ # if args.sharding_stage == 3:
38
+ # model.get_all_parameters(convert2cpu=False)
39
+ paddle.save(model.state_dict(), os.path.join(save_dir, "model.pdparams"))
40
+ paddle.save(optimizer.state_dict(), os.path.join(save_dir, "model_state.pdopt"))
41
+ if is_best:
42
+ shutil.copyfile("model.pdparams", "model_best.pdparams")
43
+ meta_dict = {
44
+ "epoch": epoch,
45
+ "step": step,
46
+ "cuda_rng_state": paddle.get_cuda_rng_state(),
47
+ }
48
+ paddle.save(meta_dict, os.path.join(save_dir, "meta_state.pdopt"))
49
+
50
+ else:
51
+ raise TypeError("`save` requires a valid value of `output_dir`.")
52
+
53
+
54
+ def load_model(args, model, optimizer=None, ckpt_dir=""):
55
+ """
56
+ load the saved checkpoint file and update the state dicts of model and optimizer.
57
+ """
58
+ if ckpt_dir and isinstance(ckpt_dir, str) and os.path.isdir(ckpt_dir):
59
+ print("Try to load checkpoint from %s " % ckpt_dir)
60
+
61
+ load_dir = "{}/mp_{:0>2d}_sharding_{:0>2d}".format(ckpt_dir, args.mp_rank, args.sharding_rank)
62
+ model_path = os.path.join(load_dir, "model.pdparams")
63
+ opt_path = os.path.join(load_dir, "model_state.pdopt")
64
+ # meta_path = os.path.join(load_dir, "meta_state.pdopt")
65
+
66
+ if os.path.exists(model_path):
67
+ model_dict = paddle.load(model_path)
68
+ for name, param in model.state_dict().items():
69
+ assert name in model_dict.keys(), "No param named `{}` was found in checkpoint file.".format(name)
70
+
71
+ if param.dtype != model_dict[name].dtype:
72
+ model_dict[name] = model_dict[name].cast(param.dtype)
73
+
74
+ model.set_state_dict(model_dict)
75
+ del model_dict
76
+ else:
77
+ raise ValueError("No checkpoint file found in %s" % model_path)
78
+
79
+ if os.path.exists(opt_path):
80
+ opt_dict = paddle.load(opt_path)
81
+ optimizer.set_state_dict(opt_dict)
82
+ del opt_dict
83
+ else:
84
+ print("No optimizer checkpoint file found in %s." % opt_path)
85
+
86
+ # if os.path.exists(meta_path):
87
+ # meta_dict = paddle.load(meta_path)
88
+ # load_recovery = {
89
+ # 'step': meta_dict['step'],
90
+ # 'epoch': meta_dict['epoch'],
91
+ # 'rng_state': meta_dict['cuda_rng_state']
92
+ # }
93
+ # del meta_dict
94
+ # else:
95
+ # raise ValueError("No meta checkpoint file found in %s." %
96
+ # meta_path)
97
+
98
+ print("successfully load checkpoints")
99
+ elif ckpt_dir and os.path.isfile(ckpt_dir):
100
+ print("Try to load a whole checkpoint from %s " % ckpt_dir)
101
+ embedding_list = ["token_embedding"]
102
+ collinear_list = [
103
+ "proj",
104
+ "w1",
105
+ "w2",
106
+ "w3",
107
+ "head",
108
+ "c_fc",
109
+ "c_proj",
110
+ "q_bias",
111
+ "v_bias",
112
+ "q_proj",
113
+ "k_proj",
114
+ "v_proj",
115
+ "qkv",
116
+ "c_fc",
117
+ "c_proj",
118
+ "lm_head",
119
+ "fc1",
120
+ "fc2",
121
+ "fc3",
122
+ ]
123
+ rowlinear_list = ["out_proj"] # in eva_text_model.py, but evaclip do not use text model
124
+ all_list = collinear_list + rowlinear_list + embedding_list
125
+ skip_list = [
126
+ "visual.patch_embed.proj.weight",
127
+ "visual.patch_embed.proj.bias",
128
+ "patch_embed.proj.weight",
129
+ "patch_embed.proj.bias",
130
+ ]
131
+
132
+ col_list = []
133
+ row_list = []
134
+ emb_list = []
135
+
136
+ mp_rank = args.mp_rank
137
+ mp_size = max(args.tensor_parallel_degree, 1)
138
+
139
+ def col_split_modeldict(model_dict):
140
+ if len(model_dict.shape) == 2:
141
+ subbatch = model_dict.shape[1] // mp_size
142
+ return model_dict[:, mp_rank * subbatch : (mp_rank + 1) * subbatch]
143
+ elif len(model_dict.shape) == 1:
144
+ subbatch = model_dict.shape[0] // mp_size
145
+ return model_dict[mp_rank * subbatch : (mp_rank + 1) * subbatch]
146
+
147
+ def row_split_modeldict(model_dict):
148
+ if len(model_dict.shape) == 2:
149
+ subbatch = model_dict.shape[0] // mp_size
150
+ return model_dict[mp_rank * subbatch : (mp_rank + 1) * subbatch]
151
+ else:
152
+ return model_dict
153
+
154
+ def emb_split_modeldict(model_dict):
155
+ subbatch = model_dict.shape[0] // mp_size
156
+ return model_dict[mp_rank * subbatch : (mp_rank + 1) * subbatch]
157
+
158
+ model_dict = paddle.load(ckpt_dir)
159
+ modelkeys = model_dict.keys()
160
+ for whole_key in modelkeys:
161
+ if "." not in whole_key:
162
+ continue
163
+
164
+ key = whole_key.split(".")[-2]
165
+ if whole_key in skip_list:
166
+ continue
167
+ if key in all_list:
168
+ if key in collinear_list:
169
+ col_list.append((key, model_dict[whole_key].shape))
170
+ model_dict[whole_key] = col_split_modeldict(model_dict[whole_key])
171
+ elif key in rowlinear_list:
172
+ row_list.append((key, model_dict[whole_key].shape))
173
+ model_dict[whole_key] = row_split_modeldict(model_dict[whole_key])
174
+ else:
175
+ emb_list.append((key, model_dict[whole_key].shape))
176
+ model_dict[whole_key] = emb_split_modeldict(model_dict[whole_key])
177
+
178
+ if hasattr(args, "context_length") and args.context_length != 77:
179
+ model_dict["text.positional_embedding"] = model_dict["text.positional_embedding"][: args.context_length, :]
180
+
181
+ # interpolate position embedding, only in eva02 finetune large size training
182
+ if "pos_embed" in model_dict and hasattr(model, "patch_embed"):
183
+ pos_embed_checkpoint = model_dict["pos_embed"] #
184
+ embedding_size = pos_embed_checkpoint.shape[-1]
185
+ num_patches = model.patch_embed.num_patches
186
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
187
+ # height (== width) for the checkpoint position embedding
188
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
189
+ # height (== width) for the new position embedding
190
+ new_size = int(num_patches**0.5)
191
+ # class_token and dist_token are kept unchanged
192
+ if orig_size != new_size:
193
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
194
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
195
+ # only the position tokens are interpolated
196
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
197
+ pos_tokens = pos_tokens.reshape([-1, orig_size, orig_size, embedding_size]).transpose(
198
+ perm=[0, 3, 1, 2]
199
+ )
200
+ pos_tokens = F.interpolate(
201
+ pos_tokens.astype(dtype="float32"), size=(new_size, new_size), mode="bicubic", align_corners=False
202
+ )
203
+ pos_tokens = pos_tokens.transpose(perm=[0, 2, 3, 1]).flatten(start_axis=1, stop_axis=2)
204
+ new_pos_embed = paddle.concat((extra_tokens, pos_tokens), axis=1)
205
+ model_dict["pos_embed"] = new_pos_embed
206
+
207
+ print("cast state_dict to default dtype:{}".format(paddle.get_default_dtype()))
208
+ for key, value in model_dict.items():
209
+ if "freqs_cos" in key or "freqs_sin" in key:
210
+ continue
211
+ model_dict[key] = paddle.cast(value, dtype=paddle.get_default_dtype())
212
+ model.set_state_dict(model_dict)
213
+ del model_dict
214
+ else:
215
+ print("`load` requires a valid value of `ckpt_dir`.")
216
+ raise TypeError("`load` requires a valid value of `ckpt_dir`.")
PaddleMIX/ppdiffusers/LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
PaddleMIX/ppdiffusers/Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .DEFAULT_GOAL := all
3
+
4
+ .PHONY: all
5
+ all: deploy-version build deploy
6
+
7
+ .PHONY: build
8
+ build:
9
+ python3 setup.py sdist bdist_wheel
10
+
11
+ .PHONY: deploy
12
+ deploy:
13
+ make deploy-version
14
+ twine upload --skip-existing dist/*
15
+
16
+ .PHONY: deploy-version
17
+ deploy-version:
18
+ echo "VERSION = '$$(cat VERSION)'" > ppdiffusers/version.py
19
+
20
+ .PHONY: install
21
+ install:
22
+ pip install -r requirements.txt
23
+
24
+ .PHONY: version
25
+ version:
26
+ @newVersion=$$(awk -F. '{print $$1"."$$2"."$$3+1}' < VERSION) \
27
+ && echo $${newVersion} > VERSION \
28
+ && git add VERSION \
29
+ && git commit -m "🔥 update version to $${newVersion}" > /dev/null \
30
+ && echo "Bumped version to $${newVersion}"
PaddleMIX/ppdiffusers/deploy-deprecated/export.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusion 模型导出教程
2
+
3
+
4
+ [PPDiffusers](https://github.com/PaddlePaddle/PaddleMIX/tree/develop/ppdiffusers) 是一款支持跨模态(如图像与语音)训练和推理的扩散模型(Diffusion Model)工具箱,其借鉴了🤗 Huggingface 团队的 [Diffusers](https://github.com/huggingface/diffusers) 的优秀设计,并且依托 [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) 框架和 [PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) 自然语言处理库。下面将介绍如何将 PPDiffusers 提供的预训练模型进行模型导出。
5
+
6
+ ### 模型导出
7
+
8
+ ___注意:模型导出过程中,需要下载 StableDiffusion 模型。为了使用该模型与权重,你必须接受该模型所要求的 License,请访问 HuggingFace 的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的 License,然后签署该协议。___
9
+
10
+ ___Tips: Stable Diffusion 是基于以下的 License: The CreativeML OpenRAIL M license is an Open RAIL M license, adapted from the work that BigScience and the RAIL Initiative are jointly carrying in the area of responsible AI licensing. See also the article about the BLOOM Open RAIL license on which this license is based.___
11
+
12
+ 可执行以下命令行完成模型导出。
13
+
14
+ ```shell
15
+ # 关闭ppxformers,否则会导致模型导出失败
16
+ export USE_PPXFORMERS=False
17
+ python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --output_path stable-diffusion-v1-5
18
+ ```
19
+ 注: 上述指令没有导出固定尺寸的模型,固定尺寸的导出模型有利于优化模型推理性能,但会牺牲一定灵活性。若要导出固定尺寸的模型,可指定`--height`和`--width`参数。
20
+
21
+ 输出的模型目录结构如下:
22
+
23
+ ```shell
24
+ stable-diffusion-v1-5/
25
+ ├── model_index.json
26
+ ├── scheduler
27
+ │ └── scheduler_config.json
28
+ ├── tokenizer
29
+ │ ├── tokenizer_config.json
30
+ │ ├── merges.txt
31
+ │ ├── vocab.json
32
+ │ └── special_tokens_map.json
33
+ ├── text_encoder
34
+ │ ├── inference.pdiparams
35
+ │ ├── inference.pdiparams.info
36
+ │ └── inference.pdmodel
37
+ ├── unet
38
+ │ ├── inference.pdiparams
39
+ │ ├── inference.pdiparams.info
40
+ │ └── inference.pdmodel
41
+ ├── vae_decoder
42
+ │ ├── inference.pdiparams
43
+ │ ├── inference.pdiparams.info
44
+ │ └── inference.pdmodel
45
+ └── vae_encoder
46
+ ├── inference.pdiparams
47
+ ├── inference.pdiparams.info
48
+ └── inference.pdmodel
49
+ ```
50
+
51
+ #### Inpaint 任务模型导出
52
+
53
+ 除了支持常规 StableDiffusion 文生图、图生图任务的模型导出以外,还支持Inpaint任务模型 (注意:这个不是 legacy 版本的 inpaint) 的导出、如果需要导出 inpaint 模型,可以执行以下命令:
54
+
55
+ ```shell
56
+ python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-inpainting --output_path stable-diffusion-v1-5-inpainting
57
+ ```
58
+
59
+ #### 参数说明
60
+
61
+ `export_model.py` 各命令行参数的说明。
62
+
63
+ | 参数 |参数说明 |
64
+ |----------|--------------|
65
+ | <span style="display:inline-block;width: 230pt"> --pretrained_model_name_or_path </span> | ppdiffuers提供的diffusion预训练模型。默认为:"CompVis/stable-diffusion-v1-4"。更多 StableDiffusion 预训练模型可参考 [ppdiffusers 模型列表](../README.md#ppdiffusers模型支持的权重)。|
66
+ | --output_path | 导出的模型目录。 |
67
+ | --sample | vae encoder 的输出是否调整为 sample 模式,注意:sample模式会引入随机因素,默认是 False。|
PaddleMIX/ppdiffusers/deploy-deprecated/export.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ export USE_PPXFORMERS=False
16
+ export CUDA_VISIBLE_DEVICES=1
17
+ python export_model.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --output_path stable-diffusion-v1-5
PaddleMIX/ppdiffusers/deploy-deprecated/export_model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+
18
+ # set USE_PPXFORMERS=False to avoid using ppxformers
19
+ os.environ["USE_PPXFORMERS"] = "False"
20
+ from pathlib import Path
21
+ from types import MethodType
22
+
23
+ import paddle
24
+
25
+ from ppdiffusers import (
26
+ FastDeployRuntimeModel,
27
+ FastDeployStableDiffusionInpaintPipeline,
28
+ FastDeployStableDiffusionMegaPipeline,
29
+ StableDiffusionPipeline,
30
+ UNet2DConditionModel,
31
+ )
32
+
33
+
34
+ def convert_ppdiffusers_pipeline_to_fastdeploy_pipeline(
35
+ model_path: str,
36
+ output_path: str,
37
+ sample: bool = False,
38
+ height: int = None,
39
+ width: int = None,
40
+ ):
41
+ # specify unet model with unet pre_temb_act opt enabled.
42
+ unet_model = UNet2DConditionModel.from_pretrained(model_path, resnet_pre_temb_non_linearity=True, subfolder="unet")
43
+ pipeline = StableDiffusionPipeline.from_pretrained(
44
+ model_path, unet=unet_model, safety_checker=None, feature_extractor=None
45
+ )
46
+ output_path = Path(output_path)
47
+ # calculate latent's H and W
48
+ latent_height = height // 8 if height is not None else None
49
+ latent_width = width // 8 if width is not None else None
50
+ # get arguments
51
+ cross_attention_dim = pipeline.unet.config.cross_attention_dim # 768 or 1024 or 1280
52
+ unet_channels = pipeline.unet.config.in_channels # 4 or 9
53
+ vae_in_channels = pipeline.vae.config.in_channels # 3
54
+ vae_latent_channels = pipeline.vae.config.latent_channels # 4
55
+ print(
56
+ f"cross_attention_dim: {cross_attention_dim}\n",
57
+ f"unet_in_channels: {unet_channels}\n",
58
+ f"vae_encoder_in_channels: {vae_in_channels}\n",
59
+ f"vae_decoder_latent_channels: {vae_latent_channels}",
60
+ )
61
+ # 1. Convert text_encoder
62
+ text_encoder = paddle.jit.to_static(
63
+ pipeline.text_encoder,
64
+ input_spec=[paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")], # input_ids
65
+ )
66
+ save_path = os.path.join(args.output_path, "text_encoder", "inference")
67
+ paddle.jit.save(text_encoder, save_path)
68
+ print(f"Save text_encoder model in {save_path} successfully.")
69
+ del pipeline.text_encoder
70
+
71
+ # 2. Convert unet
72
+ unet = paddle.jit.to_static(
73
+ pipeline.unet,
74
+ input_spec=[
75
+ paddle.static.InputSpec(
76
+ shape=[None, unet_channels, latent_height, latent_width],
77
+ dtype="float32",
78
+ name="sample",
79
+ ), # sample
80
+ paddle.static.InputSpec(shape=[1], dtype="float32", name="timestep"), # timestep
81
+ paddle.static.InputSpec(
82
+ shape=[None, None, cross_attention_dim],
83
+ dtype="float32",
84
+ name="encoder_hidden_states",
85
+ ), # encoder_hidden_states
86
+ ],
87
+ )
88
+ save_path = os.path.join(args.output_path, "unet", "inference")
89
+ paddle.jit.save(unet, save_path)
90
+ print(f"Save unet model in {save_path} successfully.")
91
+ del pipeline.unet
92
+
93
+ def forward_vae_encoder_mode(self, z):
94
+ return self.encode(z, True).latent_dist.mode()
95
+
96
+ def forward_vae_encoder_sample(self, z):
97
+ return self.encode(z, True).latent_dist.sample()
98
+
99
+ # 3. Convert vae encoder
100
+ vae_encoder = pipeline.vae
101
+ if sample:
102
+ vae_encoder.forward = MethodType(forward_vae_encoder_sample, vae_encoder)
103
+ else:
104
+ vae_encoder.forward = MethodType(forward_vae_encoder_mode, vae_encoder)
105
+
106
+ vae_encoder = paddle.jit.to_static(
107
+ vae_encoder,
108
+ input_spec=[
109
+ paddle.static.InputSpec(
110
+ shape=[None, vae_in_channels, height, width],
111
+ dtype="float32",
112
+ name="sample", # N, C, H, W
113
+ ), # latent
114
+ ],
115
+ )
116
+ # Save vae_encoder in static graph model.
117
+ save_path = os.path.join(args.output_path, "vae_encoder", "inference")
118
+ paddle.jit.save(vae_encoder, save_path)
119
+ print(f"Save vae_encoder model in {save_path} successfully.")
120
+
121
+ # 4. Convert vae encoder
122
+ vae_decoder = pipeline.vae
123
+
124
+ def forward_vae_decoder(self, z):
125
+ return self.decode(z, True).sample
126
+
127
+ vae_decoder.forward = MethodType(forward_vae_decoder, vae_decoder)
128
+ vae_decoder = paddle.jit.to_static(
129
+ vae_decoder,
130
+ input_spec=[
131
+ paddle.static.InputSpec(
132
+ shape=[None, vae_latent_channels, latent_height, latent_width],
133
+ dtype="float32",
134
+ name="latent_sample",
135
+ ), # latent_sample
136
+ ],
137
+ )
138
+ # Save vae_decoder in static graph model.
139
+ save_path = os.path.join(args.output_path, "vae_decoder", "inference")
140
+ paddle.jit.save(vae_decoder, save_path)
141
+ print(f"Save vae_decoder model in {save_path} successfully.")
142
+ del pipeline.vae
143
+
144
+ if "inpainting" in model_path:
145
+ fd_pipe_cls = FastDeployStableDiffusionInpaintPipeline
146
+ else:
147
+ fd_pipe_cls = FastDeployStableDiffusionMegaPipeline
148
+
149
+ fastdeploy_pipeline = fd_pipe_cls(
150
+ vae_encoder=FastDeployRuntimeModel.from_pretrained(output_path / "vae_encoder"),
151
+ vae_decoder=FastDeployRuntimeModel.from_pretrained(output_path / "vae_decoder"),
152
+ text_encoder=FastDeployRuntimeModel.from_pretrained(output_path / "text_encoder"),
153
+ unet=FastDeployRuntimeModel.from_pretrained(output_path / "unet"),
154
+ tokenizer=pipeline.tokenizer,
155
+ scheduler=pipeline.scheduler,
156
+ safety_checker=None,
157
+ feature_extractor=None,
158
+ image_encoder=None,
159
+ requires_safety_checker=False,
160
+ )
161
+ fastdeploy_pipeline.save_pretrained(str(output_path))
162
+ print("FastDeploy pipeline saved to", output_path)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ parser = argparse.ArgumentParser()
167
+
168
+ parser.add_argument(
169
+ "--pretrained_model_name_or_path",
170
+ type=str,
171
+ required=True,
172
+ help="Path to the `ppdiffusers` checkpoint to convert (either a local directory or on the bos).",
173
+ )
174
+ parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
175
+ parser.add_argument(
176
+ "--sample",
177
+ action="store_true",
178
+ default=False,
179
+ help="Export the vae encoder in mode or sample",
180
+ )
181
+ parser.add_argument(
182
+ "--height",
183
+ type=int,
184
+ default=None,
185
+ help="The height of output images. Default: None",
186
+ )
187
+ parser.add_argument(
188
+ "--width",
189
+ type=int,
190
+ default=None,
191
+ help="The width of output images. Default: None",
192
+ )
193
+ args = parser.parse_args()
194
+
195
+ convert_ppdiffusers_pipeline_to_fastdeploy_pipeline(
196
+ args.pretrained_model_name_or_path,
197
+ args.output_path,
198
+ args.sample,
199
+ args.height,
200
+ args.width,
201
+ )
PaddleMIX/ppdiffusers/deploy-deprecated/gradio_demo.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import os
16
+
17
+ import cv2
18
+ import fastdeploy as fd
19
+ import gradio as gr
20
+ import numpy as np
21
+ import paddle
22
+ from paddlenlp.trainer.argparser import strtobool
23
+ from PIL import Image
24
+
25
+ from ppdiffusers import FastDeployStableDiffusionMegaPipeline
26
+
27
+
28
+ def create_paddle_inference_runtime(
29
+ use_trt=False,
30
+ dynamic_shape=None,
31
+ use_fp16=False,
32
+ use_bf16=False,
33
+ device_id=0,
34
+ disable_paddle_trt_ops=[],
35
+ disable_paddle_pass=[],
36
+ paddle_stream=None,
37
+ workspace=None,
38
+ ):
39
+ assert not use_fp16 or not use_bf16, "use_fp16 and use_bf16 are mutually exclusive"
40
+ option = fd.RuntimeOption()
41
+ option.use_paddle_backend()
42
+ if device_id == -1:
43
+ option.use_cpu()
44
+ else:
45
+ option.use_gpu(device_id)
46
+ if paddle_stream is not None and use_trt:
47
+ option.set_external_raw_stream(paddle_stream)
48
+ for pass_name in disable_paddle_pass:
49
+ option.paddle_infer_option.delete_pass(pass_name)
50
+ if use_bf16:
51
+ option.paddle_infer_option.inference_precision = "bfloat16"
52
+ if use_trt:
53
+ option.paddle_infer_option.disable_trt_ops(disable_paddle_trt_ops)
54
+ option.paddle_infer_option.enable_trt = True
55
+ if workspace is not None:
56
+ option.set_trt_max_workspace_size(workspace)
57
+ if use_fp16:
58
+ option.trt_option.enable_fp16 = True
59
+ else:
60
+ # Note(zhoushunjie): These four passes don't support fp32 now.
61
+ # Remove this line of code in future.
62
+ only_fp16_passes = [
63
+ "trt_cross_multihead_matmul_fuse_pass",
64
+ "trt_flash_multihead_matmul_fuse_pass",
65
+ "preln_elementwise_groupnorm_act_pass",
66
+ "elementwise_groupnorm_act_pass",
67
+ ]
68
+ for curr_pass in only_fp16_passes:
69
+ option.paddle_infer_option.delete_pass(curr_pass)
70
+
71
+ # Need to enable collect shape
72
+ if dynamic_shape is not None:
73
+ option.paddle_infer_option.collect_trt_shape = True
74
+ for key, shape_dict in dynamic_shape.items():
75
+ option.trt_option.set_shape(
76
+ key,
77
+ shape_dict["min_shape"],
78
+ shape_dict.get("opt_shape", None),
79
+ shape_dict.get("max_shape", None),
80
+ )
81
+ return option
82
+
83
+
84
+ def create_trt_runtime(workspace=(1 << 31), dynamic_shape=None, use_fp16=False, device_id=0):
85
+ option = fd.RuntimeOption()
86
+ option.use_trt_backend()
87
+ option.use_gpu(device_id)
88
+ if use_fp16:
89
+ option.enable_trt_fp16()
90
+ if workspace is not None:
91
+ option.set_trt_max_workspace_size(workspace)
92
+ if dynamic_shape is not None:
93
+ for key, shape_dict in dynamic_shape.items():
94
+ option.set_trt_input_shape(
95
+ key,
96
+ min_shape=shape_dict["min_shape"],
97
+ opt_shape=shape_dict.get("opt_shape", None),
98
+ max_shape=shape_dict.get("max_shape", None),
99
+ )
100
+ return option
101
+
102
+
103
+ def pipe_init(args):
104
+ paddle.set_device(f"gpu:{args.device_id}")
105
+ paddle_stream = paddle.device.cuda.current_stream(args.device_id).cuda_stream
106
+ vae_in_channels = 4
107
+ text_encoder_max_length = 77
108
+ unet_max_length = text_encoder_max_length * 3 # lpw support max_length is 77x3
109
+ min_image_size = 384
110
+ max_image_size = 768
111
+ hidden_states = 1024 if args.is_sd2_0 else 768
112
+ unet_in_channels = 9 if args.task_name == "inpaint" else 4
113
+ bs = 2
114
+
115
+ text_encoder_dynamic_shape = {
116
+ "input_ids": {
117
+ "min_shape": [1, text_encoder_max_length],
118
+ "max_shape": [1, text_encoder_max_length],
119
+ "opt_shape": [1, text_encoder_max_length],
120
+ }
121
+ }
122
+
123
+ vae_encoder_dynamic_shape = {
124
+ "sample": {
125
+ "min_shape": [1, 3, min_image_size, min_image_size],
126
+ "max_shape": [1, 3, max_image_size, max_image_size],
127
+ "opt_shape": [1, 3, min_image_size, min_image_size],
128
+ }
129
+ }
130
+
131
+ vae_decoder_dynamic_shape = {
132
+ "latent_sample": {
133
+ "min_shape": [1, vae_in_channels, min_image_size // 8, min_image_size // 8],
134
+ "max_shape": [1, vae_in_channels, max_image_size // 8, max_image_size // 8],
135
+ "opt_shape": [1, vae_in_channels, min_image_size // 8, min_image_size // 8],
136
+ }
137
+ }
138
+
139
+ unet_dynamic_shape = {
140
+ "sample": {
141
+ "min_shape": [
142
+ 1,
143
+ unet_in_channels,
144
+ min_image_size // 8,
145
+ min_image_size // 8,
146
+ ],
147
+ "max_shape": [
148
+ bs,
149
+ unet_in_channels,
150
+ max_image_size // 8,
151
+ max_image_size // 8,
152
+ ],
153
+ "opt_shape": [
154
+ 2,
155
+ unet_in_channels,
156
+ min_image_size // 8,
157
+ min_image_size // 8,
158
+ ],
159
+ },
160
+ "timestep": {
161
+ "min_shape": [1],
162
+ "max_shape": [1],
163
+ "opt_shape": [1],
164
+ },
165
+ "encoder_hidden_states": {
166
+ "min_shape": [1, text_encoder_max_length, hidden_states],
167
+ "max_shape": [bs, unet_max_length, hidden_states],
168
+ "opt_shape": [2, text_encoder_max_length, hidden_states],
169
+ },
170
+ }
171
+ # 4. Init runtime
172
+ if args.backend == "tensorrt":
173
+ runtime_options = dict(
174
+ text_encoder=create_trt_runtime(
175
+ dynamic_shape=text_encoder_dynamic_shape,
176
+ use_fp16=args.use_fp16,
177
+ device_id=args.device_id,
178
+ ),
179
+ vae_encoder=create_trt_runtime(
180
+ dynamic_shape=vae_encoder_dynamic_shape,
181
+ use_fp16=args.use_fp16,
182
+ device_id=args.device_id,
183
+ ),
184
+ vae_decoder=create_trt_runtime(
185
+ dynamic_shape=vae_decoder_dynamic_shape,
186
+ use_fp16=args.use_fp16,
187
+ device_id=args.device_id,
188
+ ),
189
+ unet=create_trt_runtime(
190
+ dynamic_shape=unet_dynamic_shape,
191
+ use_fp16=args.use_fp16,
192
+ device_id=args.device_id,
193
+ ),
194
+ )
195
+ elif args.backend == "paddle" or args.backend == "paddle_tensorrt":
196
+ args.use_trt = args.backend == "paddle_tensorrt"
197
+ runtime_options = dict(
198
+ text_encoder=create_paddle_inference_runtime(
199
+ use_trt=args.use_trt,
200
+ dynamic_shape=text_encoder_dynamic_shape,
201
+ use_fp16=args.use_fp16,
202
+ use_bf16=args.use_bf16,
203
+ device_id=args.device_id,
204
+ disable_paddle_trt_ops=["arg_max", "range", "lookup_table_v2"],
205
+ paddle_stream=paddle_stream,
206
+ ),
207
+ vae_encoder=create_paddle_inference_runtime(
208
+ use_trt=args.use_trt,
209
+ dynamic_shape=vae_encoder_dynamic_shape,
210
+ use_fp16=args.use_fp16,
211
+ use_bf16=args.use_bf16,
212
+ device_id=args.device_id,
213
+ paddle_stream=paddle_stream,
214
+ ),
215
+ vae_decoder=create_paddle_inference_runtime(
216
+ use_trt=args.use_trt,
217
+ dynamic_shape=vae_decoder_dynamic_shape,
218
+ use_fp16=args.use_fp16,
219
+ use_bf16=args.use_bf16,
220
+ device_id=args.device_id,
221
+ paddle_stream=paddle_stream,
222
+ ),
223
+ unet=create_paddle_inference_runtime(
224
+ use_trt=args.use_trt,
225
+ dynamic_shape=unet_dynamic_shape,
226
+ use_fp16=args.use_fp16,
227
+ use_bf16=args.use_bf16,
228
+ device_id=args.device_id,
229
+ paddle_stream=paddle_stream,
230
+ ),
231
+ )
232
+ pipe = FastDeployStableDiffusionMegaPipeline.from_pretrained(
233
+ args.model_dir,
234
+ runtime_options=runtime_options,
235
+ )
236
+ pipe.set_progress_bar_config(disable=True)
237
+ return pipe
238
+
239
+
240
+ def parse_arguments():
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument(
243
+ "--model_dir",
244
+ default="stable-diffusion-v1-5",
245
+ help="The model directory of diffusion_model.",
246
+ )
247
+ parser.add_argument(
248
+ "--task_name",
249
+ type=str,
250
+ default="text2img_img2img_inpaint_legacy",
251
+ choices=[
252
+ "text2img_img2img_inpaint_legacy",
253
+ "inpaint",
254
+ "controlnet_canny",
255
+ ],
256
+ help="The task can be one of [text2img_img2img_inpaint_legacy, inpaint, controlnet_canny]. ",
257
+ )
258
+ parser.add_argument(
259
+ "--backend",
260
+ type=str,
261
+ default="paddle",
262
+ # Note(zhoushunjie): Will support 'tensorrt' soon.
263
+ choices=["paddle", "paddle_tensorrt"],
264
+ help="The inference runtime backend of unet model and text encoder model.",
265
+ )
266
+ parser.add_argument("--use_fp16", type=strtobool, default=True, help="Wheter to use FP16 mode")
267
+ parser.add_argument("--use_bf16", type=strtobool, default=False, help="Wheter to use BF16 mode")
268
+ parser.add_argument("--device_id", type=int, default=0, help="The selected gpu id.")
269
+ parser.add_argument(
270
+ "--parse_prompt_type",
271
+ type=str,
272
+ default="lpw",
273
+ choices=[
274
+ "raw",
275
+ "lpw",
276
+ ],
277
+ help="The parse_prompt_type can be one of [raw, lpw]. ",
278
+ )
279
+ parser.add_argument("--is_sd2_0", type=strtobool, default=False, help="Is sd2_0 model?")
280
+ return parser.parse_args()
281
+
282
+
283
+ def get_canny_image(image):
284
+ if image is not None:
285
+ low_threshold = 100
286
+ high_threshold = 200
287
+ image = cv2.Canny(np.array(image), low_threshold, high_threshold)
288
+ image = image[:, :, None]
289
+ image = np.concatenate([image, image, image], axis=2)
290
+ return image
291
+
292
+
293
+ def infer(
294
+ taskname,
295
+ image,
296
+ mask,
297
+ prompt,
298
+ negative_prompt,
299
+ steps,
300
+ height,
301
+ width,
302
+ seed,
303
+ strength,
304
+ guidance_scale,
305
+ scheduler,
306
+ conditioning_scale,
307
+ ):
308
+ task_name = taskname
309
+ fd_pipe.change_scheduler(scheduler)
310
+
311
+ if int(seed) != -1:
312
+ generator = paddle.Generator("cuda").manual_seed(seed)
313
+ else:
314
+ generator = None
315
+
316
+ if image is not None:
317
+ if isinstance(image, dict):
318
+ image["image"] = cv2.resize(image["image"], (width, height))
319
+ image["mask"] = cv2.resize(image["mask"], (width, height))
320
+ else:
321
+ image = cv2.resize(image, (width, height))
322
+ if mask is not None:
323
+ mask = cv2.resize(mask, (width, height))
324
+
325
+ if task_name == "text2img":
326
+ images = fd_pipe.text2img(
327
+ prompt=prompt,
328
+ negative_prompt=negative_prompt,
329
+ num_inference_steps=steps,
330
+ height=height,
331
+ width=width,
332
+ guidance_scale=guidance_scale,
333
+ parse_prompt_type=parse_prompt_type,
334
+ infer_op_dict=infer_op_dict,
335
+ generator=generator,
336
+ )
337
+ elif task_name == "img2img":
338
+ images = fd_pipe.img2img(
339
+ prompt=prompt,
340
+ negative_prompt=negative_prompt,
341
+ image=Image.fromarray(np.array(image)).convert("RGB"),
342
+ num_inference_steps=steps,
343
+ height=height,
344
+ width=width,
345
+ strength=strength,
346
+ guidance_scale=guidance_scale,
347
+ parse_prompt_type=parse_prompt_type,
348
+ infer_op_dict=infer_op_dict,
349
+ generator=generator,
350
+ )
351
+ elif task_name == "inpaint_legacy":
352
+ if mask is not None:
353
+ mask_image = mask
354
+ else:
355
+ mask_image = image["mask"]
356
+ image = image["image"]
357
+ images = fd_pipe.inpaint_legacy(
358
+ prompt=prompt,
359
+ negative_prompt=negative_prompt,
360
+ image=Image.fromarray(np.array(image)).convert("RGB"),
361
+ mask_image=Image.fromarray(mask_image).convert("RGB"),
362
+ num_inference_steps=steps,
363
+ height=height,
364
+ width=width,
365
+ strength=strength,
366
+ guidance_scale=guidance_scale,
367
+ parse_prompt_type=parse_prompt_type,
368
+ infer_op_dict=infer_op_dict,
369
+ generator=generator,
370
+ )
371
+ elif task_name == "inpaint":
372
+ if mask is not None:
373
+ mask_image = mask
374
+ else:
375
+ mask_image = image["mask"]
376
+ image = image["image"]
377
+ images = fd_pipe.inpaint(
378
+ prompt=prompt,
379
+ negative_prompt=negative_prompt,
380
+ image=Image.fromarray(np.array(image)).convert("RGB"),
381
+ mask_image=Image.fromarray(mask_image).convert("RGB"),
382
+ num_inference_steps=steps,
383
+ height=height,
384
+ width=width,
385
+ strength=strength,
386
+ guidance_scale=guidance_scale,
387
+ parse_prompt_type=parse_prompt_type,
388
+ infer_op_dict=infer_op_dict,
389
+ generator=generator,
390
+ )
391
+
392
+ elif task_name == "controlnet_canny":
393
+ canny_image = Image.fromarray(mask)
394
+
395
+ images = fd_pipe.text2img(
396
+ prompt=prompt,
397
+ negative_prompt=negative_prompt,
398
+ num_inference_steps=steps,
399
+ height=height,
400
+ width=width,
401
+ guidance_scale=guidance_scale,
402
+ parse_prompt_type=parse_prompt_type,
403
+ controlnet_cond=canny_image,
404
+ controlnet_conditioning_scale=conditioning_scale,
405
+ infer_op_dict=infer_op_dict,
406
+ generator=generator,
407
+ )
408
+ else:
409
+ return gr.Error(f"task error! {task_name} not found ")
410
+
411
+ return images[0][0]
412
+
413
+
414
+ scheduler_choices = [
415
+ "pndm",
416
+ "lms",
417
+ "euler",
418
+ "euler-ancestral",
419
+ "preconfig-euler-ancestral",
420
+ "dpm-multi",
421
+ "dpm-single",
422
+ "unipc-multi",
423
+ "ddim",
424
+ "ddpm",
425
+ "deis-multi",
426
+ "heun",
427
+ "kdpm2-ancestral",
428
+ "kdpm2",
429
+ ]
430
+
431
+ # some param init
432
+ args = parse_arguments()
433
+ if "model_dir" and "task_name" in os.environ:
434
+ args.model_dir = os.environ["model_dir"]
435
+ args.task_name = os.environ["task_name"]
436
+
437
+ fd_pipe = pipe_init(args)
438
+ parse_prompt_type = args.parse_prompt_type
439
+ if args.backend == "paddle":
440
+ print("When device is kunlunxin_xpu or backend is paddle, we will use `raw` infer op.")
441
+ infer_op_mode = "raw"
442
+ else:
443
+ infer_op_mode = "zero_copy_infer"
444
+ infer_op_dict = {
445
+ "vae_encoder": infer_op_mode,
446
+ "vae_decoder": infer_op_mode,
447
+ "text_encoder": infer_op_mode,
448
+ "unet": infer_op_mode,
449
+ }
450
+
451
+ with gr.Blocks() as demo:
452
+ gr.Markdown("# FastDeploy Stablediffusion")
453
+ if args.task_name == "text2img_img2img_inpaint_legacy":
454
+ with gr.Tab("text2img"):
455
+ with gr.Row():
456
+ with gr.Column():
457
+ text2img_taskname = gr.State(value="text2img")
458
+ text2img_img = gr.State(value=None)
459
+ text2img_mask = gr.State(value=None)
460
+ text2img_prompt = gr.Textbox(label="正向描述词", lines=2)
461
+ text2img_negative_prompt = gr.Textbox(label="负向描述词", lines=2)
462
+ text2img_steps = gr.Slider(label="steps", minimum=1, maximum=60, step=1, value=20)
463
+ with gr.Row():
464
+ text2img_height = gr.Slider(label="height", minimum=384, maximum=768, step=8, value=512)
465
+ text2img_width = gr.Slider(label="width", minimum=384, maximum=768, step=8, value=512)
466
+ text2img_seed = gr.Textbox(label="seed", value="-1")
467
+ text2img_strength = gr.State(value=None)
468
+ text2img_guidance_scale = gr.Slider(
469
+ label="guidance_scale", minimum=1, maximum=30, step=0.5, value=7.5
470
+ )
471
+ text2img_scheduler = gr.Radio(label="采样方法", choices=scheduler_choices, value="ddim")
472
+ text2img_conditioning_scale = gr.State(value=None)
473
+ with gr.Column():
474
+ text2img_output = gr.Image(type="numpy", label="result")
475
+ text2img_button = gr.Button("生成")
476
+ text2img_button.click(
477
+ fn=infer,
478
+ inputs=[
479
+ text2img_taskname,
480
+ text2img_img,
481
+ text2img_mask,
482
+ text2img_prompt,
483
+ text2img_negative_prompt,
484
+ text2img_steps,
485
+ text2img_height,
486
+ text2img_width,
487
+ text2img_seed,
488
+ text2img_strength,
489
+ text2img_guidance_scale,
490
+ text2img_scheduler,
491
+ text2img_conditioning_scale,
492
+ ],
493
+ outputs=[text2img_output],
494
+ )
495
+
496
+ with gr.Tab("img2img"):
497
+ with gr.Row():
498
+ with gr.Column():
499
+ img2img_taskname = gr.State(value="img2img")
500
+ img2img_img = gr.Image(label="原图")
501
+ img2img_mask = gr.State(value=None)
502
+ img2img_prompt = gr.Textbox(label="请输入描述词", lines=2)
503
+ img2img_negative_prompt = gr.Textbox(label="负向描述词", lines=2)
504
+ img2img_steps = gr.Slider(label="steps", minimum=1, maximum=60, step=1, value=20)
505
+ with gr.Row():
506
+ img2img_height = gr.Slider(label="height", minimum=384, maximum=768, step=8, value=512)
507
+ img2img_width = gr.Slider(label="width", minimum=384, maximum=768, step=8, value=512)
508
+ img2img_seed = gr.Textbox(label="seed", value="-1")
509
+ img2img_strength = gr.Slider(
510
+ label="Denoising strength", minimum=0, maximum=1, step=0.01, value=0.75
511
+ )
512
+ img2img_guidance_scale = gr.Slider(
513
+ label="guidance_scale", minimum=1, maximum=30, step=0.5, value=7.5
514
+ )
515
+ img2img_scheduler = gr.Radio(label="采样方法", choices=scheduler_choices, value="ddim")
516
+ img2img_conditioning_scale = gr.State(value=None)
517
+ with gr.Column():
518
+ img2img_output = gr.Image(type="numpy", label="result")
519
+ img2img_button = gr.Button("生成")
520
+ img2img_button.click(
521
+ fn=infer,
522
+ inputs=[
523
+ img2img_taskname,
524
+ img2img_img,
525
+ img2img_mask,
526
+ img2img_prompt,
527
+ img2img_negative_prompt,
528
+ img2img_steps,
529
+ img2img_height,
530
+ img2img_width,
531
+ img2img_seed,
532
+ img2img_strength,
533
+ img2img_guidance_scale,
534
+ img2img_scheduler,
535
+ img2img_conditioning_scale,
536
+ ],
537
+ outputs=[img2img_output],
538
+ )
539
+
540
+ with gr.Tab("inpaint_legacy"):
541
+ with gr.Row():
542
+ with gr.Column():
543
+ inpaint_legacy_taskname = gr.State(value="inpaint_legacy")
544
+ inpaint_legacy_img = gr.ImageMask(label="传入原图并涂鸦mask")
545
+ inpaint_legacy_mask = gr.Image(label="重绘mask(可选,若不涂鸦则需要传入)", image_mode="L")
546
+ inpaint_legacy_prompt = gr.Textbox(label="请输入正向描述词", lines=2)
547
+ inpaint_legacy_negative_prompt = gr.Textbox(label="负向描述词", lines=2)
548
+ inpaint_legacy_steps = gr.Slider(label="steps", minimum=1, maximum=60, step=1, value=20)
549
+ with gr.Row():
550
+ inpaint_legacy_height = gr.Slider(label="height", minimum=384, maximum=768, step=8, value=512)
551
+ inpaint_legacy_width = gr.Slider(label="width", minimum=384, maximum=768, step=8, value=512)
552
+ inpaint_legacy_seed = gr.Textbox(label="seed", value="-1")
553
+ inpaint_legacy_strength = gr.Slider(
554
+ label="Denoising strength", minimum=0, maximum=1, step=0.01, value=0.75
555
+ )
556
+ inpaint_legacy_guidance_scale = gr.Slider(
557
+ label="guidance_scale", minimum=1, maximum=30, step=0.5, value=7.5
558
+ )
559
+ inpaint_legacy_scheduler = gr.Radio(label="采样方法", choices=scheduler_choices, value="ddim")
560
+ inpaint_legacy_conditioning_scale = gr.State(value=None)
561
+ with gr.Column():
562
+ inpaint_legacy_output = gr.Image(type="numpy", label="result")
563
+ inpaint_legacy_button = gr.Button("生成")
564
+ inpaint_legacy_button.click(
565
+ fn=infer,
566
+ inputs=[
567
+ inpaint_legacy_taskname,
568
+ inpaint_legacy_img,
569
+ inpaint_legacy_mask,
570
+ inpaint_legacy_prompt,
571
+ inpaint_legacy_negative_prompt,
572
+ inpaint_legacy_steps,
573
+ inpaint_legacy_height,
574
+ inpaint_legacy_width,
575
+ inpaint_legacy_seed,
576
+ inpaint_legacy_strength,
577
+ inpaint_legacy_guidance_scale,
578
+ inpaint_legacy_scheduler,
579
+ inpaint_legacy_conditioning_scale,
580
+ ],
581
+ outputs=[inpaint_legacy_output],
582
+ )
583
+
584
+ elif args.task_name == "inpaint":
585
+ with gr.Tab("inpaint"):
586
+ with gr.Row():
587
+ with gr.Column():
588
+ inpaint_taskname = gr.State(value="inpaint")
589
+ inpaint_img = gr.ImageMask(label="传入原图并涂鸦mask")
590
+ inpaint_mask = gr.Image(label="重绘mask(可选,若不涂鸦则需要传入)", image_mode="L")
591
+ inpaint_prompt = gr.Textbox(label="请输入正向描述词", lines=2)
592
+ inpaint_negative_prompt = gr.Textbox(label="负向描述词", lines=2)
593
+ inpaint_steps = gr.Slider(label="steps", minimum=1, maximum=60, step=1, value=20)
594
+ with gr.Row():
595
+ inpaint_height = gr.Slider(label="height", minimum=384, maximum=768, step=8, value=512)
596
+ inpaint_width = gr.Slider(label="width", minimum=384, maximum=768, step=8, value=512)
597
+ inpaint_seed = gr.Textbox(label="seed", value="-1")
598
+ inpaint_strength = gr.Slider(
599
+ label="Denoising strength", minimum=0, maximum=1, step=0.01, value=0.75
600
+ )
601
+ inpaint_guidance_scale = gr.Slider(
602
+ label="guidance_scale", minimum=1, maximum=30, step=0.5, value=7.5
603
+ )
604
+ inpaint_scheduler = gr.Radio(label="采样方法", choices=scheduler_choices, value="ddim")
605
+ inpaint_conditioning_scale = gr.State(value=None)
606
+ with gr.Column():
607
+ inpaint_output = gr.Image(type="numpy", label="result")
608
+ inpaint_button = gr.Button("生成")
609
+
610
+ inpaint_button.click(
611
+ fn=infer,
612
+ inputs=[
613
+ inpaint_taskname,
614
+ inpaint_img,
615
+ inpaint_mask,
616
+ inpaint_prompt,
617
+ inpaint_negative_prompt,
618
+ inpaint_steps,
619
+ inpaint_height,
620
+ inpaint_width,
621
+ inpaint_seed,
622
+ inpaint_strength,
623
+ inpaint_guidance_scale,
624
+ inpaint_scheduler,
625
+ inpaint_conditioning_scale,
626
+ ],
627
+ outputs=[inpaint_output],
628
+ )
629
+
630
+ elif args.task_name == "controlnet_canny":
631
+ with gr.Tab("controlnet_canny"):
632
+ with gr.Row():
633
+ with gr.Column():
634
+ controlnet_canny_taskname = gr.State(value="controlnet_canny")
635
+ controlnet_canny_img = gr.Image(label="canny参考图")
636
+ controlnet_canny_mask = gr.Image(label="canny图(可选传入)")
637
+ controlnet_canny_prompt = gr.Textbox(label="请输入正向描述词", lines=2)
638
+ controlnet_canny_negative_prompt = gr.Textbox(label="负向描述词", lines=2)
639
+ controlnet_canny_steps = gr.Slider(label="steps", minimum=1, maximum=60, step=1, value=20)
640
+ with gr.Row():
641
+ controlnet_canny_height = gr.Slider(
642
+ label="height", minimum=384, maximum=768, step=8, value=512
643
+ )
644
+ controlnet_canny_width = gr.Slider(label="width", minimum=384, maximum=768, step=8, value=512)
645
+ controlnet_canny_seed = gr.Textbox(label="seed", value="-1")
646
+ controlnet_canny_strength = gr.Slider(
647
+ label="Denoising strength", minimum=0, maximum=1, step=0.01, value=0.75
648
+ )
649
+ controlnet_canny_guidance_scale = gr.Slider(
650
+ label="guidance_scale", minimum=1, maximum=30, step=0.5, value=7.5
651
+ )
652
+ controlnet_canny_scheduler = gr.Radio(label="采样方法", choices=scheduler_choices, value="ddim")
653
+ controlnet_canny_conditioning_scale = gr.Slider(
654
+ label="conditioning_scale", minimum=0, maximum=2, step=0.05, value=1
655
+ )
656
+ with gr.Column():
657
+ controlnet_canny_output = gr.Image(type="numpy", label="result")
658
+ controlnet_canny_button = gr.Button("生成")
659
+ controlnet_canny_img.change(
660
+ fn=get_canny_image, inputs=[controlnet_canny_img], outputs=[controlnet_canny_mask]
661
+ )
662
+ controlnet_canny_button.click(
663
+ fn=infer,
664
+ inputs=[
665
+ controlnet_canny_taskname,
666
+ controlnet_canny_img,
667
+ controlnet_canny_mask,
668
+ controlnet_canny_prompt,
669
+ controlnet_canny_negative_prompt,
670
+ controlnet_canny_steps,
671
+ controlnet_canny_height,
672
+ controlnet_canny_width,
673
+ controlnet_canny_seed,
674
+ controlnet_canny_strength,
675
+ controlnet_canny_guidance_scale,
676
+ controlnet_canny_scheduler,
677
+ controlnet_canny_conditioning_scale,
678
+ ],
679
+ outputs=[controlnet_canny_output],
680
+ )
681
+
682
+ if __name__ == "__main__":
683
+ demo.launch(show_error=True)
PaddleMIX/ppdiffusers/deploy-deprecated/infer.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ import time
18
+
19
+ # isort: split
20
+ import paddle
21
+
22
+ # isort: split
23
+ import fastdeploy as fd
24
+ import numpy as np
25
+ from paddlenlp.trainer.argparser import strtobool
26
+ from tqdm.auto import trange
27
+
28
+ from ppdiffusers import DiffusionPipeline, FastDeployStableDiffusionMegaPipeline
29
+ from ppdiffusers.utils import load_image
30
+
31
+
32
+ def parse_arguments():
33
+
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "--model_dir",
37
+ default="runwayml/stable-diffusion-v1-5@fastdeploy",
38
+ help="The model directory of diffusion_model.",
39
+ )
40
+ parser.add_argument(
41
+ "--inference_steps",
42
+ type=int,
43
+ default=50,
44
+ help="The number of unet inference steps.",
45
+ )
46
+ parser.add_argument(
47
+ "--benchmark_steps",
48
+ type=int,
49
+ default=1,
50
+ help="The number of performance benchmark steps.",
51
+ )
52
+ parser.add_argument(
53
+ "--backend",
54
+ type=str,
55
+ default="paddle_tensorrt",
56
+ # Note(zhoushunjie): Will support 'tensorrt' soon.
57
+ choices=["onnx_runtime", "paddle", "paddlelite", "paddle_tensorrt"],
58
+ help="The inference runtime backend of unet model and text encoder model.",
59
+ )
60
+ parser.add_argument(
61
+ "--device",
62
+ type=str,
63
+ default="gpu",
64
+ # Note(shentanyue): Will support more devices.
65
+ choices=[
66
+ "cpu",
67
+ "gpu",
68
+ "huawei_ascend_npu",
69
+ "kunlunxin_xpu",
70
+ ],
71
+ help="The inference runtime device of models.",
72
+ )
73
+ parser.add_argument(
74
+ "--task_name",
75
+ type=str,
76
+ default="text2img",
77
+ choices=[
78
+ "text2img",
79
+ "img2img",
80
+ "inpaint",
81
+ "inpaint_legacy",
82
+ "cycle_diffusion",
83
+ "hiresfix",
84
+ "mixture_tiling",
85
+ "all",
86
+ ],
87
+ help="The task can be one of [text2img, img2img, inpaint, inpaint_legacy, cycle_diffusion, hiresfix, mixture_tiling, all]. ",
88
+ )
89
+ parser.add_argument(
90
+ "--parse_prompt_type",
91
+ type=str,
92
+ default="lpw",
93
+ choices=[
94
+ "raw",
95
+ "lpw",
96
+ ],
97
+ help="The parse_prompt_type can be one of [raw, lpw]. ",
98
+ )
99
+ parser.add_argument("--use_fp16", type=strtobool, default=True, help="Wheter to use FP16 mode")
100
+ parser.add_argument("--use_bf16", type=strtobool, default=False, help="Wheter to use BF16 mode")
101
+ parser.add_argument("--device_id", type=int, default=0, help="The selected gpu id. -1 means use cpu")
102
+ parser.add_argument(
103
+ "--scheduler",
104
+ type=str,
105
+ default="preconfig-euler-ancestral",
106
+ choices=[
107
+ "pndm",
108
+ "lms",
109
+ "euler",
110
+ "euler-ancestral",
111
+ "preconfig-euler-ancestral",
112
+ "dpm-multi",
113
+ "dpm-single",
114
+ "unipc-multi",
115
+ "ddim",
116
+ "ddpm",
117
+ "deis-multi",
118
+ "heun",
119
+ "kdpm2-ancestral",
120
+ "kdpm2",
121
+ ],
122
+ help="The scheduler type of stable diffusion.",
123
+ )
124
+ parser.add_argument(
125
+ "--infer_op",
126
+ type=str,
127
+ default="zero_copy_infer",
128
+ choices=[
129
+ "zero_copy_infer",
130
+ "raw",
131
+ "all",
132
+ ],
133
+ help="The type of infer op.",
134
+ )
135
+ parser.add_argument("--height", type=int, default=512, help="Height of input image")
136
+ parser.add_argument("--width", type=int, default=512, help="Width of input image")
137
+ parser.add_argument("--hr_resize_height", type=int, default=768, help="HR Height of input image")
138
+ parser.add_argument("--hr_resize_width", type=int, default=768, help="HR Width of input image")
139
+ parser.add_argument("--is_sd2_0", type=strtobool, default=False, help="Is sd2_0 model?")
140
+
141
+ return parser.parse_args()
142
+
143
+
144
+ def create_ort_runtime(device_id=0):
145
+ option = fd.RuntimeOption()
146
+ option.use_ort_backend()
147
+ if device_id == -1:
148
+ option.use_cpu()
149
+ else:
150
+ option.use_gpu(device_id)
151
+ return option
152
+
153
+
154
+ def create_paddle_inference_runtime(
155
+ use_trt=False,
156
+ dynamic_shape=None,
157
+ use_fp16=False,
158
+ use_bf16=False,
159
+ device_id=0,
160
+ disable_paddle_trt_ops=[],
161
+ disable_paddle_pass=[],
162
+ paddle_stream=None,
163
+ workspace=None,
164
+ ):
165
+ assert not use_fp16 or not use_bf16, "use_fp16 and use_bf16 are mutually exclusive"
166
+ option = fd.RuntimeOption()
167
+ option.use_paddle_backend()
168
+ if device_id == -1:
169
+ option.use_cpu()
170
+ else:
171
+ option.use_gpu(device_id)
172
+ if paddle_stream is not None and use_trt:
173
+ option.set_external_raw_stream(paddle_stream)
174
+ for pass_name in disable_paddle_pass:
175
+ option.paddle_infer_option.delete_pass(pass_name)
176
+ if use_bf16:
177
+ option.paddle_infer_option.inference_precision = "bfloat16"
178
+ if use_trt:
179
+ option.paddle_infer_option.disable_trt_ops(disable_paddle_trt_ops)
180
+ option.paddle_infer_option.enable_trt = True
181
+ if workspace is not None:
182
+ option.set_trt_max_workspace_size(workspace)
183
+ if use_fp16:
184
+ option.trt_option.enable_fp16 = True
185
+ else:
186
+ # Note(zhoushunjie): These four passes don't support fp32 now.
187
+ # Remove this line of code in future.
188
+ only_fp16_passes = [
189
+ "trt_cross_multihead_matmul_fuse_pass",
190
+ "trt_flash_multihead_matmul_fuse_pass",
191
+ "preln_elementwise_groupnorm_act_pass",
192
+ "elementwise_groupnorm_act_pass",
193
+ ]
194
+ for curr_pass in only_fp16_passes:
195
+ option.paddle_infer_option.delete_pass(curr_pass)
196
+
197
+ # Need to enable collect shape
198
+ if dynamic_shape is not None:
199
+ option.paddle_infer_option.collect_trt_shape = True
200
+ for key, shape_dict in dynamic_shape.items():
201
+ option.trt_option.set_shape(
202
+ key,
203
+ shape_dict["min_shape"],
204
+ shape_dict.get("opt_shape", None),
205
+ shape_dict.get("max_shape", None),
206
+ )
207
+ return option
208
+
209
+
210
+ def create_paddle_lite_runtime(device="cpu", device_id=0, use_fp16=False):
211
+ option = fd.RuntimeOption()
212
+ option.use_paddle_lite_backend()
213
+ if device == "huawei_ascend_npu":
214
+ option.use_ascend()
215
+ option.set_lite_device_names(["huawei_ascend_npu"])
216
+ option.set_lite_context_properties(
217
+ "HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS={};HUAWEI_ASCEND_NPU_PRECISION_MODE=allow_mix_precision".format(
218
+ device_id
219
+ )
220
+ )
221
+ elif device == "kunlunxin_xpu":
222
+ # TODO(shentanyue): Add kunlunxin_xpu code
223
+ # https://github.com/PaddlePaddle/FastDeploy/blob/4c3e7030e151528d304619901c794481bb2f6037/examples/multimodal/stable_diffusion/infer.py#L178-L195
224
+ option.use_kunlunxin(
225
+ device_id,
226
+ l3_workspace_size=(64 * 1024 * 1024 - 4 * 1024),
227
+ locked=False,
228
+ autotune=False,
229
+ autotune_file="",
230
+ precision="int16",
231
+ adaptive_seqlen=True,
232
+ enable_multi_stream=True,
233
+ )
234
+ if use_fp16:
235
+ option.enable_lite_fp16()
236
+ else:
237
+ pass
238
+ return option
239
+
240
+
241
+ def create_trt_runtime(workspace=(1 << 31), dynamic_shape=None, use_fp16=False, device_id=0):
242
+ option = fd.RuntimeOption()
243
+ option.use_trt_backend()
244
+ option.use_gpu(device_id)
245
+ if use_fp16:
246
+ option.enable_trt_fp16()
247
+ if workspace is not None:
248
+ option.set_trt_max_workspace_size(workspace)
249
+ if dynamic_shape is not None:
250
+ for key, shape_dict in dynamic_shape.items():
251
+ option.set_trt_input_shape(
252
+ key,
253
+ min_shape=shape_dict["min_shape"],
254
+ opt_shape=shape_dict.get("opt_shape", None),
255
+ max_shape=shape_dict.get("max_shape", None),
256
+ )
257
+ return option
258
+
259
+
260
+ def main(args):
261
+ if args.device_id == -1:
262
+ paddle.set_device("cpu")
263
+ paddle_stream = None
264
+ else:
265
+ paddle.set_device(f"gpu:{args.device_id}")
266
+ paddle_stream = paddle.device.cuda.current_stream(args.device_id).cuda_stream
267
+
268
+ seed = 1024
269
+ vae_in_channels = 4
270
+ text_encoder_max_length = 77
271
+ unet_max_length = text_encoder_max_length * 3 # lpw support max_length is 77x3
272
+ min_image_size = 512
273
+ max_image_size = 768
274
+ max_image_size = max(min_image_size, max_image_size)
275
+ hidden_states = 1024 if args.is_sd2_0 else 768
276
+ unet_in_channels = 9 if args.task_name == "inpaint" else 4
277
+
278
+ if args.task_name == "cycle_diffusion":
279
+ bs = 4
280
+ min_image_size = max_image_size = 512
281
+ else:
282
+ bs = 2
283
+
284
+ text_encoder_dynamic_shape = {
285
+ "input_ids": {
286
+ "min_shape": [1, text_encoder_max_length],
287
+ "max_shape": [1, text_encoder_max_length],
288
+ "opt_shape": [1, text_encoder_max_length],
289
+ }
290
+ }
291
+
292
+ vae_encoder_dynamic_shape = {
293
+ "sample": {
294
+ "min_shape": [1, 3, min_image_size, min_image_size],
295
+ "max_shape": [1, 3, max_image_size, max_image_size],
296
+ "opt_shape": [1, 3, min_image_size, min_image_size],
297
+ }
298
+ }
299
+
300
+ vae_decoder_dynamic_shape = {
301
+ "latent_sample": {
302
+ "min_shape": [1, vae_in_channels, min_image_size // 8, min_image_size // 8],
303
+ "max_shape": [1, vae_in_channels, max_image_size // 8, max_image_size // 8],
304
+ "opt_shape": [1, vae_in_channels, min_image_size // 8, min_image_size // 8],
305
+ }
306
+ }
307
+
308
+ unet_dynamic_shape = {
309
+ "sample": {
310
+ "min_shape": [
311
+ 1,
312
+ unet_in_channels,
313
+ min_image_size // 8,
314
+ min_image_size // 8,
315
+ ],
316
+ "max_shape": [
317
+ bs,
318
+ unet_in_channels,
319
+ max_image_size // 8,
320
+ max_image_size // 8,
321
+ ],
322
+ "opt_shape": [
323
+ 2,
324
+ unet_in_channels,
325
+ min_image_size // 8,
326
+ min_image_size // 8,
327
+ ],
328
+ },
329
+ "timestep": {
330
+ "min_shape": [1],
331
+ "max_shape": [1],
332
+ "opt_shape": [1],
333
+ },
334
+ "encoder_hidden_states": {
335
+ "min_shape": [1, text_encoder_max_length, hidden_states],
336
+ "max_shape": [bs, unet_max_length, hidden_states],
337
+ "opt_shape": [2, text_encoder_max_length, hidden_states],
338
+ },
339
+ }
340
+ # 4. Init runtime
341
+ if args.backend == "onnx_runtime":
342
+ runtime_options = dict(
343
+ text_encoder=create_ort_runtime(device_id=args.device_id),
344
+ vae_encoder=create_ort_runtime(device_id=args.device_id),
345
+ vae_decoder=create_ort_runtime(device_id=args.device_id),
346
+ unet=create_ort_runtime(device_id=args.device_id),
347
+ )
348
+ elif args.backend == "paddlelite":
349
+ runtime_options = dict(
350
+ text_encoder=create_paddle_lite_runtime(device=args.device, device_id=args.device_id, use_fp16=False),
351
+ vae_encoder=create_paddle_lite_runtime(device=args.device, device_id=args.device_id, use_fp16=False),
352
+ vae_decoder=create_paddle_lite_runtime(device=args.device, device_id=args.device_id, use_fp16=False),
353
+ unet=create_paddle_lite_runtime(device=args.device, device_id=args.device_id, use_fp16=args.use_fp16),
354
+ )
355
+ elif args.backend == "tensorrt":
356
+ runtime_options = dict(
357
+ text_encoder=create_trt_runtime(
358
+ dynamic_shape=text_encoder_dynamic_shape,
359
+ use_fp16=args.use_fp16,
360
+ device_id=args.device_id,
361
+ ),
362
+ vae_encoder=create_trt_runtime(
363
+ dynamic_shape=vae_encoder_dynamic_shape,
364
+ use_fp16=args.use_fp16,
365
+ device_id=args.device_id,
366
+ ),
367
+ vae_decoder=create_trt_runtime(
368
+ dynamic_shape=vae_decoder_dynamic_shape,
369
+ use_fp16=args.use_fp16,
370
+ device_id=args.device_id,
371
+ ),
372
+ unet=create_trt_runtime(
373
+ dynamic_shape=unet_dynamic_shape,
374
+ use_fp16=args.use_fp16,
375
+ device_id=args.device_id,
376
+ ),
377
+ )
378
+ elif args.backend == "paddle" or args.backend == "paddle_tensorrt":
379
+ args.use_trt = args.backend == "paddle_tensorrt"
380
+ runtime_options = dict(
381
+ text_encoder=create_paddle_inference_runtime(
382
+ use_trt=args.use_trt,
383
+ dynamic_shape=text_encoder_dynamic_shape,
384
+ use_fp16=args.use_fp16,
385
+ use_bf16=args.use_bf16,
386
+ device_id=args.device_id,
387
+ disable_paddle_trt_ops=["arg_max", "range", "lookup_table_v2"],
388
+ paddle_stream=paddle_stream,
389
+ ),
390
+ vae_encoder=create_paddle_inference_runtime(
391
+ use_trt=args.use_trt,
392
+ dynamic_shape=vae_encoder_dynamic_shape,
393
+ use_fp16=args.use_fp16,
394
+ use_bf16=args.use_bf16,
395
+ device_id=args.device_id,
396
+ paddle_stream=paddle_stream,
397
+ ),
398
+ vae_decoder=create_paddle_inference_runtime(
399
+ use_trt=args.use_trt,
400
+ dynamic_shape=vae_decoder_dynamic_shape,
401
+ use_fp16=args.use_fp16,
402
+ use_bf16=args.use_bf16,
403
+ device_id=args.device_id,
404
+ paddle_stream=paddle_stream,
405
+ ),
406
+ unet=create_paddle_inference_runtime(
407
+ use_trt=args.use_trt,
408
+ dynamic_shape=unet_dynamic_shape,
409
+ use_fp16=args.use_fp16,
410
+ use_bf16=args.use_bf16,
411
+ device_id=args.device_id,
412
+ paddle_stream=paddle_stream,
413
+ ),
414
+ )
415
+ pipe = FastDeployStableDiffusionMegaPipeline.from_pretrained(
416
+ args.model_dir,
417
+ runtime_options=runtime_options,
418
+ )
419
+ pipe.set_progress_bar_config(disable=True)
420
+ pipe.change_scheduler(args.scheduler)
421
+ parse_prompt_type = args.parse_prompt_type
422
+ width = args.width
423
+ height = args.height
424
+ hr_resize_width = args.hr_resize_width
425
+ hr_resize_height = args.hr_resize_height
426
+
427
+ if args.infer_op == "all":
428
+ infer_op_list = ["zero_copy_infer", "raw"]
429
+ else:
430
+ infer_op_list = [args.infer_op]
431
+ if args.device == "kunlunxin_xpu" or args.backend == "paddle":
432
+ print("When device is kunlunxin_xpu or backend is paddle, we will use `raw` infer op.")
433
+ infer_op_list = ["raw"]
434
+
435
+ for infer_op in infer_op_list:
436
+ infer_op_dict = {
437
+ "vae_encoder": infer_op,
438
+ "vae_decoder": infer_op,
439
+ "text_encoder": infer_op,
440
+ "unet": infer_op,
441
+ }
442
+ folder = f"infer_op_{infer_op}_fp16" if args.use_fp16 else f"infer_op_{infer_op}_fp32"
443
+ os.makedirs(folder, exist_ok=True)
444
+ if args.task_name in ["text2img", "all"]:
445
+ # text2img
446
+ prompt = "a photo of an astronaut riding a horse on mars"
447
+ time_costs = []
448
+ # warmup
449
+ pipe.text2img(
450
+ prompt,
451
+ num_inference_steps=10,
452
+ height=height,
453
+ width=width,
454
+ parse_prompt_type=parse_prompt_type,
455
+ infer_op_dict=infer_op_dict,
456
+ )
457
+ print("==> Test text2img performance.")
458
+ for step in trange(args.benchmark_steps):
459
+ start = time.time()
460
+ paddle.seed(seed)
461
+ images = pipe.text2img(
462
+ prompt,
463
+ num_inference_steps=args.inference_steps,
464
+ height=height,
465
+ width=width,
466
+ parse_prompt_type=parse_prompt_type,
467
+ infer_op_dict=infer_op_dict,
468
+ ).images
469
+ latency = time.time() - start
470
+ time_costs += [latency]
471
+ # print(f"No {step:3d} time cost: {latency:2f} s")
472
+ print(
473
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
474
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
475
+ )
476
+ images[0].save(f"{folder}/text2img.png")
477
+
478
+ if args.task_name in ["img2img", "all"]:
479
+ # img2img
480
+ img_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/sketch-mountains-input.png"
481
+ init_image = load_image(img_url)
482
+ prompt = "A fantasy landscape, trending on artstation"
483
+ time_costs = []
484
+ # warmup
485
+ pipe.img2img(
486
+ prompt,
487
+ image=init_image,
488
+ num_inference_steps=20,
489
+ height=height,
490
+ width=width,
491
+ parse_prompt_type=parse_prompt_type,
492
+ infer_op_dict=infer_op_dict,
493
+ )
494
+ print("==> Test img2img performance.")
495
+ for step in trange(args.benchmark_steps):
496
+ start = time.time()
497
+ paddle.seed(seed)
498
+ images = pipe.img2img(
499
+ prompt,
500
+ image=init_image,
501
+ num_inference_steps=args.inference_steps,
502
+ height=height,
503
+ width=width,
504
+ parse_prompt_type=parse_prompt_type,
505
+ infer_op_dict=infer_op_dict,
506
+ ).images
507
+ latency = time.time() - start
508
+ time_costs += [latency]
509
+ # print(f"No {step:3d} time cost: {latency:2f} s")
510
+ print(
511
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
512
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
513
+ )
514
+ images[0].save(f"{folder}/img2img.png")
515
+
516
+ if args.task_name in ["inpaint", "inpaint_legacy", "all"]:
517
+ img_url = (
518
+ "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
519
+ )
520
+ mask_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations-mask.png"
521
+ init_image = load_image(img_url)
522
+ mask_image = load_image(mask_url)
523
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
524
+ time_costs = []
525
+ # warmup
526
+ if args.task_name in ["inpaint_legacy", "all"]:
527
+ call_fn = pipe.inpaint_legacy
528
+ task_name = "inpaint_legacy"
529
+ else:
530
+ call_fn = pipe.inpaint
531
+ task_name = "inpaint"
532
+ call_fn(
533
+ prompt,
534
+ image=init_image,
535
+ mask_image=mask_image,
536
+ num_inference_steps=20,
537
+ height=height,
538
+ width=width,
539
+ parse_prompt_type=parse_prompt_type,
540
+ infer_op_dict=infer_op_dict,
541
+ )
542
+ print(f"==> Test {task_name} performance.")
543
+ for step in trange(args.benchmark_steps):
544
+ start = time.time()
545
+ paddle.seed(seed)
546
+ images = call_fn(
547
+ prompt,
548
+ image=init_image,
549
+ mask_image=mask_image,
550
+ num_inference_steps=args.inference_steps,
551
+ height=height,
552
+ width=width,
553
+ parse_prompt_type=parse_prompt_type,
554
+ infer_op_dict=infer_op_dict,
555
+ ).images
556
+ latency = time.time() - start
557
+ time_costs += [latency]
558
+ # print(f"No {step:3d} time cost: {latency:2f} s")
559
+ print(
560
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
561
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
562
+ )
563
+
564
+ images[0].save(f"{folder}/{task_name}.png")
565
+
566
+ if args.task_name in ["hiresfix", "all"]:
567
+ hiresfix_pipe = DiffusionPipeline.from_pretrained(
568
+ args.model_dir,
569
+ vae_encoder=pipe.vae_encoder,
570
+ vae_decoder=pipe.vae_decoder,
571
+ text_encoder=pipe.text_encoder,
572
+ tokenizer=pipe.tokenizer,
573
+ unet=pipe.unet,
574
+ scheduler=pipe.scheduler,
575
+ safety_checker=pipe.safety_checker,
576
+ feature_extractor=pipe.feature_extractor,
577
+ requires_safety_checker=pipe.requires_safety_checker,
578
+ custom_pipeline="pipeline_fastdeploy_stable_diffusion_hires_fix",
579
+ )
580
+ # custom_pipeline
581
+ # https://github.com/PaddlePaddle/PaddleNLP/blob/develop/ppdiffusers/examples/community/pipeline_fastdeploy_stable_diffusion_hires_fix.py
582
+ hiresfix_pipe._progress_bar_config = pipe._progress_bar_config
583
+ # hiresfix
584
+ prompt = "a photo of an astronaut riding a horse on mars"
585
+ time_costs = []
586
+ # warmup
587
+ hiresfix_pipe(
588
+ prompt,
589
+ height=height,
590
+ width=width,
591
+ num_inference_steps=20,
592
+ hires_ratio=0.5,
593
+ hr_resize_width=hr_resize_width,
594
+ hr_resize_height=hr_resize_height,
595
+ enable_hr=True,
596
+ parse_prompt_type=parse_prompt_type,
597
+ infer_op_dict=infer_op_dict,
598
+ )
599
+ print("==> Test hiresfix performance.")
600
+ for step in trange(args.benchmark_steps):
601
+ start = time.time()
602
+ paddle.seed(seed)
603
+ images = hiresfix_pipe(
604
+ prompt,
605
+ height=height,
606
+ width=width,
607
+ num_inference_steps=args.inference_steps,
608
+ hires_ratio=0.5,
609
+ hr_resize_width=hr_resize_width,
610
+ hr_resize_height=hr_resize_height,
611
+ enable_hr=True,
612
+ infer_op_dict=infer_op_dict,
613
+ ).images
614
+ latency = time.time() - start
615
+ time_costs += [latency]
616
+ # print(f"No {step:3d} time cost: {latency:2f} s")
617
+ print(
618
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
619
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
620
+ )
621
+ images[0].save(f"{folder}/hiresfix.png")
622
+
623
+ if args.task_name in ["cycle_diffusion"]:
624
+ pipe.change_scheduler("ddim")
625
+ image_url = (
626
+ "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/ride_on_horse.png"
627
+ )
628
+ init_image = load_image(image_url)
629
+ source_prompt = "An astronaut riding a horse"
630
+ prompt = "An astronaut riding an elephant"
631
+ time_costs = []
632
+ # warmup
633
+ pipe.cycle_diffusion(
634
+ prompt=prompt,
635
+ source_prompt=source_prompt,
636
+ image=init_image,
637
+ num_inference_steps=10,
638
+ eta=0.1,
639
+ strength=0.8,
640
+ guidance_scale=2,
641
+ source_guidance_scale=1,
642
+ height=height,
643
+ width=width,
644
+ parse_prompt_type=parse_prompt_type,
645
+ infer_op_dict=infer_op_dict,
646
+ ).images[0]
647
+ print("==> Test cycle diffusion performance.")
648
+ for step in trange(args.benchmark_steps):
649
+ start = time.time()
650
+ paddle.seed(seed)
651
+ images = pipe.cycle_diffusion(
652
+ prompt=prompt,
653
+ source_prompt=source_prompt,
654
+ image=init_image,
655
+ num_inference_steps=args.inference_steps,
656
+ eta=0.1,
657
+ strength=0.8,
658
+ guidance_scale=2,
659
+ source_guidance_scale=1,
660
+ height=height,
661
+ width=width,
662
+ parse_prompt_type=parse_prompt_type,
663
+ infer_op_dict=infer_op_dict,
664
+ ).images
665
+ latency = time.time() - start
666
+ time_costs += [latency]
667
+ # print(f"No {step:3d} time cost: {latency:2f} s")
668
+ print(
669
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
670
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
671
+ )
672
+ images[0].save(f"{folder}/cycle_diffusion.png")
673
+
674
+ if args.task_name in ["mixture_tiling"]:
675
+ mixture_tiling_pipe = DiffusionPipeline.from_pretrained(
676
+ args.model_dir,
677
+ vae_encoder=pipe.vae_encoder,
678
+ vae_decoder=pipe.vae_decoder,
679
+ text_encoder=pipe.text_encoder,
680
+ tokenizer=pipe.tokenizer,
681
+ unet=pipe.unet,
682
+ scheduler=pipe.scheduler,
683
+ safety_checker=pipe.safety_checker,
684
+ feature_extractor=pipe.feature_extractor,
685
+ requires_safety_checker=pipe.requires_safety_checker,
686
+ custom_pipeline="pipeline_fastdeploy_stable_diffusion_mixture_tiling",
687
+ )
688
+ # custom_pipeline
689
+ mixture_tiling_pipe._progress_bar_config = pipe._progress_bar_config
690
+ # mixture_tiling
691
+ time_costs = []
692
+ # warmup
693
+ mixture_tiling_pipe(
694
+ prompt=[
695
+ [
696
+ "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
697
+ # "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
698
+ # "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
699
+ ]
700
+ ],
701
+ tile_height=512,
702
+ tile_width=512,
703
+ tile_row_overlap=0,
704
+ tile_col_overlap=0,
705
+ guidance_scale=8,
706
+ seed=7178915308,
707
+ num_inference_steps=50,
708
+ infer_op_dict=None,
709
+ )
710
+ print("==> Test mixture tiling.")
711
+ for step in trange(args.benchmark_steps):
712
+ start = time.time()
713
+ images = mixture_tiling_pipe(
714
+ prompt=[
715
+ [
716
+ "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
717
+ # "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
718
+ # "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
719
+ ]
720
+ ],
721
+ tile_height=512,
722
+ tile_width=512,
723
+ tile_row_overlap=0,
724
+ tile_col_overlap=0,
725
+ guidance_scale=8,
726
+ seed=7178915308,
727
+ num_inference_steps=50,
728
+ infer_op_dict=None,
729
+ )["images"]
730
+ latency = time.time() - start
731
+ time_costs += [latency]
732
+ # print(f"No {step:3d} time cost: {latency:2f} s")
733
+ print(
734
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
735
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
736
+ )
737
+ images[0].save(f"{folder}/mixture_tiling.png")
738
+
739
+
740
+ if __name__ == "__main__":
741
+ args = parse_arguments()
742
+ main(args)
PaddleMIX/ppdiffusers/deploy-deprecated/infer_dygraph.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ import time
18
+ import warnings
19
+
20
+ import numpy as np
21
+ import paddle
22
+ from paddlenlp.trainer.argparser import strtobool
23
+ from paddlenlp.utils.log import logger
24
+ from tqdm.auto import trange
25
+
26
+ from ppdiffusers import DiffusionPipeline
27
+ from ppdiffusers.utils import load_image
28
+
29
+ logger.set_level("WARNING")
30
+
31
+
32
+ def parse_arguments():
33
+
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "--model_dir",
37
+ default="runwayml/stable-diffusion-v1-5",
38
+ help="The model directory of diffusion_model.",
39
+ )
40
+ parser.add_argument(
41
+ "--inference_steps",
42
+ type=int,
43
+ default=50,
44
+ help="The number of unet inference steps.",
45
+ )
46
+ parser.add_argument(
47
+ "--benchmark_steps",
48
+ type=int,
49
+ default=1,
50
+ help="The number of performance benchmark steps.",
51
+ )
52
+ parser.add_argument(
53
+ "--task_name",
54
+ type=str,
55
+ default="text2img",
56
+ choices=[
57
+ "text2img",
58
+ "img2img",
59
+ "inpaint",
60
+ "inpaint_legacy",
61
+ "cycle_diffusion",
62
+ "hiresfix",
63
+ "all",
64
+ ],
65
+ help="The task can be one of [text2img, img2img, inpaint, inpaint_legacy, cycle_diffusion, hiresfix, all]. ",
66
+ )
67
+ parser.add_argument(
68
+ "--parse_prompt_type",
69
+ type=str,
70
+ default="lpw",
71
+ choices=[
72
+ "raw",
73
+ "lpw",
74
+ ],
75
+ help="The parse_prompt_type can be one of [raw, lpw]. ",
76
+ )
77
+ parser.add_argument("--use_fp16", type=strtobool, default=True, help="Wheter to use FP16 mode")
78
+ parser.add_argument(
79
+ "--attention_type",
80
+ type=str,
81
+ default="raw",
82
+ choices=["raw", "cutlass", "flash", "all"],
83
+ help="attention_type.",
84
+ )
85
+ parser.add_argument("--device_id", type=int, default=0, help="The selected gpu id. -1 means use cpu")
86
+ parser.add_argument(
87
+ "--scheduler",
88
+ type=str,
89
+ default="euler-ancestral",
90
+ choices=[
91
+ "pndm",
92
+ "lms",
93
+ "euler",
94
+ "euler-ancestral",
95
+ "dpm-multi",
96
+ "dpm-single",
97
+ "unipc-multi",
98
+ "ddim",
99
+ "ddpm",
100
+ "deis-multi",
101
+ "heun",
102
+ "kdpm2-ancestral",
103
+ "kdpm2",
104
+ ],
105
+ help="The scheduler type of stable diffusion.",
106
+ )
107
+ parser.add_argument("--height", type=int, default=512, help="Height of input image")
108
+ parser.add_argument("--width", type=int, default=512, help="Width of input image")
109
+ parser.add_argument("--hr_resize_height", type=int, default=768, help="HR Height of input image")
110
+ parser.add_argument("--hr_resize_width", type=int, default=768, help="HR Width of input image")
111
+ return parser.parse_args()
112
+
113
+
114
+ def main(args):
115
+ if args.device_id == -1:
116
+ paddle.set_device("cpu")
117
+ else:
118
+ paddle.set_device(f"gpu:{args.device_id}")
119
+
120
+ seed = 1024
121
+ paddle_dtype = paddle.float16 if args.use_fp16 else paddle.float32
122
+ print(
123
+ os.path.join(os.path.abspath(os.path.join(os.getcwd(), "..")), "examples/community/stable_diffusion_mega.py")
124
+ )
125
+ pipe = DiffusionPipeline.from_pretrained(
126
+ args.model_dir,
127
+ safety_checker=None,
128
+ feature_extractor=None,
129
+ requires_safety_checker=False,
130
+ paddle_dtype=paddle_dtype,
131
+ custom_pipeline=os.path.join(
132
+ os.path.abspath(os.path.join(os.getcwd(), "..")), "examples/community/stable_diffusion_mega.py"
133
+ ),
134
+ )
135
+ pipe.set_progress_bar_config(disable=True)
136
+ pipe.change_scheduler(args.scheduler)
137
+ parse_prompt_type = args.parse_prompt_type
138
+ if args.attention_type == "all":
139
+ args.attention_type = ["raw", "cutlass", "flash"]
140
+ else:
141
+ args.attention_type = [args.attention_type]
142
+
143
+ for attention_type in args.attention_type:
144
+ if attention_type == "raw":
145
+ pipe.disable_xformers_memory_efficient_attention()
146
+ else:
147
+ try:
148
+ pipe.enable_xformers_memory_efficient_attention(attention_type)
149
+ except Exception as e:
150
+ if attention_type == "flash":
151
+ warnings.warn(
152
+ "Attention type flash is not supported on your GPU! We need to use 3060、3070、3080、3090、4060、4070、4080、4090、A30、A100 etc."
153
+ )
154
+ continue
155
+ else:
156
+ raise ValueError(e)
157
+
158
+ if not args.use_fp16 and attention_type == "flash":
159
+ print("Flash attention is not supported dtype=float32! Please use float16 or bfloat16. We will skip this!")
160
+ continue
161
+ width = args.width
162
+ height = args.height
163
+ hr_resize_width = args.hr_resize_width
164
+ hr_resize_height = args.hr_resize_height
165
+ folder = f"attn_{attention_type}_fp16" if args.use_fp16 else f"attn_{attention_type}_fp32"
166
+ os.makedirs(folder, exist_ok=True)
167
+ if args.task_name in ["text2img", "all"]:
168
+ # text2img
169
+ prompt = "a photo of an astronaut riding a horse on mars"
170
+ time_costs = []
171
+ # warmup
172
+ pipe.text2img(
173
+ prompt,
174
+ num_inference_steps=10,
175
+ height=height,
176
+ width=width,
177
+ parse_prompt_type=parse_prompt_type,
178
+ )
179
+ print("==> Test text2img performance.")
180
+ paddle.seed(seed)
181
+ for step in trange(args.benchmark_steps):
182
+ start = time.time()
183
+ images = pipe.text2img(
184
+ prompt,
185
+ num_inference_steps=args.inference_steps,
186
+ height=height,
187
+ width=width,
188
+ parse_prompt_type=parse_prompt_type,
189
+ ).images
190
+ latency = time.time() - start
191
+ time_costs += [latency]
192
+ # print(f"No {step:3d} time cost: {latency:2f} s")
193
+ print(
194
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
195
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
196
+ )
197
+ images[0].save(f"{folder}/text2img.png")
198
+
199
+ if args.task_name in ["img2img", "all"]:
200
+ # img2img
201
+ img_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/sketch-mountains-input.png"
202
+ init_image = load_image(img_url)
203
+ prompt = "A fantasy landscape, trending on artstation"
204
+ time_costs = []
205
+ # warmup
206
+ pipe.img2img(
207
+ prompt,
208
+ image=init_image,
209
+ num_inference_steps=20,
210
+ height=height,
211
+ width=width,
212
+ parse_prompt_type=parse_prompt_type,
213
+ )
214
+ print("==> Test img2img performance.")
215
+ for step in trange(args.benchmark_steps):
216
+ start = time.time()
217
+ paddle.seed(seed)
218
+ images = pipe.img2img(
219
+ prompt,
220
+ image=init_image,
221
+ num_inference_steps=args.inference_steps,
222
+ height=height,
223
+ width=width,
224
+ parse_prompt_type=parse_prompt_type,
225
+ ).images
226
+ latency = time.time() - start
227
+ time_costs += [latency]
228
+ # print(f"No {step:3d} time cost: {latency:2f} s")
229
+ print(
230
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
231
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
232
+ )
233
+ images[0].save(f"{folder}/img2img.png")
234
+
235
+ if args.task_name in ["inpaint", "inpaint_legacy", "all"]:
236
+ img_url = (
237
+ "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
238
+ )
239
+ mask_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations-mask.png"
240
+ init_image = load_image(img_url)
241
+ mask_image = load_image(mask_url)
242
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
243
+ time_costs = []
244
+ # warmup
245
+ if args.task_name in ["inpaint_legacy", "all"]:
246
+ call_fn = pipe.inpaint_legacy
247
+ task_name = "inpaint_legacy"
248
+ else:
249
+ call_fn = pipe.inpaint
250
+ task_name = args.task_name
251
+ if pipe.unet.config.in_channels == 4:
252
+ task_name = "inpaint_legacy"
253
+ elif pipe.unet.config.in_channels == 9:
254
+ task_name = "inpaint"
255
+
256
+ call_fn(
257
+ prompt,
258
+ image=init_image,
259
+ mask_image=mask_image,
260
+ num_inference_steps=20,
261
+ height=height,
262
+ width=width,
263
+ parse_prompt_type=parse_prompt_type,
264
+ )
265
+ print(f"==> Test {task_name} performance.")
266
+ for step in trange(args.benchmark_steps):
267
+ start = time.time()
268
+ paddle.seed(seed)
269
+ images = call_fn(
270
+ prompt,
271
+ image=init_image,
272
+ mask_image=mask_image,
273
+ num_inference_steps=args.inference_steps,
274
+ height=height,
275
+ width=width,
276
+ parse_prompt_type=parse_prompt_type,
277
+ ).images
278
+ latency = time.time() - start
279
+ time_costs += [latency]
280
+ # print(f"No {step:3d} time cost: {latency:2f} s")
281
+ print(
282
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
283
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
284
+ )
285
+
286
+ images[0].save(f"{folder}/{task_name}.png")
287
+
288
+ if args.task_name in ["cycle_diffusion", "all"]:
289
+ pipe.change_scheduler("ddim")
290
+ image_url = (
291
+ "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/ride_on_horse.png"
292
+ )
293
+ init_image = load_image(image_url)
294
+ source_prompt = "An astronaut riding a horse"
295
+ prompt = "An astronaut riding an elephant"
296
+ time_costs = []
297
+ # warmup
298
+ pipe.cycle_diffusion(
299
+ prompt=prompt,
300
+ source_prompt=source_prompt,
301
+ image=init_image,
302
+ num_inference_steps=10,
303
+ eta=0.1,
304
+ strength=0.8,
305
+ guidance_scale=2,
306
+ source_guidance_scale=1,
307
+ height=height,
308
+ width=width,
309
+ parse_prompt_type=parse_prompt_type,
310
+ ).images[0]
311
+ print("==> Test cycle diffusion performance.")
312
+ for step in trange(args.benchmark_steps):
313
+ start = time.time()
314
+ paddle.seed(seed)
315
+ images = pipe.cycle_diffusion(
316
+ prompt=prompt,
317
+ source_prompt=source_prompt,
318
+ image=init_image,
319
+ num_inference_steps=args.inference_steps,
320
+ eta=0.1,
321
+ strength=0.8,
322
+ guidance_scale=2,
323
+ source_guidance_scale=1,
324
+ height=height,
325
+ width=width,
326
+ parse_prompt_type=parse_prompt_type,
327
+ ).images
328
+ latency = time.time() - start
329
+ time_costs += [latency]
330
+ # print(f"No {step:3d} time cost: {latency:2f} s")
331
+ print(
332
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
333
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
334
+ )
335
+ images[0].save(f"{folder}/cycle_diffusion.png")
336
+
337
+ if args.task_name in ["hiresfix", "all"]:
338
+ # hiresfix
339
+ prompt = "a photo of an astronaut riding a horse on mars"
340
+ time_costs = []
341
+ # warmup
342
+ pipe.hires_fix(
343
+ prompt,
344
+ height=height,
345
+ width=width,
346
+ num_inference_steps=20,
347
+ hires_ratio=0.5,
348
+ hr_resize_width=hr_resize_width,
349
+ hr_resize_height=hr_resize_height,
350
+ enable_hr=True,
351
+ parse_prompt_type=parse_prompt_type,
352
+ )
353
+ print("==> Test hiresfix performance.")
354
+ for step in trange(args.benchmark_steps):
355
+ start = time.time()
356
+ paddle.seed(seed)
357
+ images = pipe.hires_fix(
358
+ prompt,
359
+ height=height,
360
+ width=width,
361
+ num_inference_steps=args.inference_steps,
362
+ hires_ratio=0.5,
363
+ hr_resize_width=hr_resize_width,
364
+ hr_resize_height=hr_resize_height,
365
+ enable_hr=True,
366
+ parse_prompt_type=parse_prompt_type,
367
+ ).images
368
+ latency = time.time() - start
369
+ time_costs += [latency]
370
+ # print(f"No {step:3d} time cost: {latency:2f} s")
371
+ print(
372
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
373
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
374
+ )
375
+ images[0].save(f"{folder}/hiresfix.png")
376
+
377
+
378
+ if __name__ == "__main__":
379
+ args = parse_arguments()
380
+ main(args)
PaddleMIX/ppdiffusers/deploy-deprecated/infer_dygraph_torch.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import os
17
+ import time
18
+
19
+ import torch
20
+
21
+ torch.nn.functional.scaled_dot_product_attention_ = torch.nn.functional.scaled_dot_product_attention
22
+ delattr(torch.nn.functional, "scaled_dot_product_attention")
23
+ import numpy as np
24
+ from diffusers import (
25
+ CycleDiffusionPipeline,
26
+ DDIMScheduler,
27
+ DDPMScheduler,
28
+ DEISMultistepScheduler,
29
+ DiffusionPipeline,
30
+ DPMSolverMultistepScheduler,
31
+ DPMSolverSinglestepScheduler,
32
+ EulerAncestralDiscreteScheduler,
33
+ EulerDiscreteScheduler,
34
+ HeunDiscreteScheduler,
35
+ KDPM2AncestralDiscreteScheduler,
36
+ KDPM2DiscreteScheduler,
37
+ LMSDiscreteScheduler,
38
+ PNDMScheduler,
39
+ UniPCMultistepScheduler,
40
+ )
41
+ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
42
+ from diffusers.utils import load_image
43
+ from tqdm.auto import trange
44
+
45
+
46
+ def strtobool(v):
47
+ if isinstance(v, bool):
48
+ return v
49
+ if v.lower() in ("yes", "true", "t", "y", "1"):
50
+ return True
51
+ elif v.lower() in ("no", "false", "f", "n", "0"):
52
+ return False
53
+ else:
54
+ raise ValueError(
55
+ f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
56
+ )
57
+
58
+
59
+ def change_scheduler(self, scheduler_type="ddim"):
60
+ self.original_scheduler_config = self.scheduler.config
61
+ scheduler_type = scheduler_type.lower()
62
+ if scheduler_type == "pndm":
63
+ scheduler = PNDMScheduler.from_config(self.original_scheduler_config, skip_prk_steps=True)
64
+ elif scheduler_type == "lms":
65
+ scheduler = LMSDiscreteScheduler.from_config(self.original_scheduler_config)
66
+ elif scheduler_type == "heun":
67
+ scheduler = HeunDiscreteScheduler.from_config(self.original_scheduler_config)
68
+ elif scheduler_type == "euler":
69
+ scheduler = EulerDiscreteScheduler.from_config(self.original_scheduler_config)
70
+ elif scheduler_type == "euler-ancestral":
71
+ scheduler = EulerAncestralDiscreteScheduler.from_config(self.original_scheduler_config)
72
+ elif scheduler_type == "dpm-multi":
73
+ scheduler = DPMSolverMultistepScheduler.from_config(self.original_scheduler_config)
74
+ elif scheduler_type == "dpm-single":
75
+ scheduler = DPMSolverSinglestepScheduler.from_config(self.original_scheduler_config)
76
+ elif scheduler_type == "kdpm2-ancestral":
77
+ scheduler = KDPM2AncestralDiscreteScheduler.from_config(self.original_scheduler_config)
78
+ elif scheduler_type == "kdpm2":
79
+ scheduler = KDPM2DiscreteScheduler.from_config(self.original_scheduler_config)
80
+ elif scheduler_type == "unipc-multi":
81
+ scheduler = UniPCMultistepScheduler.from_config(self.original_scheduler_config)
82
+ elif scheduler_type == "ddim":
83
+ scheduler = DDIMScheduler.from_config(
84
+ self.original_scheduler_config,
85
+ steps_offset=1,
86
+ clip_sample=False,
87
+ set_alpha_to_one=False,
88
+ )
89
+ elif scheduler_type == "ddpm":
90
+ scheduler = DDPMScheduler.from_config(
91
+ self.original_scheduler_config,
92
+ )
93
+ elif scheduler_type == "deis-multi":
94
+ scheduler = DEISMultistepScheduler.from_config(
95
+ self.original_scheduler_config,
96
+ )
97
+ else:
98
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
99
+ return scheduler
100
+
101
+
102
+ def parse_arguments():
103
+
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument(
106
+ "--pretrained_model_name_or_path",
107
+ default="runwayml/stable-diffusion-v1-5",
108
+ help="The model directory of diffusion_model.",
109
+ )
110
+ parser.add_argument(
111
+ "--inference_steps",
112
+ type=int,
113
+ default=50,
114
+ help="The number of unet inference steps.",
115
+ )
116
+ parser.add_argument(
117
+ "--benchmark_steps",
118
+ type=int,
119
+ default=10,
120
+ help="The number of performance benchmark steps.",
121
+ )
122
+ parser.add_argument(
123
+ "--task_name",
124
+ type=str,
125
+ default="all",
126
+ choices=[
127
+ "text2img",
128
+ "img2img",
129
+ "inpaint",
130
+ "inpaint_legacy",
131
+ "cycle_diffusion",
132
+ "all",
133
+ ],
134
+ help="The task can be one of [text2img, img2img, inpaint, inpaint_legacy, cycle_diffusion, hiresfix, all]. ",
135
+ )
136
+ parser.add_argument(
137
+ "--parse_prompt_type",
138
+ type=str,
139
+ default="raw",
140
+ choices=[
141
+ "raw",
142
+ "lpw",
143
+ ],
144
+ help="The parse_prompt_type can be one of [raw, lpw]. ",
145
+ )
146
+ parser.add_argument(
147
+ "--channels_last",
148
+ type=strtobool,
149
+ default=False,
150
+ help="Wheter to use channels_last",
151
+ )
152
+ parser.add_argument("--use_fp16", type=strtobool, default=True, help="Wheter to use FP16 mode")
153
+ parser.add_argument("--tf32", type=strtobool, default=True, help="tf32")
154
+ parser.add_argument("--compile", type=strtobool, default=False, help="compile")
155
+ parser.add_argument(
156
+ "--attention_type",
157
+ type=str,
158
+ default="sdp",
159
+ choices=[
160
+ "raw",
161
+ "sdp",
162
+ ],
163
+ help="attention_type.",
164
+ )
165
+ parser.add_argument("--device_id", type=int, default=0, help="The selected gpu id. -1 means use cpu")
166
+ parser.add_argument(
167
+ "--scheduler",
168
+ type=str,
169
+ default="euler-ancestral",
170
+ choices=[
171
+ "pndm",
172
+ "lms",
173
+ "euler",
174
+ "euler-ancestral",
175
+ "dpm-multi",
176
+ "dpm-single",
177
+ "unipc-multi",
178
+ "ddim",
179
+ "ddpm",
180
+ "deis-multi",
181
+ "heun",
182
+ "kdpm2-ancestral",
183
+ "kdpm2",
184
+ ],
185
+ help="The scheduler type of stable diffusion.",
186
+ )
187
+ parser.add_argument("--height", type=int, default=512, help="Height of input image")
188
+ parser.add_argument("--width", type=int, default=512, help="Width of input image")
189
+ return parser.parse_args()
190
+
191
+
192
+ def attn_processors(self):
193
+ processors = {}
194
+
195
+ def fn_recursive_add_processors(name: str, module, processors):
196
+ if hasattr(module, "set_processor"):
197
+ processors[f"{name}.processor"] = module.processor
198
+
199
+ for sub_name, child in module.named_children():
200
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
201
+
202
+ return processors
203
+
204
+ for name, module in self.named_children():
205
+ fn_recursive_add_processors(name, module, processors)
206
+
207
+ return processors
208
+
209
+
210
+ def set_attn_processor(self, processor):
211
+ count = len(attn_processors(self).keys())
212
+
213
+ if isinstance(processor, dict) and len(processor) != count:
214
+ raise ValueError(
215
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
216
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
217
+ )
218
+
219
+ def fn_recursive_attn_processor(name: str, module, processor):
220
+ if hasattr(module, "set_processor"):
221
+ if not isinstance(processor, dict):
222
+ module.set_processor(processor)
223
+ else:
224
+ module.set_processor(processor.pop(f"{name}.processor"))
225
+
226
+ for sub_name, child in module.named_children():
227
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
228
+
229
+ for name, module in self.named_children():
230
+ fn_recursive_attn_processor(name, module, processor)
231
+
232
+
233
+ def main(args):
234
+ if args.tf32:
235
+ torch.backends.cuda.matmul.allow_tf32 = True
236
+ else:
237
+ torch.backends.cuda.matmul.allow_tf32 = False
238
+
239
+ seed = 1024
240
+ torch_dtype = torch.float16 if args.use_fp16 else torch.float32
241
+ pipe = DiffusionPipeline.from_pretrained(
242
+ args.pretrained_model_name_or_path,
243
+ safety_checker=None,
244
+ feature_extractor=None,
245
+ requires_safety_checker=False,
246
+ torch_dtype=torch_dtype,
247
+ custom_pipeline="stable_diffusion_mega" if args.parse_prompt_type == "raw" else "lpw_stable_diffusion",
248
+ )
249
+ scheduler = change_scheduler(pipe, args.scheduler)
250
+ pipe.scheduler = scheduler
251
+ if args.device_id >= 0:
252
+ pipe.to(f"cuda:{args.device_id}")
253
+
254
+ if args.attention_type == "all":
255
+ args.attention_type = ["raw", "sdp"]
256
+ else:
257
+ args.attention_type = [args.attention_type]
258
+
259
+ for attention_type in args.attention_type:
260
+ attn_prrocessor_cls = AttnProcessor if attention_type == "raw" else AttnProcessor2_0
261
+ if attention_type == "sdp":
262
+ torch.nn.functional.scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention_
263
+
264
+ set_attn_processor(pipe.unet, attn_prrocessor_cls())
265
+ set_attn_processor(pipe.vae, attn_prrocessor_cls())
266
+ if args.channels_last:
267
+ pipe.unet.to(memory_format=torch.channels_last)
268
+
269
+ if args.compile:
270
+ print("Run torch compile")
271
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
272
+
273
+ width = args.width
274
+ height = args.height
275
+ pipe.set_progress_bar_config(disable=True)
276
+
277
+ folder = f"torch_attn_{attention_type}_fp16" if args.use_fp16 else f"torch_attn_{attention_type}_fp32"
278
+ os.makedirs(folder, exist_ok=True)
279
+ if args.task_name in ["text2img", "all"]:
280
+ # text2img
281
+ prompt = "a photo of an astronaut riding a horse on mars"
282
+ time_costs = []
283
+ # warmup
284
+ pipe.text2img(
285
+ prompt,
286
+ num_inference_steps=10,
287
+ height=height,
288
+ width=width,
289
+ )
290
+ print("==> Test text2img performance.")
291
+ torch.cuda.manual_seed(seed)
292
+ for step in trange(args.benchmark_steps):
293
+ start = time.time()
294
+ images = pipe.text2img(
295
+ prompt,
296
+ num_inference_steps=args.inference_steps,
297
+ height=height,
298
+ width=width,
299
+ ).images
300
+ latency = time.time() - start
301
+ time_costs += [latency]
302
+ # print(f"No {step:3d} time cost: {latency:2f} s")
303
+ print(
304
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
305
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
306
+ )
307
+ images[0].save(f"{folder}/text2img.png")
308
+
309
+ if args.task_name in ["img2img", "all"]:
310
+ # img2img
311
+ img_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/sketch-mountains-input.png"
312
+ init_image = load_image(img_url).resize((width, height))
313
+ prompt = "A fantasy landscape, trending on artstation"
314
+ time_costs = []
315
+ # warmup
316
+ pipe.img2img(
317
+ prompt,
318
+ image=init_image,
319
+ num_inference_steps=20,
320
+ height=height,
321
+ width=width,
322
+ )
323
+ print("==> Test img2img performance.")
324
+ for step in trange(args.benchmark_steps):
325
+ start = time.time()
326
+ torch.cuda.manual_seed(seed)
327
+ images = pipe.img2img(
328
+ prompt,
329
+ image=init_image,
330
+ num_inference_steps=args.inference_steps,
331
+ height=height,
332
+ width=width,
333
+ ).images
334
+ latency = time.time() - start
335
+ time_costs += [latency]
336
+ # print(f"No {step:3d} time cost: {latency:2f} s")
337
+ print(
338
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
339
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
340
+ )
341
+ images[0].save(f"{folder}/img2img.png")
342
+
343
+ if args.task_name in ["inpaint", "inpaint_legacy", "all"]:
344
+ img_url = (
345
+ "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations.png"
346
+ )
347
+ mask_url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/overture-creations-mask.png"
348
+ init_image = load_image(img_url).resize((width, height))
349
+ mask_image = load_image(mask_url).resize((width, height))
350
+ prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
351
+ time_costs = []
352
+ # warmup
353
+ if args.task_name in ["inpaint_legacy", "all"]:
354
+ call_fn = pipe.inpaint
355
+ task_name = "inpaint_legacy"
356
+ else:
357
+ call_fn = pipe.inpaint
358
+ task_name = args.task_name
359
+ if pipe.unet.config.in_channels == 4:
360
+ task_name = "inpaint_legacy"
361
+ elif pipe.unet.config.in_channels == 9:
362
+ task_name = "inpaint"
363
+
364
+ call_fn(
365
+ prompt,
366
+ image=init_image,
367
+ mask_image=mask_image,
368
+ num_inference_steps=20,
369
+ )
370
+ print(f"==> Test {task_name} performance.")
371
+ for step in trange(args.benchmark_steps):
372
+ start = time.time()
373
+ torch.cuda.manual_seed(seed)
374
+ images = call_fn(
375
+ prompt,
376
+ image=init_image,
377
+ mask_image=mask_image,
378
+ num_inference_steps=args.inference_steps,
379
+ ).images
380
+ latency = time.time() - start
381
+ time_costs += [latency]
382
+ # print(f"No {step:3d} time cost: {latency:2f} s")
383
+ print(
384
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
385
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
386
+ )
387
+
388
+ images[0].save(f"{folder}/{task_name}.png")
389
+
390
+ if args.task_name in ["cycle_diffusion", "all"]:
391
+ # need fix diffuers=0.17.1, self.unet return_dict=False!
392
+ cycle_pipe = CycleDiffusionPipeline(
393
+ vae=pipe.vae,
394
+ text_encoder=pipe.text_encoder,
395
+ tokenizer=pipe.tokenizer,
396
+ unet=pipe.unet,
397
+ scheduler=scheduler,
398
+ safety_checker=None,
399
+ feature_extractor=None,
400
+ requires_safety_checker=False,
401
+ )
402
+ cycle_pipe.set_progress_bar_config(disable=True)
403
+ scheduler = change_scheduler(cycle_pipe, "ddim")
404
+ cycle_pipe.scheduler = scheduler
405
+ image_url = "ride_on_horse.png"
406
+ init_image = load_image(image_url).resize((width, height))
407
+ source_prompt = "An astronaut riding a horse"
408
+ prompt = "An astronaut riding an elephant"
409
+ time_costs = []
410
+ # warmup
411
+ cycle_pipe(
412
+ prompt=prompt,
413
+ source_prompt=source_prompt,
414
+ image=init_image,
415
+ num_inference_steps=10,
416
+ eta=0.1,
417
+ strength=0.8,
418
+ guidance_scale=2,
419
+ source_guidance_scale=1,
420
+ ).images[0]
421
+ print("==> Test cycle diffusion performance.")
422
+ for step in trange(args.benchmark_steps):
423
+ start = time.time()
424
+ torch.cuda.manual_seed(seed)
425
+ images = cycle_pipe(
426
+ prompt=prompt,
427
+ source_prompt=source_prompt,
428
+ image=init_image,
429
+ num_inference_steps=args.inference_steps,
430
+ eta=0.1,
431
+ strength=0.8,
432
+ guidance_scale=2,
433
+ source_guidance_scale=1,
434
+ ).images
435
+ latency = time.time() - start
436
+ time_costs += [latency]
437
+ # print(f"No {step:3d} time cost: {latency:2f} s")
438
+ print(
439
+ f"Mean latency: {np.mean(time_costs):2f} s, p50 latency: {np.percentile(time_costs, 50):2f} s, "
440
+ f"p90 latency: {np.percentile(time_costs, 90):2f} s, p95 latency: {np.percentile(time_costs, 95):2f} s."
441
+ )
442
+ images[0].save(f"{folder}/cycle_diffusion.png")
443
+
444
+
445
+ if __name__ == "__main__":
446
+ args = parse_arguments()
447
+ main(args)
PaddleMIX/ppdiffusers/deploy-deprecated/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ppdiffusers>=0.16.3
2
+ ligo-segments
PaddleMIX/ppdiffusers/deploy/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PPDiffusers推理部署
2
+
3
+ PPDiffusers基于Paddle Inference,提供了以下重点扩散模型的部署方案:
4
+ - ControlNet
5
+ - IP-Adapter-SD15
6
+ - IP-Adapter-SDXL
7
+ - SD15
8
+ - SDXL
9
+
10
+
11
+ # V100性能数据
12
+ |模型|Paddle Deploy TensorRT / ips|Torch Dynamic / ips|
13
+ |-|-|-|
14
+ |IP-Adapter-SD15 text2img|18.30|18.18|
15
+ |IP-Adapter-SD15 img2img|18.11|17.87|
16
+ |IP-Adapter-SD15 inpaint|17.93|17.44|
17
+ |IP-Adapter-SDXL text2img|12.01|11.47|
18
+ |IP-Adapter-SDXL img2img|12.00|10.95|
19
+ |IP-Adapter-SDXL inpaint|11.67|10.79|
20
+ |SD15 text2img|19.68|18.27|
21
+ |SD15 img2img|19.68|17.90|
22
+ |SD15 inpaint|19.44|17.56|
23
+ |SDXL text2img|13.91|11.50|
24
+ |SDXL img2img|13.86|11.60|
25
+ |SDXL inpaint|13.45|11.28|
26
+
27
+ <!-- |SD15 text2img|11.87|6.68|6.32|
28
+ |SD15 img2img|14.47|8.09|7.63|
29
+ |SD15 inpaint|14.30|6.42|6.06| -->
30
+
31
+ > Note:
32
+ > 测试环境或配置为Paddle 3.0 beta版本,V100 32G单卡,FP16。
33
+ 推理参数为Image Width = 512, Image Height = 512, Num Inference Steps = 50。
34
+
35
+ # A100性能数据
36
+ |模型|Paddle Deploy TensorRT / ips|Torch Dynamic / ips|
37
+ |-|-|-|
38
+ |IP-Adapter-SD15 text2img|38.52|32.75|
39
+ |IP-Adapter-SD15 img2img|37.91|32.50|
40
+ |IP-Adapter-SD15 inpaint|37.80|31.78|
41
+ |IP-Adapter-SDXL text2img|22.88|17.26|
42
+ |IP-Adapter-SDXL img2img|22.79|17.24|
43
+ |IP-Adapter-SDXL inpaint|22.30|17.06|
44
+ |SD15 text2img|47.22|33.74|
45
+ |SD15 img2img|46.59|32.96|
46
+ |SD15 inpaint|46.05|32.14|
47
+ |SDXL text2img|31.98|17.73|
48
+ |SDXL img2img|31.80|17.40|
49
+ |SDXL inpaint|30.58|16.98|
50
+
51
+ <!-- |SD15 text2img|26.37|10.49||
52
+ |SD15 img2img|30.81|12.70||
53
+ |SD15 inpaint|30.55|9.67|| -->
54
+
55
+ > Note: 测试环境或配置为Paddle 3.0 beta版本,A100 80G单卡,FP16。
56
+ 推理参数为Image Width = 512, Image Height = 512, Num Inference Steps = 50。
57
+
58
+ <!-- |SDXL text2img||||
59
+ |SDXL img2img||||
60
+ |SDXL inpaint|||| -->
61
+
62
+ <!-- |-|-|-|-|
63
+ |ControlNet text2img|3.360597|||
64
+ |ControlNet img2img|3.360597|||
65
+ |ControlNet inpaint|3.360597||| -->
PaddleMIX/ppdiffusers/ppdiffusers/__init__.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from .patches import *
18
+ from .utils import (
19
+ PPDIFFUSERS_SLOW_IMPORT,
20
+ OptionalDependencyNotAvailable,
21
+ _LazyModule,
22
+ is_einops_available,
23
+ is_fastdeploy_available,
24
+ is_inflect_available,
25
+ is_k_diffusion_available,
26
+ is_k_diffusion_version,
27
+ is_librosa_available,
28
+ is_note_seq_available,
29
+ is_onnx_available,
30
+ is_paddle_available,
31
+ is_paddle_version,
32
+ is_paddlenlp_available,
33
+ is_paddlenlp_version,
34
+ is_paddlesde_available,
35
+ is_pp_invisible_watermark_available,
36
+ is_ppxformers_available,
37
+ is_scipy_available,
38
+ is_torch_available,
39
+ is_transformers_available,
40
+ is_unidecode_available,
41
+ logging,
42
+ )
43
+ from .version import VERSION as __version__
44
+
45
+ # Lazy Import based on
46
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
47
+
48
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
49
+ # and is used to defer the actual importing for when the objects are requested.
50
+ # This way `import ppdiffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
51
+
52
+ _import_structure = {
53
+ "configuration_utils": ["ConfigMixin"],
54
+ "models": [],
55
+ "pipelines": [],
56
+ "schedulers": [],
57
+ "utils": [
58
+ "OptionalDependencyNotAvailable",
59
+ "is_inflect_available",
60
+ "is_pp_invisible_watermark_available",
61
+ "is_k_diffusion_available",
62
+ "is_k_diffusion_version",
63
+ "is_librosa_available",
64
+ "is_note_seq_available",
65
+ "is_onnx_available",
66
+ "is_fastdeploy_available",
67
+ "is_scipy_available",
68
+ "is_paddle_available",
69
+ "is_paddle_version",
70
+ "is_paddlesde_available",
71
+ "is_paddlenlp_available",
72
+ "is_paddlenlp_version",
73
+ "is_unidecode_available",
74
+ # NEW ADD
75
+ "is_ppxformers_available",
76
+ "is_einops_available",
77
+ "is_torch_available",
78
+ "is_transformers_available",
79
+ "logging",
80
+ ],
81
+ }
82
+
83
+ try:
84
+ if not is_fastdeploy_available():
85
+ raise OptionalDependencyNotAvailable()
86
+ except OptionalDependencyNotAvailable:
87
+ from .utils import dummy_fastdeploy_objects # noqa F403
88
+
89
+ _import_structure["utils.dummy_fastdeploy_objects"] = [
90
+ name for name in dir(dummy_fastdeploy_objects) if not name.startswith("_")
91
+ ]
92
+
93
+ else:
94
+ _import_structure["pipelines"].extend(
95
+ ["FastDeployRuntimeModel", "FastDeployDiffusionPipelineMixin", "FastDeployDiffusionXLPipelineMixin"]
96
+ )
97
+
98
+ try:
99
+ if not is_paddle_available():
100
+ raise OptionalDependencyNotAvailable()
101
+ except OptionalDependencyNotAvailable:
102
+ from .utils import dummy_paddle_objects # noqa F403
103
+
104
+ _import_structure["utils.dummy_paddle_objects"] = [
105
+ name for name in dir(dummy_paddle_objects) if not name.startswith("_")
106
+ ]
107
+
108
+ else:
109
+ _import_structure["models"].extend(
110
+ [
111
+ "AsymmetricAutoencoderKL",
112
+ "AutoencoderKL",
113
+ "AutoencoderKLCogVideoX",
114
+ "AutoencoderKLTemporalDecoder",
115
+ "AutoencoderTiny",
116
+ "CogVideoXTransformer3DModel",
117
+ "CogVideoXTransformer3DVCtrlModel",
118
+ "ConsistencyDecoderVAE",
119
+ "ControlNetModel",
120
+ "Kandinsky3UNet",
121
+ "ModelMixin",
122
+ "MotionAdapter",
123
+ "MultiAdapter",
124
+ "PriorTransformer",
125
+ "SD3Transformer2DModel",
126
+ "T2IAdapter",
127
+ "T5FilmDecoder",
128
+ "Transformer2DModel",
129
+ "UNet1DModel",
130
+ "UNet2DConditionModel",
131
+ "UNet2DModel",
132
+ "UNet3DConditionModel",
133
+ "UNetMotionModel",
134
+ "UNetSpatioTemporalConditionModel",
135
+ "VQModel",
136
+ "UViTT2IModel",
137
+ "DiTLLaMA2DModel",
138
+ "DiTLLaMAT2IModel",
139
+ # new add
140
+ "LVDMAutoencoderKL",
141
+ "LVDMUNet3DModel",
142
+ "PaddleInferRuntimeModel",
143
+ # new add
144
+ "AutoencoderKL_imgtovideo",
145
+ "GaussianDiffusion",
146
+ "GaussianDiffusion_SDEdit",
147
+ "STUNetModel",
148
+ "Vid2VidSTUNet",
149
+ # new add
150
+ "SD3ControlNetModel",
151
+ "SD3MultiControlNetModel",
152
+ # new add
153
+ "VCtrlModel",
154
+ ]
155
+ )
156
+
157
+ _import_structure["optimization"] = [
158
+ "get_constant_schedule",
159
+ "get_constant_schedule_with_warmup",
160
+ "get_cosine_schedule_with_warmup",
161
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
162
+ "get_linear_schedule_with_warmup",
163
+ "get_polynomial_decay_schedule_with_warmup",
164
+ "get_scheduler",
165
+ ]
166
+ _import_structure["pipelines"].extend(
167
+ [
168
+ "AudioPipelineOutput",
169
+ "AutoPipelineForImage2Image",
170
+ "AutoPipelineForInpainting",
171
+ "AutoPipelineForText2Image",
172
+ "ConsistencyModelPipeline",
173
+ "CogVideoXVCtrlPipeline",
174
+ "CogVideoXVCtrlImageToVideoPipeline",
175
+ "DanceDiffusionPipeline",
176
+ "DDIMPipeline",
177
+ "DDPMPipeline",
178
+ "DiffusionPipeline",
179
+ "DiTPipeline",
180
+ "ImagePipelineOutput",
181
+ "KarrasVePipeline",
182
+ "LDMPipeline",
183
+ "LDMSuperResolutionPipeline",
184
+ "PNDMPipeline",
185
+ "RePaintPipeline",
186
+ "ScoreSdeVePipeline",
187
+ ]
188
+ )
189
+ _import_structure["schedulers"].extend(
190
+ [
191
+ "CMStochasticIterativeScheduler",
192
+ "CogVideoXDDIMScheduler",
193
+ "CogVideoXDPMScheduler",
194
+ "DDIMInverseScheduler",
195
+ "DDIMParallelScheduler",
196
+ "DDIMScheduler",
197
+ "DDPMParallelScheduler",
198
+ "DDPMScheduler",
199
+ "DDPMWuerstchenScheduler",
200
+ "DEISMultistepScheduler",
201
+ "DPMSolverMultistepInverseScheduler",
202
+ "DPMSolverMultistepScheduler",
203
+ "DPMSolverSinglestepScheduler",
204
+ "EulerAncestralDiscreteScheduler",
205
+ "EulerDiscreteScheduler",
206
+ "FlowMatchEulerDiscreteScheduler",
207
+ "HeunDiscreteScheduler",
208
+ "IPNDMScheduler",
209
+ "KarrasVeScheduler",
210
+ "ScoreSdeVpScheduler", # new add
211
+ "PreconfigEulerAncestralDiscreteScheduler", # new add
212
+ "PreconfigLMSDiscreteScheduler", # new add
213
+ "KDPM2AncestralDiscreteScheduler",
214
+ "KDPM2DiscreteScheduler",
215
+ "LCMScheduler",
216
+ "PNDMScheduler",
217
+ "RePaintScheduler",
218
+ "SchedulerMixin",
219
+ "ScoreSdeVeScheduler",
220
+ "UnCLIPScheduler",
221
+ "UniPCMultistepScheduler",
222
+ "VQDiffusionScheduler",
223
+ "EDMDPMSolverMultistepScheduler",
224
+ "EDMEulerScheduler",
225
+ ]
226
+ )
227
+ _import_structure["training_utils"] = ["EMAModel"]
228
+
229
+ try:
230
+ if not (is_paddle_available() and is_scipy_available()):
231
+ raise OptionalDependencyNotAvailable()
232
+ except OptionalDependencyNotAvailable:
233
+ from .utils import dummy_paddle_and_scipy_objects # noqa F403
234
+
235
+ _import_structure["utils.dummy_paddle_and_scipy_objects"] = [
236
+ name for name in dir(dummy_paddle_and_scipy_objects) if not name.startswith("_")
237
+ ]
238
+
239
+ else:
240
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
241
+
242
+ try:
243
+ if not (is_paddle_available() and is_paddlesde_available()):
244
+ raise OptionalDependencyNotAvailable()
245
+ except OptionalDependencyNotAvailable:
246
+ from .utils import dummy_paddle_and_paddlesde_objects # noqa F403
247
+
248
+ _import_structure["utils.dummy_paddle_and_paddlesde_objects"] = [
249
+ name for name in dir(dummy_paddle_and_paddlesde_objects) if not name.startswith("_")
250
+ ]
251
+
252
+ else:
253
+ _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
254
+
255
+ try:
256
+ if not (is_paddle_available() and is_paddlenlp_available()):
257
+ raise OptionalDependencyNotAvailable()
258
+ except OptionalDependencyNotAvailable:
259
+ from .utils import dummy_paddle_and_paddlenlp_objects # noqa F403
260
+
261
+ _import_structure["utils.dummy_paddle_and_paddlenlp_objects"] = [
262
+ name for name in dir(dummy_paddle_and_paddlenlp_objects) if not name.startswith("_")
263
+ ]
264
+
265
+ else:
266
+ _import_structure["pipelines"].extend(
267
+ [
268
+ "AltDiffusionImg2ImgPipeline",
269
+ "AltDiffusionPipeline",
270
+ "AnimateDiffPipeline",
271
+ "AudioLDM2Pipeline",
272
+ "AudioLDM2ProjectionModel",
273
+ "AudioLDM2UNet2DConditionModel",
274
+ "AudioLDMPipeline",
275
+ "BlipDiffusionControlNetPipeline",
276
+ "BlipDiffusionPipeline",
277
+ "CLIPImageProjection",
278
+ "CogVideoXPipeline",
279
+ "CycleDiffusionPipeline",
280
+ "IFImg2ImgPipeline",
281
+ "IFImg2ImgSuperResolutionPipeline",
282
+ "IFInpaintingPipeline",
283
+ "IFInpaintingSuperResolutionPipeline",
284
+ "IFPipeline",
285
+ "IFSuperResolutionPipeline",
286
+ "ImageTextPipelineOutput",
287
+ "Kandinsky3Img2ImgPipeline",
288
+ "Kandinsky3Pipeline",
289
+ "KandinskyCombinedPipeline",
290
+ "KandinskyImg2ImgCombinedPipeline",
291
+ "KandinskyImg2ImgPipeline",
292
+ "KandinskyInpaintCombinedPipeline",
293
+ "KandinskyInpaintPipeline",
294
+ "KandinskyPipeline",
295
+ "KandinskyPriorPipeline",
296
+ "KandinskyV22CombinedPipeline",
297
+ "KandinskyV22ControlnetImg2ImgPipeline",
298
+ "KandinskyV22ControlnetPipeline",
299
+ "KandinskyV22Img2ImgCombinedPipeline",
300
+ "KandinskyV22Img2ImgPipeline",
301
+ "KandinskyV22InpaintCombinedPipeline",
302
+ "KandinskyV22InpaintPipeline",
303
+ "KandinskyV22Pipeline",
304
+ "KandinskyV22PriorEmb2EmbPipeline",
305
+ "KandinskyV22PriorPipeline",
306
+ "LatentConsistencyModelImg2ImgPipeline",
307
+ "LatentConsistencyModelPipeline",
308
+ "LDMTextToImagePipeline",
309
+ "LDMTextToImageUViTPipeline",
310
+ "LDMTextToImageLargeDiTPipeline",
311
+ "MusicLDMPipeline",
312
+ "PaintByExamplePipeline",
313
+ "PixArtAlphaPipeline",
314
+ "SemanticStableDiffusionPipeline",
315
+ "ShapEImg2ImgPipeline",
316
+ "ShapEPipeline",
317
+ "StableDiffusion3ControlNetInpaintingPipeline",
318
+ "StableDiffusion3ControlNetPipeline",
319
+ "StableDiffusion3Img2ImgPipeline",
320
+ "StableDiffusion3Pipeline",
321
+ "StableDiffusionAdapterPipeline",
322
+ "StableDiffusionAttendAndExcitePipeline",
323
+ "StableDiffusionControlNetImg2ImgPipeline",
324
+ "StableDiffusionControlNetInpaintPipeline",
325
+ "StableDiffusionControlNetPipeline",
326
+ "StableDiffusionDepth2ImgPipeline",
327
+ "StableDiffusionDiffEditPipeline",
328
+ "StableDiffusionGLIGENPipeline",
329
+ "StableDiffusionGLIGENTextImagePipeline",
330
+ "StableDiffusionImageVariationPipeline",
331
+ "StableDiffusionImg2ImgPipeline",
332
+ "StableDiffusionInpaintPipeline",
333
+ "StableDiffusionInpaintPipelineLegacy",
334
+ "StableDiffusionInstructPix2PixPipeline",
335
+ "StableDiffusionLatentUpscalePipeline",
336
+ "StableDiffusionLDM3DPipeline",
337
+ "StableDiffusionModelEditingPipeline",
338
+ "StableDiffusionPanoramaPipeline",
339
+ "StableDiffusionParadigmsPipeline",
340
+ "StableDiffusionPipeline",
341
+ "StableDiffusionPipelineSafe",
342
+ "StableDiffusionPix2PixZeroPipeline",
343
+ "StableDiffusionSAGPipeline",
344
+ "StableDiffusionUpscalePipeline",
345
+ "StableDiffusionXLAdapterPipeline",
346
+ "StableDiffusionXLControlNetImg2ImgPipeline",
347
+ "StableDiffusionXLControlNetInpaintPipeline",
348
+ "StableDiffusionXLControlNetPipeline",
349
+ "StableDiffusionXLImg2ImgPipeline",
350
+ "StableDiffusionXLInpaintPipeline",
351
+ "StableDiffusionXLInstructPix2PixPipeline",
352
+ "StableDiffusionXLPipeline",
353
+ "StableUnCLIPImg2ImgPipeline",
354
+ "StableUnCLIPPipeline",
355
+ "StableDiffusionSafetyChecker",
356
+ "StableVideoDiffusionPipeline",
357
+ "TextToVideoSDPipeline",
358
+ "TextToVideoZeroPipeline",
359
+ "TextToVideoZeroSDXLPipeline",
360
+ "UnCLIPImageVariationPipeline",
361
+ "UnCLIPPipeline",
362
+ "UniDiffuserModel",
363
+ "UniDiffuserPipeline",
364
+ "UniDiffuserTextDecoder",
365
+ "VersatileDiffusionDualGuidedPipeline",
366
+ "VersatileDiffusionImageVariationPipeline",
367
+ "VersatileDiffusionPipeline",
368
+ "VersatileDiffusionTextToImagePipeline",
369
+ "VideoToVideoSDPipeline",
370
+ "VQDiffusionPipeline",
371
+ "WuerstchenCombinedPipeline",
372
+ "WuerstchenDecoderPipeline",
373
+ "WuerstchenPriorPipeline",
374
+ # new add
375
+ "LVDMTextToVideoPipeline",
376
+ "LVDMUncondPipeline",
377
+ "PaddleInferCycleDiffusionPipeline",
378
+ "PaddleInferStableDiffusionImg2ImgPipeline",
379
+ "PaddleInferStableDiffusionInpaintPipeline",
380
+ "PaddleInferStableDiffusionInpaintPipelineLegacy",
381
+ "PaddleInferStableDiffusionMegaPipeline",
382
+ "PaddleInferStableDiffusionPipeline",
383
+ "PaddleInferStableDiffusionUpscalePipeline",
384
+ "PaddleInferStableDiffusionXLPipeline",
385
+ "PaddleInferStableDiffusionXLImg2ImgPipeline",
386
+ "PaddleInferStableDiffusionXLInpaintPipeline",
387
+ "PaddleInferStableDiffusionXLInstructPix2PixPipeline",
388
+ "PaddleInferStableDiffusionXLMegaPipeline",
389
+ "PaddleInferStableDiffusionControlNetPipeline",
390
+ "PaddleInferStableVideoDiffusionPipeline",
391
+ # new add
392
+ "ImgToVideoSDPipeline",
393
+ "VideoToVideoModelscopePipeline",
394
+ ]
395
+ )
396
+
397
+ try:
398
+ if not (is_paddle_available() and is_paddlenlp_available() and is_k_diffusion_available()):
399
+ raise OptionalDependencyNotAvailable()
400
+ except OptionalDependencyNotAvailable:
401
+ from .utils import dummy_paddle_and_paddlenlp_and_k_diffusion_objects # noqa F403
402
+
403
+ _import_structure["utils.dummy_paddle_and_paddlenlp_and_k_diffusion_objects"] = [
404
+ name for name in dir(dummy_paddle_and_paddlenlp_and_k_diffusion_objects) if not name.startswith("_")
405
+ ]
406
+
407
+ else:
408
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
409
+
410
+ try:
411
+ if not (is_paddle_available() and is_paddlenlp_available() and is_fastdeploy_available()):
412
+ raise OptionalDependencyNotAvailable()
413
+ except OptionalDependencyNotAvailable:
414
+ from .utils import dummy_paddle_and_paddlenlp_and_fastdeploy_objects # noqa F403
415
+
416
+ _import_structure["utils.dummy_paddle_and_paddlenlp_and_fastdeploy_objects"] = [
417
+ name for name in dir(dummy_paddle_and_paddlenlp_and_fastdeploy_objects) if not name.startswith("_")
418
+ ]
419
+
420
+ else:
421
+ _import_structure["pipelines"].extend(
422
+ [
423
+ "FastDeployStableDiffusionImg2ImgPipeline",
424
+ "FastDeployStableDiffusionInpaintPipeline",
425
+ "FastDeployStableDiffusionInpaintPipelineLegacy",
426
+ "FastDeployStableDiffusionPipeline",
427
+ "FastDeployStableDiffusionMegaPipeline",
428
+ "FastDeployCycleDiffusionPipeline",
429
+ "FastDeployStableDiffusionControlNetPipeline",
430
+ "FastDeployStableDiffusionUpscalePipeline",
431
+ ]
432
+ )
433
+
434
+ try:
435
+ if not (is_paddle_available() and is_librosa_available()):
436
+ raise OptionalDependencyNotAvailable()
437
+ except OptionalDependencyNotAvailable:
438
+ from .utils import dummy_paddle_and_librosa_objects # noqa F403
439
+
440
+ _import_structure["utils.dummy_paddle_and_librosa_objects"] = [
441
+ name for name in dir(dummy_paddle_and_librosa_objects) if not name.startswith("_")
442
+ ]
443
+
444
+ else:
445
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
446
+
447
+ try:
448
+ if not (is_paddlenlp_available() and is_paddle_available() and is_note_seq_available()):
449
+ raise OptionalDependencyNotAvailable()
450
+ except OptionalDependencyNotAvailable:
451
+ from .utils import dummy_paddle_and_paddlenlp_and_note_seq_objects # noqa F403
452
+
453
+ _import_structure["utils.dummy_paddle_and_paddlenlp_and_note_seq_objects"] = [
454
+ name for name in dir(dummy_paddle_and_paddlenlp_and_note_seq_objects) if not name.startswith("_")
455
+ ]
456
+
457
+
458
+ else:
459
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
460
+
461
+ try:
462
+ if not (is_note_seq_available()):
463
+ raise OptionalDependencyNotAvailable()
464
+ except OptionalDependencyNotAvailable:
465
+ from .utils import dummy_note_seq_objects # noqa F403
466
+
467
+ _import_structure["utils.dummy_note_seq_objects"] = [
468
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
469
+ ]
470
+
471
+
472
+ else:
473
+ _import_structure["pipelines"].extend(["MidiProcessor"])
474
+
475
+ if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
476
+ from .configuration_utils import ConfigMixin
477
+
478
+ try:
479
+ if not is_fastdeploy_available():
480
+ raise OptionalDependencyNotAvailable()
481
+ except OptionalDependencyNotAvailable:
482
+ from .utils.dummy_fastdeploy_objects import * # noqa F403
483
+ else:
484
+ from .pipelines import (
485
+ FastDeployDiffusionPipelineMixin,
486
+ FastDeployDiffusionXLPipelineMixin,
487
+ FastDeployRuntimeModel,
488
+ )
489
+
490
+ try:
491
+ if not is_paddle_available():
492
+ raise OptionalDependencyNotAvailable()
493
+ except OptionalDependencyNotAvailable:
494
+ from .utils.dummy_paddle_objects import * # noqa F403
495
+ else:
496
+ from .models import ( # new add
497
+ AsymmetricAutoencoderKL,
498
+ AutoencoderKL,
499
+ AutoencoderKL_imgtovideo,
500
+ AutoencoderKLCogVideoX,
501
+ AutoencoderKLTemporalDecoder,
502
+ AutoencoderTiny,
503
+ CogVideoXTransformer3DModel,
504
+ CogVideoXTransformer3DVCtrlModel,
505
+ ConsistencyDecoderVAE,
506
+ ControlNetModel,
507
+ DiTLLaMA2DModel,
508
+ DiTLLaMAT2IModel,
509
+ GaussianDiffusion,
510
+ GaussianDiffusion_SDEdit,
511
+ Kandinsky3UNet,
512
+ LVDMAutoencoderKL,
513
+ LVDMUNet3DModel,
514
+ ModelMixin,
515
+ MotionAdapter,
516
+ MultiAdapter,
517
+ PaddleInferRuntimeModel,
518
+ PriorTransformer,
519
+ SD3ControlNetModel,
520
+ SD3MultiControlNetModel,
521
+ SD3Transformer2DModel,
522
+ STUNetModel,
523
+ T2IAdapter,
524
+ T5FilmDecoder,
525
+ Transformer2DModel,
526
+ UNet1DModel,
527
+ UNet2DConditionModel,
528
+ UNet2DModel,
529
+ UNet3DConditionModel,
530
+ UNetMotionModel,
531
+ UNetSpatioTemporalConditionModel,
532
+ UViTT2IModel,
533
+ VCtrlModel,
534
+ Vid2VidSTUNet,
535
+ VQModel,
536
+ )
537
+ from .optimization import (
538
+ get_constant_schedule,
539
+ get_constant_schedule_with_warmup,
540
+ get_cosine_schedule_with_warmup,
541
+ get_cosine_with_hard_restarts_schedule_with_warmup,
542
+ get_linear_schedule_with_warmup,
543
+ get_polynomial_decay_schedule_with_warmup,
544
+ get_scheduler,
545
+ )
546
+ from .pipelines import ( # new add
547
+ AudioPipelineOutput,
548
+ AutoPipelineForImage2Image,
549
+ AutoPipelineForInpainting,
550
+ AutoPipelineForText2Image,
551
+ BlipDiffusionControlNetPipeline,
552
+ BlipDiffusionPipeline,
553
+ CogVideoXVCtrlImageToVideoPipeline,
554
+ CogVideoXVCtrlPipeline,
555
+ ConsistencyModelPipeline,
556
+ DanceDiffusionPipeline,
557
+ DDIMPipeline,
558
+ DDPMPipeline,
559
+ DiffusionPipeline,
560
+ DiTPipeline,
561
+ ImagePipelineOutput,
562
+ ImgToVideoSDPipeline,
563
+ KarrasVePipeline,
564
+ LDMPipeline,
565
+ LDMSuperResolutionPipeline,
566
+ PNDMPipeline,
567
+ RePaintPipeline,
568
+ ScoreSdeVePipeline,
569
+ VideoToVideoModelscopePipeline,
570
+ )
571
+ from .schedulers import (
572
+ CMStochasticIterativeScheduler,
573
+ CogVideoXDDIMScheduler,
574
+ CogVideoXDPMScheduler,
575
+ DDIMInverseScheduler,
576
+ DDIMParallelScheduler,
577
+ DDIMScheduler,
578
+ DDPMParallelScheduler,
579
+ DDPMScheduler,
580
+ DDPMWuerstchenScheduler,
581
+ DEISMultistepScheduler,
582
+ DPMSolverMultistepInverseScheduler,
583
+ DPMSolverMultistepScheduler,
584
+ DPMSolverSinglestepScheduler,
585
+ EDMDPMSolverMultistepScheduler,
586
+ EDMEulerScheduler,
587
+ EulerAncestralDiscreteScheduler,
588
+ EulerDiscreteScheduler,
589
+ FlowMatchEulerDiscreteScheduler,
590
+ HeunDiscreteScheduler,
591
+ IPNDMScheduler,
592
+ KarrasVeScheduler,
593
+ KDPM2AncestralDiscreteScheduler,
594
+ KDPM2DiscreteScheduler,
595
+ LCMScheduler,
596
+ PNDMScheduler,
597
+ PreconfigEulerAncestralDiscreteScheduler,
598
+ PreconfigLMSDiscreteScheduler,
599
+ RePaintScheduler,
600
+ SchedulerMixin,
601
+ ScoreSdeVeScheduler,
602
+ ScoreSdeVpScheduler,
603
+ UnCLIPScheduler,
604
+ UniPCMultistepScheduler,
605
+ VQDiffusionScheduler,
606
+ )
607
+ from .training_utils import EMAModel
608
+
609
+ try:
610
+ if not (is_paddle_available() and is_scipy_available()):
611
+ raise OptionalDependencyNotAvailable()
612
+ except OptionalDependencyNotAvailable:
613
+ from .utils.dummy_paddle_and_scipy_objects import * # noqa F403
614
+ else:
615
+ from .schedulers import LMSDiscreteScheduler
616
+
617
+ try:
618
+ if not (is_paddle_available() and is_paddlesde_available()):
619
+ raise OptionalDependencyNotAvailable()
620
+ except OptionalDependencyNotAvailable:
621
+ from .utils.dummy_paddle_and_paddlesde_objects import * # noqa F403
622
+ else:
623
+ from .schedulers import DPMSolverSDEScheduler
624
+
625
+ try:
626
+ if not (is_paddle_available() and is_paddlenlp_available()):
627
+ raise OptionalDependencyNotAvailable()
628
+ except OptionalDependencyNotAvailable:
629
+ from .utils.dummy_paddle_and_paddlenlp_objects import * # noqa F403
630
+ else:
631
+ from .pipelines import ( # new add
632
+ AltDiffusionImg2ImgPipeline,
633
+ AltDiffusionPipeline,
634
+ AnimateDiffPipeline,
635
+ AudioLDM2Pipeline,
636
+ AudioLDM2ProjectionModel,
637
+ AudioLDM2UNet2DConditionModel,
638
+ AudioLDMPipeline,
639
+ CLIPImageProjection,
640
+ CogVideoXPipeline,
641
+ CycleDiffusionPipeline,
642
+ IFImg2ImgPipeline,
643
+ IFImg2ImgSuperResolutionPipeline,
644
+ IFInpaintingPipeline,
645
+ IFInpaintingSuperResolutionPipeline,
646
+ IFPipeline,
647
+ IFSuperResolutionPipeline,
648
+ ImageTextPipelineOutput,
649
+ Kandinsky3Img2ImgPipeline,
650
+ Kandinsky3Pipeline,
651
+ KandinskyCombinedPipeline,
652
+ KandinskyImg2ImgCombinedPipeline,
653
+ KandinskyImg2ImgPipeline,
654
+ KandinskyInpaintCombinedPipeline,
655
+ KandinskyInpaintPipeline,
656
+ KandinskyPipeline,
657
+ KandinskyPriorPipeline,
658
+ KandinskyV22CombinedPipeline,
659
+ KandinskyV22ControlnetImg2ImgPipeline,
660
+ KandinskyV22ControlnetPipeline,
661
+ KandinskyV22Img2ImgCombinedPipeline,
662
+ KandinskyV22Img2ImgPipeline,
663
+ KandinskyV22InpaintCombinedPipeline,
664
+ KandinskyV22InpaintPipeline,
665
+ KandinskyV22Pipeline,
666
+ KandinskyV22PriorEmb2EmbPipeline,
667
+ KandinskyV22PriorPipeline,
668
+ LatentConsistencyModelImg2ImgPipeline,
669
+ LatentConsistencyModelPipeline,
670
+ LDMTextToImageLargeDiTPipeline,
671
+ LDMTextToImagePipeline,
672
+ LDMTextToImageUViTPipeline,
673
+ LVDMTextToVideoPipeline,
674
+ LVDMUncondPipeline,
675
+ MusicLDMPipeline,
676
+ PaddleInferCycleDiffusionPipeline,
677
+ PaddleInferStableDiffusionControlNetPipeline,
678
+ PaddleInferStableDiffusionImg2ImgPipeline,
679
+ PaddleInferStableDiffusionInpaintPipeline,
680
+ PaddleInferStableDiffusionInpaintPipelineLegacy,
681
+ PaddleInferStableDiffusionMegaPipeline,
682
+ PaddleInferStableDiffusionPipeline,
683
+ PaddleInferStableDiffusionXLImg2ImgPipeline,
684
+ PaddleInferStableDiffusionXLInpaintPipeline,
685
+ PaddleInferStableDiffusionXLInstructPix2PixPipeline,
686
+ PaddleInferStableDiffusionXLMegaPipeline,
687
+ PaddleInferStableDiffusionXLPipeline,
688
+ PaddleInferStableVideoDiffusionPipeline,
689
+ PaintByExamplePipeline,
690
+ PixArtAlphaPipeline,
691
+ SemanticStableDiffusionPipeline,
692
+ ShapEImg2ImgPipeline,
693
+ ShapEPipeline,
694
+ StableDiffusion3ControlNetPipeline,
695
+ StableDiffusion3Img2ImgPipeline,
696
+ StableDiffusion3Pipeline,
697
+ StableDiffusionAdapterPipeline,
698
+ StableDiffusionAttendAndExcitePipeline,
699
+ StableDiffusionControlNetImg2ImgPipeline,
700
+ StableDiffusionControlNetInpaintPipeline,
701
+ StableDiffusionControlNetPipeline,
702
+ StableDiffusionDepth2ImgPipeline,
703
+ StableDiffusionDiffEditPipeline,
704
+ StableDiffusionGLIGENPipeline,
705
+ StableDiffusionGLIGENTextImagePipeline,
706
+ StableDiffusionImageVariationPipeline,
707
+ StableDiffusionImg2ImgPipeline,
708
+ StableDiffusionInpaintPipeline,
709
+ StableDiffusionInpaintPipelineLegacy,
710
+ StableDiffusionInstructPix2PixPipeline,
711
+ StableDiffusionLatentUpscalePipeline,
712
+ StableDiffusionLDM3DPipeline,
713
+ StableDiffusionModelEditingPipeline,
714
+ StableDiffusionPanoramaPipeline,
715
+ StableDiffusionParadigmsPipeline,
716
+ StableDiffusionPipeline,
717
+ StableDiffusionPipelineSafe,
718
+ StableDiffusionPix2PixZeroPipeline,
719
+ StableDiffusionSafetyChecker,
720
+ StableDiffusionSAGPipeline,
721
+ StableDiffusionUpscalePipeline,
722
+ StableDiffusionXLAdapterPipeline,
723
+ StableDiffusionXLControlNetImg2ImgPipeline,
724
+ StableDiffusionXLControlNetInpaintPipeline,
725
+ StableDiffusionXLControlNetPipeline,
726
+ StableDiffusionXLImg2ImgPipeline,
727
+ StableDiffusionXLInpaintPipeline,
728
+ StableDiffusionXLInstructPix2PixPipeline,
729
+ StableDiffusionXLPipeline,
730
+ StableUnCLIPImg2ImgPipeline,
731
+ StableUnCLIPPipeline,
732
+ StableVideoDiffusionPipeline,
733
+ TextToVideoSDPipeline,
734
+ TextToVideoZeroPipeline,
735
+ TextToVideoZeroSDXLPipeline,
736
+ UnCLIPImageVariationPipeline,
737
+ UnCLIPPipeline,
738
+ UniDiffuserModel,
739
+ UniDiffuserPipeline,
740
+ UniDiffuserTextDecoder,
741
+ VersatileDiffusionDualGuidedPipeline,
742
+ VersatileDiffusionImageVariationPipeline,
743
+ VersatileDiffusionPipeline,
744
+ VersatileDiffusionTextToImagePipeline,
745
+ VideoToVideoSDPipeline,
746
+ VQDiffusionPipeline,
747
+ WuerstchenCombinedPipeline,
748
+ WuerstchenDecoderPipeline,
749
+ WuerstchenPriorPipeline,
750
+ )
751
+
752
+ try:
753
+ if not (is_paddle_available() and is_paddlenlp_available() and is_k_diffusion_available()):
754
+ raise OptionalDependencyNotAvailable()
755
+ except OptionalDependencyNotAvailable:
756
+ from .utils.dummy_paddle_and_paddlenlp_and_k_diffusion_objects import * # noqa F403
757
+ else:
758
+ from .pipelines import StableDiffusionKDiffusionPipeline
759
+
760
+ try:
761
+ if not (is_paddle_available() and is_paddlenlp_available() and is_fastdeploy_available()):
762
+ raise OptionalDependencyNotAvailable()
763
+ except OptionalDependencyNotAvailable:
764
+ from .utils.dummy_paddle_and_paddlenlp_and_fastdeploy_objects import * # noqa F403
765
+ else:
766
+ from .pipelines import (
767
+ FastDeployCycleDiffusionPipeline,
768
+ FastDeployStableDiffusionControlNetPipeline,
769
+ FastDeployStableDiffusionImg2ImgPipeline,
770
+ FastDeployStableDiffusionInpaintPipeline,
771
+ FastDeployStableDiffusionInpaintPipelineLegacy,
772
+ FastDeployStableDiffusionMegaPipeline,
773
+ FastDeployStableDiffusionPipeline,
774
+ FastDeployStableDiffusionUpscalePipeline,
775
+ FastDeployStableDiffusionXLImg2ImgPipeline,
776
+ FastDeployStableDiffusionXLInpaintPipeline,
777
+ FastDeployStableDiffusionXLInstructPix2PixPipeline,
778
+ FastDeployStableDiffusionXLPipeline,
779
+ )
780
+
781
+ try:
782
+ if not (is_paddle_available() and is_librosa_available()):
783
+ raise OptionalDependencyNotAvailable()
784
+ except OptionalDependencyNotAvailable:
785
+ from .utils.dummy_paddle_and_librosa_objects import * # noqa F403
786
+ else:
787
+ from .pipelines import AudioDiffusionPipeline, Mel
788
+
789
+ try:
790
+ if not (is_paddlenlp_available() and is_paddle_available() and is_note_seq_available()):
791
+ raise OptionalDependencyNotAvailable()
792
+ except OptionalDependencyNotAvailable:
793
+ from .utils.dummy_paddle_and_paddlenlp_and_note_seq_objects import * # noqa F403
794
+ else:
795
+ from .pipelines import SpectrogramDiffusionPipeline
796
+
797
+ try:
798
+ if not (is_note_seq_available()):
799
+ raise OptionalDependencyNotAvailable()
800
+ except OptionalDependencyNotAvailable:
801
+ from .utils.dummy_note_seq_objects import * # noqa F403
802
+ else:
803
+ from .pipelines import MidiProcessor
804
+
805
+ else:
806
+ import sys
807
+
808
+ sys.modules[__name__] = _LazyModule(
809
+ __name__,
810
+ globals()["__file__"],
811
+ _import_structure,
812
+ module_spec=__spec__,
813
+ extra_objects={"__version__": __version__},
814
+ )
PaddleMIX/ppdiffusers/ppdiffusers/accelerate/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ __version__ = "0.25.0"
16
+
17
+ from .accelerator import Accelerator
18
+ from .state import PartialState
19
+ from .utils import (
20
+ AutocastKwargs,
21
+ DistributedDataParallelKwargs,
22
+ DistributedType,
23
+ GradScalerKwargs,
24
+ InitProcessGroupKwargs,
25
+ find_executable_batch_size,
26
+ is_rich_available,
27
+ )
28
+
29
+ if is_rich_available():
30
+ from .utils import rich
PaddleMIX/ppdiffusers/ppdiffusers/accelerate/logging.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import logging
17
+ import os
18
+
19
+ from .state import PartialState
20
+
21
+
22
+ class MultiProcessAdapter(logging.LoggerAdapter):
23
+ """
24
+ An adapter to assist with logging in multiprocess.
25
+
26
+ `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
27
+ or only the main executed one. Default is `main_process_only=True`.
28
+
29
+ Does not require an `Accelerator` object to be created first.
30
+ """
31
+
32
+ @staticmethod
33
+ def _should_log(main_process_only):
34
+ "Check if log should be performed"
35
+ state = PartialState()
36
+ return not main_process_only or (main_process_only and state.is_main_process)
37
+
38
+ def log(self, level, msg, *args, **kwargs):
39
+ """
40
+ Delegates logger call after checking if we should log.
41
+
42
+ Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
43
+ or only the main executed one. Default is `True` if not passed
44
+
45
+ Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
46
+ read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
47
+ break with the previous behavior.
48
+
49
+ `in_order` is ignored if `main_process_only` is passed.
50
+ """
51
+ if PartialState._shared_state == {}:
52
+ raise RuntimeError(
53
+ "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
54
+ )
55
+ main_process_only = kwargs.pop("main_process_only", True)
56
+ in_order = kwargs.pop("in_order", False)
57
+
58
+ if self.isEnabledFor(level):
59
+ if self._should_log(main_process_only):
60
+ msg, kwargs = self.process(msg, kwargs)
61
+ self.logger.log(level, msg, *args, **kwargs)
62
+
63
+ elif in_order:
64
+ state = PartialState()
65
+ for i in range(state.num_processes):
66
+ if i == state.process_index:
67
+ msg, kwargs = self.process(msg, kwargs)
68
+ self.logger.log(level, msg, *args, **kwargs)
69
+ state.wait_for_everyone()
70
+
71
+ @functools.lru_cache(None)
72
+ def warning_once(self, *args, **kwargs):
73
+ """
74
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
75
+
76
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
77
+ cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
78
+ switch to another type of cache that includes the caller frame information in the hashing function.
79
+ """
80
+ self.warning(*args, **kwargs)
81
+
82
+
83
+ def get_logger(name: str, log_level: str = None):
84
+ """
85
+ Returns a `logging.Logger` for `name` that can handle multiprocessing.
86
+
87
+ If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
88
+ processes and in order, also pass `in_order=True`
89
+
90
+ Args:
91
+ name (`str`):
92
+ The name for the logger, such as `__file__`
93
+ log_level (`str`, *optional*):
94
+ The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
95
+
96
+ Example:
97
+
98
+ ```python
99
+ >>> from accelerate.logging import get_logger
100
+ >>> from accelerate import Accelerator
101
+
102
+ >>> logger = get_logger(__name__)
103
+
104
+ >>> accelerator = Accelerator()
105
+ >>> logger.info("My log", main_process_only=False)
106
+ >>> logger.debug("My log", main_process_only=True)
107
+
108
+ >>> logger = get_logger(__name__, log_level="DEBUG")
109
+ >>> logger.info("My log")
110
+ >>> logger.debug("My second log")
111
+
112
+ >>> array = ["a", "b", "c", "d"]
113
+ >>> letter_at_rank = array[accelerator.process_index]
114
+ >>> logger.info(letter_at_rank, in_order=True)
115
+ ```
116
+ """
117
+ if log_level is None:
118
+ log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
119
+ logger = logging.getLogger(name)
120
+ if log_level is not None:
121
+ logger.setLevel(log_level.upper())
122
+ logger.root.setLevel(log_level.upper())
123
+ return MultiProcessAdapter(logger, {})
PaddleMIX/ppdiffusers/ppdiffusers/accelerate/optimizer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import warnings
17
+
18
+ import paddle
19
+ import paddle.optimizer
20
+
21
+ from .state import AcceleratorState, GradientState
22
+ from .utils import honor_type
23
+
24
+
25
+ def move_to_device(state, device):
26
+ if isinstance(state, (list, tuple)):
27
+ return honor_type(state, (move_to_device(t, device) for t in state))
28
+ elif isinstance(state, dict):
29
+ return type(state)({k: move_to_device(v, device) for k, v in state.items()})
30
+ elif isinstance(state, paddle.Tensor):
31
+ return state.to(device)
32
+ return state
33
+
34
+
35
+ class AcceleratedOptimizer(paddle.optimizer.Optimizer):
36
+ """
37
+ Internal wrapper around a torch optimizer.
38
+
39
+ Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
40
+ accumulation.
41
+
42
+ Args:
43
+ optimizer (`torch.optim.optimizer.Optimizer`):
44
+ The optimizer to wrap.
45
+ device_placement (`bool`, *optional*, defaults to `True`):
46
+ Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
47
+ `optimizer` on the right device.
48
+ scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*):
49
+ The scaler to use in the step function if training with mixed precision.
50
+ """
51
+
52
+ def __init__(self, optimizer, device_placement=True, scaler=None):
53
+ self.optimizer = optimizer
54
+ self.scaler = scaler
55
+ self.accelerator_state = AcceleratorState()
56
+ self.gradient_state = GradientState()
57
+ device_placement = False
58
+ self.device_placement = device_placement
59
+ self._is_overflow = False
60
+
61
+ if self.scaler is not None:
62
+ self._accelerate_step_called = False
63
+ self._optimizer_original_step_method = self.optimizer.step
64
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
65
+
66
+ # Handle device placement
67
+ if device_placement:
68
+ state_dict = self.optimizer.state_dict()
69
+ self.optimizer.set_state_dict(state_dict)
70
+
71
+ @property
72
+ def state(self):
73
+ return self.optimizer.state
74
+
75
+ @state.setter
76
+ def state(self, state):
77
+ self.optimizer.state = state
78
+
79
+ @property
80
+ def param_groups(self):
81
+ return self.optimizer._param_groups
82
+
83
+ @param_groups.setter
84
+ def param_groups(self, param_groups):
85
+ self.optimizer._param_groups = param_groups
86
+
87
+ @property
88
+ def defaults(self):
89
+ return self.optimizer.defaults
90
+
91
+ @defaults.setter
92
+ def defaults(self, defaults):
93
+ self.optimizer.defaults = defaults
94
+
95
+ def add_param_group(self, param_group):
96
+ self.optimizer.add_param_group(param_group)
97
+
98
+ def load_state_dict(self, state_dict):
99
+ self.optimizer.set_state_dict(state_dict)
100
+
101
+ set_state_dict = load_state_dict
102
+
103
+ def state_dict(self):
104
+ return self.optimizer.state_dict()
105
+
106
+ def zero_grad(self, set_to_zero=None):
107
+ if self.gradient_state.sync_gradients:
108
+ accept_arg = "set_to_zero" in inspect.signature(self.optimizer.clear_grad).parameters
109
+ if accept_arg:
110
+ if set_to_zero is None:
111
+ set_to_zero = True
112
+ self.optimizer.clear_grad(set_to_zero=set_to_zero)
113
+ else:
114
+ if set_to_zero is not None:
115
+ raise ValueError("`set_to_zero` for Optimizer.clear_grad` is not supported by this optimizer.")
116
+ self.optimizer.clear_grad()
117
+
118
+ clear_grad = zero_grad
119
+
120
+ def step(self, closure=None):
121
+ if self.gradient_state.sync_gradients:
122
+ if self.scaler is not None:
123
+ self.optimizer.step = self._optimizer_patched_step_method
124
+
125
+ self.scaler.step(self.optimizer)
126
+ self.scaler.update()
127
+
128
+ if not self._accelerate_step_called:
129
+ # If the optimizer step was skipped, gradient overflow was detected.
130
+ self._is_overflow = True
131
+ else:
132
+ self._is_overflow = False
133
+ # Reset the step method to the original one
134
+ self.optimizer.step = self._optimizer_original_step_method
135
+ # Reset the indicator
136
+ self._accelerate_step_called = False
137
+ else:
138
+ self.optimizer.step()
139
+
140
+ def _switch_parameters(self, parameters_map):
141
+ for param_group in self.param_groups:
142
+ param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
143
+
144
+ @property
145
+ def is_overflow(self):
146
+ """Whether or not the optimizer step was done, or skipped because of gradient overflow."""
147
+ warnings.warn(
148
+ "The `is_overflow` property is deprecated and will be removed in version 1.0 of Accelerate use "
149
+ "`optimizer.step_was_skipped` instead.",
150
+ FutureWarning,
151
+ )
152
+ return self._is_overflow
153
+
154
+ @property
155
+ def step_was_skipped(self):
156
+ """Whether or not the optimizer step was skipped."""
157
+ return self._is_overflow
158
+
159
+ def __getstate__(self):
160
+ _ignored_keys = [
161
+ "_accelerate_step_called",
162
+ "_optimizer_original_step_method",
163
+ "_optimizer_patched_step_method",
164
+ ]
165
+ return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
166
+
167
+ def __setstate__(self, state):
168
+ self.__dict__.update(state)
169
+ if self.scaler is not None:
170
+ self._accelerate_step_called = False
171
+ self._optimizer_original_step_method = self.optimizer.step
172
+ self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
173
+
174
+
175
+ def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
176
+ def patched_step(*args, **kwargs):
177
+ accelerated_optimizer._accelerate_step_called = True
178
+ return method(*args, **kwargs)
179
+
180
+ return patched_step
PaddleMIX/ppdiffusers/ppdiffusers/accelerate/scheduler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
16
+
17
+
18
+ from .state import AcceleratorState, GradientState
19
+
20
+
21
+ class AcceleratedScheduler:
22
+ """
23
+ A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
24
+ to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
25
+ precision training)
26
+
27
+ When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
28
+ step the scheduler to account for it.
29
+
30
+ Args:
31
+ scheduler (`torch.optim.lr_scheduler._LRScheduler`):
32
+ The scheduler to wrap.
33
+ optimizers (one or a list of `torch.optim.Optimizer`):
34
+ The optimizers used.
35
+ step_with_optimizer (`bool`, *optional*, defaults to `True`):
36
+ Whether or not the scheduler should be stepped at each optimizer step.
37
+ split_batches (`bool`, *optional*, defaults to `False`):
38
+ Whether or not the dataloaders split one batch across the different processes (so batch size is the same
39
+ regardless of the number of processes) or create batches on each process (so batch size is the original
40
+ batch size multiplied by the number of processes).
41
+ """
42
+
43
+ def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
44
+ self.scheduler = scheduler
45
+ self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
46
+ self.split_batches = split_batches
47
+ self.step_with_optimizer = step_with_optimizer
48
+ self.gradient_state = GradientState()
49
+
50
+ def step(self, *args, **kwargs):
51
+ if not self.step_with_optimizer:
52
+ # No link between scheduler and optimizer -> just step
53
+ self.scheduler.step(*args, **kwargs)
54
+ return
55
+
56
+ # Otherwise, first make sure the optimizer was stepped.
57
+ if not self.gradient_state.sync_gradients:
58
+ if self.gradient_state.adjust_scheduler:
59
+ self.scheduler._step_count += 1
60
+ return
61
+
62
+ for opt in self.optimizers:
63
+ if opt.step_was_skipped:
64
+ return
65
+ if self.split_batches:
66
+ # Split batches -> the training dataloader batch size is not changed so one step per training step
67
+ self.scheduler.step(*args, **kwargs)
68
+ else:
69
+ # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
70
+ # num_processes steps per training step
71
+ num_processes = AcceleratorState().num_processes
72
+ for _ in range(num_processes):
73
+ # Special case when using OneCycle and `drop_last` was not used
74
+ if hasattr(self.scheduler, "total_steps"):
75
+ if self.scheduler._step_count <= self.scheduler.total_steps:
76
+ self.scheduler.step(*args, **kwargs)
77
+ else:
78
+ self.scheduler.step(*args, **kwargs)
79
+
80
+ # Passthroughs
81
+ def get_last_lr(self):
82
+ return self.scheduler.get_lr()
83
+
84
+ def state_dict(self):
85
+ return self.scheduler.state_dict()
86
+
87
+ def load_state_dict(self, state_dict):
88
+ self.scheduler.set_state_dict(state_dict)
89
+
90
+ set_state_dict = load_state_dict
91
+
92
+ def get_lr(self):
93
+ return self.scheduler.get_lr()
94
+
95
+ def print_lr(self, *args, **kwargs):
96
+ return self.scheduler.print_lr(*args, **kwargs)
PaddleMIX/ppdiffusers/ppdiffusers/accelerate/tracking.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Expectation:
16
+ # Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
17
+
18
+ import json
19
+ import os
20
+ import time
21
+ from functools import wraps
22
+ from typing import Any, Dict, List, Optional, Union
23
+
24
+ import yaml
25
+
26
+ from .logging import get_logger
27
+ from .state import PartialState
28
+ from .utils import (
29
+ LoggerType,
30
+ is_aim_available,
31
+ is_clearml_available,
32
+ is_comet_ml_available,
33
+ is_dvclive_available,
34
+ is_mlflow_available,
35
+ is_tensorboard_available,
36
+ is_visualdl_available,
37
+ is_wandb_available,
38
+ listify,
39
+ )
40
+
41
+ _available_trackers = []
42
+
43
+ if is_tensorboard_available():
44
+ _available_trackers.append(LoggerType.TENSORBOARD)
45
+
46
+ if is_wandb_available():
47
+ _available_trackers.append(LoggerType.WANDB)
48
+
49
+ if is_comet_ml_available():
50
+ _available_trackers.append(LoggerType.COMETML)
51
+
52
+ if is_aim_available():
53
+ _available_trackers.append(LoggerType.AIM)
54
+
55
+ if is_mlflow_available():
56
+ _available_trackers.append(LoggerType.MLFLOW)
57
+
58
+ if is_clearml_available():
59
+ _available_trackers.append(LoggerType.CLEARML)
60
+
61
+ if is_dvclive_available():
62
+ _available_trackers.append(LoggerType.DVCLIVE)
63
+
64
+ if is_visualdl_available():
65
+ _available_trackers.append(LoggerType.VISUALDL)
66
+
67
+ logger = get_logger(__name__)
68
+
69
+
70
+ def on_main_process(function):
71
+ """
72
+ Decorator to selectively run the decorated function on the main process only based on the `main_process_only`
73
+ attribute in a class.
74
+
75
+ Checks at function execution rather than initialization time, not triggering the initialization of the
76
+ `PartialState`.
77
+ """
78
+
79
+ @wraps(function)
80
+ def execute_on_main_process(self, *args, **kwargs):
81
+ if getattr(self, "main_process_only", False):
82
+ return PartialState().on_main_process(function)(self, *args, **kwargs)
83
+ else:
84
+ return function(self, *args, **kwargs)
85
+
86
+ return execute_on_main_process
87
+
88
+
89
+ def get_available_trackers():
90
+ "Returns a list of all supported available trackers in the system"
91
+ return _available_trackers
92
+
93
+
94
+ class GeneralTracker:
95
+ """
96
+ A base Tracker class to be used for all logging integration implementations.
97
+
98
+ Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to
99
+ [`Accelerator`].
100
+
101
+ Should implement `name`, `requires_logging_directory`, and `tracker` properties such that:
102
+
103
+ `name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory`
104
+ (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
105
+ tracking mechanism used by a tracker class (such as the `run` for wandb)
106
+
107
+ Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
108
+ other functions should occur on the main process or across all processes (by default will use `True`)
109
+ """
110
+
111
+ main_process_only = True
112
+
113
+ def __init__(self, _blank=False):
114
+ if not _blank:
115
+ err = ""
116
+ if not hasattr(self, "name"):
117
+ err += "`name`"
118
+ if not hasattr(self, "requires_logging_directory"):
119
+ if len(err) > 0:
120
+ err += ", "
121
+ err += "`requires_logging_directory`"
122
+
123
+ # as tracker is a @property that relies on post-init
124
+ if "tracker" not in dir(self):
125
+ if len(err) > 0:
126
+ err += ", "
127
+ err += "`tracker`"
128
+ if len(err) > 0:
129
+ raise NotImplementedError(
130
+ f"The implementation for this tracker class is missing the following "
131
+ f"required attributes. Please define them in the class definition: "
132
+ f"{err}"
133
+ )
134
+
135
+ def store_init_configuration(self, values: dict):
136
+ """
137
+ Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration
138
+ functionality of a tracking API.
139
+
140
+ Args:
141
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
142
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
143
+ `str`, `float`, `int`, or `None`.
144
+ """
145
+ pass
146
+
147
+ def log(self, values: dict, step: Optional[int], **kwargs):
148
+ """
149
+ Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with
150
+ special behavior for the `step parameter.
151
+
152
+ Args:
153
+ values (Dictionary `str` to `str`, `float`, or `int`):
154
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
155
+ step (`int`, *optional*):
156
+ The run step. If included, the log will be affiliated with this step.
157
+ """
158
+ pass
159
+
160
+ def finish(self):
161
+ """
162
+ Should run any finalizing functions within the tracking API. If the API should not have one, just don't
163
+ overwrite that method.
164
+ """
165
+ pass
166
+
167
+
168
+ class TensorBoardTracker(GeneralTracker):
169
+ """
170
+ A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
171
+
172
+ Args:
173
+ run_name (`str`):
174
+ The name of the experiment run
175
+ logging_dir (`str`, `os.PathLike`):
176
+ Location for TensorBoard logs to be stored.
177
+ kwargs:
178
+ Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
179
+ """
180
+
181
+ name = "tensorboard"
182
+ requires_logging_directory = True
183
+
184
+ @on_main_process
185
+ def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
186
+ try:
187
+ from torch.utils import tensorboard
188
+ except ModuleNotFoundError:
189
+ import tensorboardX as tensorboard
190
+ super().__init__()
191
+ self.run_name = run_name
192
+ self.logging_dir = os.path.join(logging_dir, run_name)
193
+ self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs)
194
+ logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
195
+ logger.debug(
196
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
197
+ )
198
+
199
+ @property
200
+ def tracker(self):
201
+ return self.writer
202
+
203
+ @on_main_process
204
+ def store_init_configuration(self, values: dict):
205
+ """
206
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
207
+ hyperparameters in a yaml file for future use.
208
+
209
+ Args:
210
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
211
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
212
+ `str`, `float`, `int`, or `None`.
213
+ """
214
+ self.writer.add_hparams(values, metric_dict={})
215
+ self.writer.flush()
216
+ project_run_name = time.time()
217
+ dir_name = os.path.join(self.logging_dir, str(project_run_name))
218
+ os.makedirs(dir_name, exist_ok=True)
219
+ with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
220
+ try:
221
+ yaml.dump(values, outfile)
222
+ except yaml.representer.RepresenterError:
223
+ logger.error("Serialization to store hyperparameters failed")
224
+ raise
225
+ logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
226
+
227
+ @on_main_process
228
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
229
+ """
230
+ Logs `values` to the current run.
231
+
232
+ Args:
233
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
234
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
235
+ `str` to `float`/`int`.
236
+ step (`int`, *optional*):
237
+ The run step. If included, the log will be affiliated with this step.
238
+ kwargs:
239
+ Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
240
+ `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
241
+ """
242
+ values = listify(values)
243
+ for k, v in values.items():
244
+ if isinstance(v, (int, float)):
245
+ self.writer.add_scalar(k, v, global_step=step, **kwargs)
246
+ elif isinstance(v, str):
247
+ self.writer.add_text(k, v, global_step=step, **kwargs)
248
+ elif isinstance(v, dict):
249
+ self.writer.add_scalars(k, v, global_step=step, **kwargs)
250
+ self.writer.flush()
251
+ logger.debug("Successfully logged to TensorBoard")
252
+
253
+ @on_main_process
254
+ def log_images(self, values: dict, step: Optional[int], **kwargs):
255
+ """
256
+ Logs `images` to the current run.
257
+
258
+ Args:
259
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
260
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
261
+ step (`int`, *optional*):
262
+ The run step. If included, the log will be affiliated with this step.
263
+ kwargs:
264
+ Additional key word arguments passed along to the `SummaryWriter.add_image` method.
265
+ """
266
+ for k, v in values.items():
267
+ self.writer.add_images(k, v, global_step=step, **kwargs)
268
+ logger.debug("Successfully logged images to TensorBoard")
269
+
270
+ @on_main_process
271
+ def finish(self):
272
+ """
273
+ Closes `TensorBoard` writer
274
+ """
275
+ self.writer.close()
276
+ logger.debug("TensorBoard writer closed")
277
+
278
+
279
+ class VisualdlTracker(GeneralTracker):
280
+ """
281
+ A `Tracker` class that supports `visualdl`. Should be initialized at the start of your script.
282
+
283
+ Args:
284
+ run_name (`str`):
285
+ The name of the experiment run
286
+ logging_dir (`str`, `os.PathLike`):
287
+ Location for Visualdl logs to be stored.
288
+ kwargs:
289
+ Additional key word arguments passed along to the `visualdl.LogWriter.__init__` method.
290
+ """
291
+
292
+ name = "visualdl"
293
+ requires_logging_directory = True
294
+
295
+ @on_main_process
296
+ def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs):
297
+ super().__init__()
298
+ from visualdl import LogWriter
299
+
300
+ self.run_name = run_name
301
+ self.logging_dir = os.path.join(logging_dir, run_name)
302
+ self.writer = LogWriter(self.logging_dir, **kwargs)
303
+ logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}")
304
+ logger.debug(
305
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
306
+ )
307
+
308
+ @property
309
+ def tracker(self):
310
+ return self.writer
311
+
312
+ @on_main_process
313
+ def store_init_configuration(self, values: dict):
314
+ """
315
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
316
+ hyperparameters in a yaml file for future use.
317
+
318
+ Args:
319
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
320
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
321
+ `str`, `float`, `int`, or `None`.
322
+ """
323
+ self.writer.add_hparams(hparams_dict=values, metrics_list=[])
324
+ self.writer.flush()
325
+ project_run_name = time.time()
326
+ dir_name = os.path.join(self.logging_dir, str(project_run_name))
327
+ os.makedirs(dir_name, exist_ok=True)
328
+ with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile:
329
+ try:
330
+ yaml.dump(values, outfile)
331
+ except yaml.representer.RepresenterError:
332
+ logger.error("Serialization to store hyperparameters failed")
333
+ raise
334
+ logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file")
335
+
336
+ @on_main_process
337
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
338
+ """
339
+ Logs `values` to the current run.
340
+
341
+ Args:
342
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
343
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
344
+ `str` to `float`/`int`.
345
+ step (`int`, *optional*):
346
+ The run step. If included, the log will be affiliated with this step.
347
+ kwargs:
348
+ Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
349
+ `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
350
+ """
351
+ values = listify(values)
352
+ for k, v in values.items():
353
+ if isinstance(v, (int, float)):
354
+ self.writer.add_scalar(k, v, step=step, **kwargs)
355
+ elif isinstance(v, str):
356
+ self.writer.add_text(k, v, step=step, **kwargs)
357
+ elif isinstance(v, dict):
358
+ self.writer.add_scalars(k, v, step=step, **kwargs)
359
+ self.writer.flush()
360
+ logger.debug("Successfully logged to Visualdl")
361
+
362
+ @on_main_process
363
+ def log_images(self, values: dict, step: Optional[int], **kwargs):
364
+ """
365
+ Logs `images` to the current run.
366
+
367
+ Args:
368
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
369
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
370
+ step (`int`, *optional*):
371
+ The run step. If included, the log will be affiliated with this step.
372
+ kwargs:
373
+ Additional key word arguments passed along to the `SummaryWriter.add_image` method.
374
+ """
375
+ for k, v in values.items():
376
+ self.writer.add_image(k, v, step=step, **kwargs)
377
+ logger.debug("Successfully logged images to Visualdl")
378
+
379
+ @on_main_process
380
+ def finish(self):
381
+ """
382
+ Closes `VisualDL` writer
383
+ """
384
+ self.writer.close()
385
+ logger.debug("VisualDL writer closed")
386
+
387
+
388
+ class WandBTracker(GeneralTracker):
389
+ """
390
+ A `Tracker` class that supports `wandb`. Should be initialized at the start of your script.
391
+
392
+ Args:
393
+ run_name (`str`):
394
+ The name of the experiment run.
395
+ kwargs:
396
+ Additional key word arguments passed along to the `wandb.init` method.
397
+ """
398
+
399
+ name = "wandb"
400
+ requires_logging_directory = False
401
+ main_process_only = False
402
+
403
+ @on_main_process
404
+ def __init__(self, run_name: str, **kwargs):
405
+ super().__init__()
406
+ self.run_name = run_name
407
+
408
+ import wandb
409
+
410
+ self.run = wandb.init(project=self.run_name, **kwargs)
411
+ logger.debug(f"Initialized WandB project {self.run_name}")
412
+ logger.debug(
413
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
414
+ )
415
+
416
+ @property
417
+ def tracker(self):
418
+ return self.run
419
+
420
+ @on_main_process
421
+ def store_init_configuration(self, values: dict):
422
+ """
423
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
424
+
425
+ Args:
426
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
427
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
428
+ `str`, `float`, `int`, or `None`.
429
+ """
430
+ import wandb
431
+
432
+ wandb.config.update(values, allow_val_change=True)
433
+ logger.debug("Stored initial configuration hyperparameters to WandB")
434
+
435
+ @on_main_process
436
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
437
+ """
438
+ Logs `values` to the current run.
439
+
440
+ Args:
441
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
442
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
443
+ `str` to `float`/`int`.
444
+ step (`int`, *optional*):
445
+ The run step. If included, the log will be affiliated with this step.
446
+ kwargs:
447
+ Additional key word arguments passed along to the `wandb.log` method.
448
+ """
449
+ self.run.log(values, step=step, **kwargs)
450
+ logger.debug("Successfully logged to WandB")
451
+
452
+ @on_main_process
453
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
454
+ """
455
+ Logs `images` to the current run.
456
+
457
+ Args:
458
+ values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
459
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
460
+ step (`int`, *optional*):
461
+ The run step. If included, the log will be affiliated with this step.
462
+ kwargs:
463
+ Additional key word arguments passed along to the `wandb.log` method.
464
+ """
465
+ import wandb
466
+
467
+ for k, v in values.items():
468
+ self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs)
469
+ logger.debug("Successfully logged images to WandB")
470
+
471
+ @on_main_process
472
+ def log_table(
473
+ self,
474
+ table_name: str,
475
+ columns: List[str] = None,
476
+ data: List[List[Any]] = None,
477
+ dataframe: Any = None,
478
+ step: Optional[int] = None,
479
+ **kwargs,
480
+ ):
481
+ """
482
+ Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
483
+ with `columns` and `data` or with `dataframe`.
484
+
485
+ Args:
486
+ table_name (`str`):
487
+ The name to give to the logged table on the wandb workspace
488
+ columns (list of `str`, *optional*):
489
+ The name of the columns on the table
490
+ data (List of List of Any data type, *optional*):
491
+ The data to be logged in the table
492
+ dataframe (Any data type, *optional*):
493
+ The data to be logged in the table
494
+ step (`int`, *optional*):
495
+ The run step. If included, the log will be affiliated with this step.
496
+ """
497
+ import wandb
498
+
499
+ values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
500
+ self.log(values, step=step, **kwargs)
501
+
502
+ @on_main_process
503
+ def finish(self):
504
+ """
505
+ Closes `wandb` writer
506
+ """
507
+ self.run.finish()
508
+ logger.debug("WandB run closed")
509
+
510
+
511
+ class CometMLTracker(GeneralTracker):
512
+ """
513
+ A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
514
+
515
+ API keys must be stored in a Comet config file.
516
+
517
+ Args:
518
+ run_name (`str`):
519
+ The name of the experiment run.
520
+ kwargs:
521
+ Additional key word arguments passed along to the `Experiment.__init__` method.
522
+ """
523
+
524
+ name = "comet_ml"
525
+ requires_logging_directory = False
526
+
527
+ @on_main_process
528
+ def __init__(self, run_name: str, **kwargs):
529
+ super().__init__()
530
+ self.run_name = run_name
531
+
532
+ from comet_ml import Experiment
533
+
534
+ self.writer = Experiment(project_name=run_name, **kwargs)
535
+ logger.debug(f"Initialized CometML project {self.run_name}")
536
+ logger.debug(
537
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
538
+ )
539
+
540
+ @property
541
+ def tracker(self):
542
+ return self.writer
543
+
544
+ @on_main_process
545
+ def store_init_configuration(self, values: dict):
546
+ """
547
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
548
+
549
+ Args:
550
+ values (Dictionary `str` to `bool`, `str`, `float` or `int`):
551
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
552
+ `str`, `float`, `int`, or `None`.
553
+ """
554
+ self.writer.log_parameters(values)
555
+ logger.debug("Stored initial configuration hyperparameters to CometML")
556
+
557
+ @on_main_process
558
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
559
+ """
560
+ Logs `values` to the current run.
561
+
562
+ Args:
563
+ values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
564
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
565
+ `str` to `float`/`int`.
566
+ step (`int`, *optional*):
567
+ The run step. If included, the log will be affiliated with this step.
568
+ kwargs:
569
+ Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`,
570
+ or `Experiment.log_metrics` method based on the contents of `values`.
571
+ """
572
+ if step is not None:
573
+ self.writer.set_step(step)
574
+ for k, v in values.items():
575
+ if isinstance(v, (int, float)):
576
+ self.writer.log_metric(k, v, step=step, **kwargs)
577
+ elif isinstance(v, str):
578
+ self.writer.log_other(k, v, **kwargs)
579
+ elif isinstance(v, dict):
580
+ self.writer.log_metrics(v, step=step, **kwargs)
581
+ logger.debug("Successfully logged to CometML")
582
+
583
+ @on_main_process
584
+ def finish(self):
585
+ """
586
+ Closes `comet-ml` writer
587
+ """
588
+ self.writer.end()
589
+ logger.debug("CometML run closed")
590
+
591
+
592
+ class AimTracker(GeneralTracker):
593
+ """
594
+ A `Tracker` class that supports `aim`. Should be initialized at the start of your script.
595
+
596
+ Args:
597
+ run_name (`str`):
598
+ The name of the experiment run.
599
+ kwargs:
600
+ Additional key word arguments passed along to the `Run.__init__` method.
601
+ """
602
+
603
+ name = "aim"
604
+ requires_logging_directory = True
605
+
606
+ @on_main_process
607
+ def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs):
608
+ self.run_name = run_name
609
+
610
+ from aim import Run
611
+
612
+ self.writer = Run(repo=logging_dir, **kwargs)
613
+ self.writer.name = self.run_name
614
+ logger.debug(f"Initialized Aim project {self.run_name}")
615
+ logger.debug(
616
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
617
+ )
618
+
619
+ @property
620
+ def tracker(self):
621
+ return self.writer
622
+
623
+ @on_main_process
624
+ def store_init_configuration(self, values: dict):
625
+ """
626
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
627
+
628
+ Args:
629
+ values (`dict`):
630
+ Values to be stored as initial hyperparameters as key-value pairs.
631
+ """
632
+ self.writer["hparams"] = values
633
+
634
+ @on_main_process
635
+ def log(self, values: dict, step: Optional[int], **kwargs):
636
+ """
637
+ Logs `values` to the current run.
638
+
639
+ Args:
640
+ values (`dict`):
641
+ Values to be logged as key-value pairs.
642
+ step (`int`, *optional*):
643
+ The run step. If included, the log will be affiliated with this step.
644
+ kwargs:
645
+ Additional key word arguments passed along to the `Run.track` method.
646
+ """
647
+ # Note: replace this with the dictionary support when merged
648
+ for key, value in values.items():
649
+ self.writer.track(value, name=key, step=step, **kwargs)
650
+
651
+ @on_main_process
652
+ def finish(self):
653
+ """
654
+ Closes `aim` writer
655
+ """
656
+ self.writer.close()
657
+
658
+
659
+ class MLflowTracker(GeneralTracker):
660
+ """
661
+ A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script.
662
+
663
+ Args:
664
+ experiment_name (`str`, *optional*):
665
+ Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument.
666
+ logging_dir (`str` or `os.PathLike`, defaults to `"."`):
667
+ Location for mlflow logs to be stored.
668
+ run_id (`str`, *optional*):
669
+ If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s
670
+ end time is unset and its status is set to running, but the run’s other attributes (source_version,
671
+ source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument.
672
+ tags (`Dict[str, str]`, *optional*):
673
+ An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a
674
+ run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are
675
+ set on the new run. Environment variable MLFLOW_TAGS has priority over this argument.
676
+ nested_run (`bool`, *optional*, defaults to `False`):
677
+ Controls whether run is nested in parent run. True creates a nested run. Environment variable
678
+ MLFLOW_NESTED_RUN has priority over this argument.
679
+ run_name (`str`, *optional*):
680
+ Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified.
681
+ description (`str`, *optional*):
682
+ An optional string that populates the description box of the run. If a run is being resumed, the
683
+ description is set on the resumed run. If a new run is being created, the description is set on the new
684
+ run.
685
+ """
686
+
687
+ name = "mlflow"
688
+ requires_logging_directory = False
689
+
690
+ @on_main_process
691
+ def __init__(
692
+ self,
693
+ experiment_name: str = None,
694
+ logging_dir: Optional[Union[str, os.PathLike]] = None,
695
+ run_id: Optional[str] = None,
696
+ tags: Optional[Union[Dict[str, Any], str]] = None,
697
+ nested_run: Optional[bool] = False,
698
+ run_name: Optional[str] = None,
699
+ description: Optional[str] = None,
700
+ ):
701
+ experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", experiment_name)
702
+ run_id = os.getenv("MLFLOW_RUN_ID", run_id)
703
+ tags = os.getenv("MLFLOW_TAGS", tags)
704
+ if isinstance(tags, str):
705
+ tags = json.loads(tags)
706
+
707
+ nested_run = os.getenv("MLFLOW_NESTED_RUN", nested_run)
708
+
709
+ import mlflow
710
+
711
+ exps = mlflow.search_experiments(filter_string=f"name = '{experiment_name}'")
712
+ if len(exps) > 0:
713
+ if len(exps) > 1:
714
+ logger.warning("Multiple experiments with the same name found. Using first one.")
715
+ experiment_id = exps[0].experiment_id
716
+ else:
717
+ experiment_id = mlflow.create_experiment(
718
+ name=experiment_name,
719
+ artifact_location=logging_dir,
720
+ tags=tags,
721
+ )
722
+
723
+ self.active_run = mlflow.start_run(
724
+ run_id=run_id,
725
+ experiment_id=experiment_id,
726
+ run_name=run_name,
727
+ nested=nested_run,
728
+ tags=tags,
729
+ description=description,
730
+ )
731
+
732
+ logger.debug(f"Initialized mlflow experiment {experiment_name}")
733
+ logger.debug(
734
+ "Make sure to log any initial configurations with `self.store_init_configuration` before training!"
735
+ )
736
+
737
+ @property
738
+ def tracker(self):
739
+ return self.active_run
740
+
741
+ @on_main_process
742
+ def store_init_configuration(self, values: dict):
743
+ """
744
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
745
+
746
+ Args:
747
+ values (`dict`):
748
+ Values to be stored as initial hyperparameters as key-value pairs.
749
+ """
750
+ import mlflow
751
+
752
+ for name, value in list(values.items()):
753
+ # internally, all values are converted to str in MLflow
754
+ if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
755
+ logger.warning_once(
756
+ f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
757
+ f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute."
758
+ )
759
+ del values[name]
760
+
761
+ values_list = list(values.items())
762
+
763
+ # MLflow cannot log more than 100 values in one go, so we have to split it
764
+ for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH):
765
+ mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH]))
766
+
767
+ logger.debug("Stored initial configuration hyperparameters to MLflow")
768
+
769
+ @on_main_process
770
+ def log(self, values: dict, step: Optional[int]):
771
+ """
772
+ Logs `values` to the current run.
773
+
774
+ Args:
775
+ values (`dict`):
776
+ Values to be logged as key-value pairs.
777
+ step (`int`, *optional*):
778
+ The run step. If included, the log will be affiliated with this step.
779
+ """
780
+ metrics = {}
781
+ for k, v in values.items():
782
+ if isinstance(v, (int, float)):
783
+ metrics[k] = v
784
+ else:
785
+ logger.warning_once(
786
+ f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
787
+ "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
788
+ )
789
+ import mlflow
790
+
791
+ mlflow.log_metrics(metrics, step=step)
792
+ logger.debug("Successfully logged to mlflow")
793
+
794
+ @on_main_process
795
+ def finish(self):
796
+ """
797
+ End the active MLflow run.
798
+ """
799
+ import mlflow
800
+
801
+ mlflow.end_run()
802
+
803
+
804
+ class ClearMLTracker(GeneralTracker):
805
+ """
806
+ A `Tracker` class that supports `clearml`. Should be initialized at the start of your script.
807
+
808
+ Args:
809
+ run_name (`str`, *optional*):
810
+ Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this
811
+ argument.
812
+ kwargs:
813
+ Kwargs passed along to the `Task.__init__` method.
814
+ """
815
+
816
+ name = "clearml"
817
+ requires_logging_directory = False
818
+
819
+ @on_main_process
820
+ def __init__(self, run_name: str = None, **kwargs):
821
+ from clearml import Task
822
+
823
+ current_task = Task.current_task()
824
+ self._initialized_externally = False
825
+ if current_task:
826
+ self._initialized_externally = True
827
+ self.task = current_task
828
+ return
829
+
830
+ kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name))
831
+ kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name))
832
+ self.task = Task.init(**kwargs)
833
+
834
+ @property
835
+ def tracker(self):
836
+ return self.task
837
+
838
+ @on_main_process
839
+ def store_init_configuration(self, values: dict):
840
+ """
841
+ Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment.
842
+
843
+ Args:
844
+ values (`dict`):
845
+ Values to be stored as initial hyperparameters as key-value pairs.
846
+ """
847
+ return self.task.connect_configuration(values)
848
+
849
+ @on_main_process
850
+ def log(self, values: Dict[str, Union[int, float]], step: Optional[int] = None, **kwargs):
851
+ """
852
+ Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be
853
+ ints or floats
854
+
855
+ Args:
856
+ values (`Dict[str, Union[int, float]]`):
857
+ Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will
858
+ be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed.
859
+ Otherwise, the value will be reported under the 'train' series, and no prefix will be removed.
860
+ step (`int`, *optional*):
861
+ If specified, the values will be reported as scalars, with the iteration number equal to `step`.
862
+ Otherwise they will be reported as single values.
863
+ kwargs:
864
+ Additional key word arguments passed along to the `clearml.Logger.report_single_value` or
865
+ `clearml.Logger.report_scalar` methods.
866
+ """
867
+ clearml_logger = self.task.get_logger()
868
+ for k, v in values.items():
869
+ if not isinstance(v, (int, float)):
870
+ logger.warning_once(
871
+ "Accelerator is attempting to log a value of "
872
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
873
+ "This invocation of ClearML logger's report_scalar() "
874
+ "is incorrect so we dropped this attribute."
875
+ )
876
+ continue
877
+ if step is None:
878
+ clearml_logger.report_single_value(name=k, value=v, **kwargs)
879
+ continue
880
+ title, series = ClearMLTracker._get_title_series(k)
881
+ clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs)
882
+
883
+ @on_main_process
884
+ def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
885
+ """
886
+ Logs `images` to the current run.
887
+
888
+ Args:
889
+ values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`):
890
+ Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
891
+ step (`int`, *optional*):
892
+ The run step. If included, the log will be affiliated with this step.
893
+ kwargs:
894
+ Additional key word arguments passed along to the `clearml.Logger.report_image` method.
895
+ """
896
+ clearml_logger = self.task.get_logger()
897
+ for k, v in values.items():
898
+ title, series = ClearMLTracker._get_title_series(k)
899
+ clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs)
900
+
901
+ @on_main_process
902
+ def log_table(
903
+ self,
904
+ table_name: str,
905
+ columns: List[str] = None,
906
+ data: List[List[Any]] = None,
907
+ dataframe: Any = None,
908
+ step: Optional[int] = None,
909
+ **kwargs,
910
+ ):
911
+ """
912
+ Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`.
913
+
914
+ Args:
915
+ table_name (`str`):
916
+ The name of the table
917
+ columns (list of `str`, *optional*):
918
+ The name of the columns on the table
919
+ data (List of List of Any data type, *optional*):
920
+ The data to be logged in the table. If `columns` is not specified, then the first entry in data will be
921
+ the name of the columns of the table
922
+ dataframe (Any data type, *optional*):
923
+ The data to be logged in the table
924
+ step (`int`, *optional*):
925
+ The run step. If included, the log will be affiliated with this step.
926
+ kwargs:
927
+ Additional key word arguments passed along to the `clearml.Logger.report_table` method.
928
+ """
929
+ to_report = dataframe
930
+ if dataframe is None:
931
+ if data is None:
932
+ raise ValueError(
933
+ "`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`"
934
+ )
935
+ to_report = [columns] + data if columns else data
936
+ title, series = ClearMLTracker._get_title_series(table_name)
937
+ self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs)
938
+
939
+ @on_main_process
940
+ def finish(self):
941
+ """
942
+ Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this
943
+ function is a noop
944
+ """
945
+ if self.task and not self._initialized_externally:
946
+ self.task.close()
947
+
948
+ @staticmethod
949
+ def _get_title_series(name):
950
+ for prefix in ["eval", "test", "train"]:
951
+ if name.startswith(prefix + "_"):
952
+ return name[len(prefix) + 1 :], prefix
953
+ return name, "train"
954
+
955
+
956
+ class DVCLiveTracker(GeneralTracker):
957
+ """
958
+ A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
959
+
960
+ Args:
961
+ run_name (`str`, *optional*):
962
+ Ignored for dvclive. See `kwargs` instead.
963
+ kwargs:
964
+ Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live).
965
+
966
+ Example:
967
+
968
+ ```py
969
+ from accelerate import Accelerator
970
+
971
+ accelerator = Accelerator(log_with="dvclive")
972
+ accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}})
973
+ ```
974
+ """
975
+
976
+ name = "dvclive"
977
+ requires_logging_directory = False
978
+
979
+ @on_main_process
980
+ def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
981
+ from dvclive import Live
982
+
983
+ super().__init__()
984
+ self.live = live if live is not None else Live(**kwargs)
985
+
986
+ @property
987
+ def tracker(self):
988
+ return self.live
989
+
990
+ @on_main_process
991
+ def store_init_configuration(self, values: dict):
992
+ """
993
+ Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
994
+ hyperparameters in a yaml file for future use.
995
+
996
+ Args:
997
+ values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
998
+ Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
999
+ `str`, `float`, or `int`.
1000
+ """
1001
+ self.live.log_params(values)
1002
+
1003
+ @on_main_process
1004
+ def log(self, values: dict, step: Optional[int] = None, **kwargs):
1005
+ """
1006
+ Logs `values` to the current run.
1007
+
1008
+ Args:
1009
+ values (Dictionary `str` to `str`, `float`, or `int`):
1010
+ Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
1011
+ step (`int`, *optional*):
1012
+ The run step. If included, the log will be affiliated with this step.
1013
+ kwargs:
1014
+ Additional key word arguments passed along to `dvclive.Live.log_metric()`.
1015
+ """
1016
+ from dvclive.plots import Metric
1017
+
1018
+ if step is not None:
1019
+ self.live.step = step
1020
+ for k, v in values.items():
1021
+ if Metric.could_log(v):
1022
+ self.live.log_metric(k, v, **kwargs)
1023
+ else:
1024
+ logger.warning_once(
1025
+ "Accelerator attempted to log a value of "
1026
+ f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
1027
+ "This invocation of DVCLive's Live.log_metric() "
1028
+ "is incorrect so we dropped this attribute."
1029
+ )
1030
+
1031
+ @on_main_process
1032
+ def finish(self):
1033
+ """
1034
+ Closes `dvclive.Live()`.
1035
+ """
1036
+ self.live.end()
1037
+
1038
+
1039
+ LOGGER_TYPE_TO_CLASS = {
1040
+ "aim": AimTracker,
1041
+ "comet_ml": CometMLTracker,
1042
+ "mlflow": MLflowTracker,
1043
+ "tensorboard": TensorBoardTracker,
1044
+ "wandb": WandBTracker,
1045
+ "clearml": ClearMLTracker,
1046
+ "dvclive": DVCLiveTracker,
1047
+ "visualdl": VisualdlTracker,
1048
+ }
1049
+
1050
+
1051
+ def filter_trackers(
1052
+ log_with: List[Union[str, LoggerType, GeneralTracker]], logging_dir: Union[str, os.PathLike] = None
1053
+ ):
1054
+ """
1055
+ Takes in a list of potential tracker types and checks that:
1056
+ - The tracker wanted is available in that environment
1057
+ - Filters out repeats of tracker types
1058
+ - If `all` is in `log_with`, will return all trackers in the environment
1059
+ - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None`
1060
+
1061
+ Args:
1062
+ log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
1063
+ A list of loggers to be setup for experiment tracking. Should be one or several of:
1064
+
1065
+ - `"all"`
1066
+ - `"tensorboard"`
1067
+ - `"wandb"`
1068
+ - `"comet_ml"`
1069
+ - `"mlflow"`
1070
+ - `"dvclive"`
1071
+ - `"visualdl"`
1072
+ If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
1073
+ also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
1074
+ logging_dir (`str`, `os.PathLike`, *optional*):
1075
+ A path to a directory for storing logs of locally-compatible loggers.
1076
+ """
1077
+ loggers = []
1078
+ if log_with is not None:
1079
+ if not isinstance(log_with, (list, tuple)):
1080
+ log_with = [log_with]
1081
+ if "all" in log_with or LoggerType.ALL in log_with:
1082
+ loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers()
1083
+ else:
1084
+ for log_type in log_with:
1085
+ if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker):
1086
+ raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}")
1087
+ if issubclass(type(log_type), GeneralTracker):
1088
+ loggers.append(log_type)
1089
+ else:
1090
+ log_type = LoggerType(log_type)
1091
+ if log_type not in loggers:
1092
+ if log_type in get_available_trackers():
1093
+ tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)]
1094
+ if getattr(tracker_init, "requires_logging_directory"):
1095
+ if logging_dir is None:
1096
+ raise ValueError(
1097
+ f"Logging with `{log_type}` requires a `logging_dir` to be passed in."
1098
+ )
1099
+ loggers.append(log_type)
1100
+ else:
1101
+ logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.")
1102
+
1103
+ return loggers
PaddleMIX/ppdiffusers/ppdiffusers/callbacks.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from .configuration_utils import ConfigMixin, register_to_config
4
+ from .utils import CONFIG_NAME
5
+
6
+
7
+ class PipelineCallback(ConfigMixin):
8
+ """
9
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10
+ custom callbacks and ensures that all callbacks have a consistent interface.
11
+
12
+ Please implement the following:
13
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14
+ include
15
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16
+ `callback_fn`: This method defines the core functionality of your callback.
17
+ """
18
+
19
+ config_name = CONFIG_NAME
20
+
21
+ @register_to_config
22
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23
+ super().__init__()
24
+
25
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26
+ cutoff_step_ratio is not None and cutoff_step_index is not None
27
+ ):
28
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29
+
30
+ if cutoff_step_ratio is not None and (
31
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32
+ ):
33
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34
+
35
+ @property
36
+ def tensor_inputs(self) -> List[str]:
37
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38
+
39
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41
+
42
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44
+
45
+
46
+ class MultiPipelineCallbacks:
47
+ """
48
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49
+ provides a unified interface for calling all of them.
50
+ """
51
+
52
+ def __init__(self, callbacks: List[PipelineCallback]):
53
+ self.callbacks = callbacks
54
+
55
+ @property
56
+ def tensor_inputs(self) -> List[str]:
57
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
58
+
59
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60
+ """
61
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62
+ """
63
+ for callback in self.callbacks:
64
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65
+
66
+ return callback_kwargs
67
+
68
+
69
+ class SDCFGCutoffCallback(PipelineCallback):
70
+ """
71
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72
+ `cutoff_step_index`), this callback will disable the CFG.
73
+
74
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75
+ """
76
+
77
+ tensor_inputs = ["prompt_embeds"]
78
+
79
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80
+ cutoff_step_ratio = self.config.cutoff_step_ratio
81
+ cutoff_step_index = self.config.cutoff_step_index
82
+
83
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84
+ cutoff_step = (
85
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86
+ )
87
+
88
+ if step_index == cutoff_step:
89
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91
+
92
+ pipeline._guidance_scale = 0.0
93
+
94
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95
+ return callback_kwargs
96
+
97
+
98
+ class SDXLCFGCutoffCallback(PipelineCallback):
99
+ """
100
+ Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101
+ `cutoff_step_index`), this callback will disable the CFG.
102
+
103
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
+ """
105
+
106
+ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
107
+
108
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109
+ cutoff_step_ratio = self.config.cutoff_step_ratio
110
+ cutoff_step_index = self.config.cutoff_step_index
111
+
112
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
113
+ cutoff_step = (
114
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
115
+ )
116
+
117
+ if step_index == cutoff_step:
118
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
119
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
120
+
121
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
122
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
123
+
124
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
125
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
126
+
127
+ pipeline._guidance_scale = 0.0
128
+
129
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
132
+ return callback_kwargs
133
+
134
+
135
+ class IPAdapterScaleCutoffCallback(PipelineCallback):
136
+ """
137
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
138
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
139
+
140
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
141
+ """
142
+
143
+ tensor_inputs = []
144
+
145
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
146
+ cutoff_step_ratio = self.config.cutoff_step_ratio
147
+ cutoff_step_index = self.config.cutoff_step_index
148
+
149
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
150
+ cutoff_step = (
151
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
152
+ )
153
+
154
+ if step_index == cutoff_step:
155
+ pipeline.set_ip_adapter_scale(0.0)
156
+ return callback_kwargs
PaddleMIX/ppdiffusers/ppdiffusers/configuration_utils.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import functools
18
+ import importlib
19
+ import inspect
20
+ import json
21
+ import os
22
+ import re
23
+ from collections import OrderedDict
24
+ from pathlib import PosixPath
25
+ from typing import Any, Dict, Tuple, Union
26
+
27
+ from .utils.download_utils import SaveToAistudioMixin
28
+ from .utils.hub_utils import PushToHubMixin
29
+
30
+ try:
31
+ from omegaconf.listconfig import ListConfig
32
+
33
+ _omegaconf_available = True
34
+ except:
35
+ _omegaconf_available = False
36
+ import numpy as np
37
+ from huggingface_hub import create_repo
38
+
39
+ from .utils import (
40
+ DIFFUSERS_CACHE,
41
+ FROM_AISTUDIO,
42
+ FROM_HF_HUB,
43
+ PPDIFFUSERS_CACHE,
44
+ DummyObject,
45
+ bos_aistudio_hf_download,
46
+ deprecate,
47
+ extract_commit_hash,
48
+ http_user_agent,
49
+ logging,
50
+ )
51
+ from .version import VERSION as __version__
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
56
+
57
+
58
+ class FrozenDict(OrderedDict):
59
+ def __init__(self, *args, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+
62
+ for key, value in self.items():
63
+ setattr(self, key, value)
64
+
65
+ self.__frozen = True
66
+
67
+ def __delitem__(self, *args, **kwargs):
68
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
69
+
70
+ def setdefault(self, *args, **kwargs):
71
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
72
+
73
+ def pop(self, *args, **kwargs):
74
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
75
+
76
+ def update(self, *args, **kwargs):
77
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
78
+
79
+ def __setattr__(self, name, value):
80
+ if hasattr(self, "__frozen") and self.__frozen:
81
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
82
+ super().__setattr__(name, value)
83
+
84
+ def __setitem__(self, name, value):
85
+ if hasattr(self, "__frozen") and self.__frozen:
86
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
87
+ super().__setitem__(name, value)
88
+
89
+
90
+ class ConfigMixin(PushToHubMixin, SaveToAistudioMixin):
91
+ r"""
92
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
93
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
94
+ saving classes that inherit from [`ConfigMixin`].
95
+
96
+ Class attributes:
97
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
98
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
99
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
100
+ overridden by subclass).
101
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
102
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
103
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
104
+ subclass).
105
+ """
106
+
107
+ config_name = None
108
+ ignore_for_config = []
109
+ has_compatibles = False
110
+
111
+ _deprecated_kwargs = []
112
+
113
+ def register_to_config(self, **kwargs):
114
+ if self.config_name is None:
115
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
116
+ # Special case for `kwargs` used in deprecation warning added to schedulers
117
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
118
+ # or solve in a more general way.
119
+ kwargs.pop("kwargs", None)
120
+
121
+ if not hasattr(self, "_internal_dict"):
122
+ internal_dict = kwargs
123
+ else:
124
+ previous_dict = dict(self._internal_dict)
125
+ internal_dict = {**self._internal_dict, **kwargs}
126
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
127
+
128
+ self._internal_dict = FrozenDict(internal_dict)
129
+
130
+ def __getattr__(self, name: str) -> Any:
131
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
132
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
133
+
134
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
135
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
136
+ """
137
+
138
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
139
+ is_attribute = name in self.__dict__
140
+
141
+ if is_in_config and not is_attribute:
142
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
143
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
144
+ return self._internal_dict[name]
145
+
146
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
147
+
148
+ def save_config(
149
+ self,
150
+ save_directory: Union[str, os.PathLike],
151
+ push_to_hub: bool = False,
152
+ save_to_aistudio: bool = False,
153
+ to_diffusers: bool = False,
154
+ **kwargs,
155
+ ):
156
+ """
157
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
158
+ [`~ConfigMixin.from_config`] class method.
159
+
160
+ Args:
161
+ save_directory (`str` or `os.PathLike`):
162
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
163
+ push_to_hub (`bool`, *optional*, defaults to `False`):
164
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
165
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
166
+ namespace).
167
+ kwargs (`Dict[str, Any]`, *optional*):
168
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
169
+ """
170
+ if os.path.isfile(save_directory):
171
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
172
+
173
+ os.makedirs(save_directory, exist_ok=True)
174
+
175
+ # If we save using the predefined names, we can load using `from_config`
176
+ output_config_file = os.path.join(save_directory, self.config_name)
177
+
178
+ self.to_json_file(output_config_file, to_diffusers=to_diffusers)
179
+ logger.info(f"Configuration saved in {output_config_file}")
180
+
181
+ commit_message = kwargs.pop("commit_message", None)
182
+ create_pr = kwargs.pop("create_pr", False)
183
+ token = kwargs.pop("token", None)
184
+ token_kwargs = {}
185
+ if token is not None:
186
+ token_kwargs["token"] = token
187
+ private = kwargs.pop("private", False)
188
+ exist_ok = kwargs.pop("exist_ok", True)
189
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
190
+ license = kwargs.pop("license", "creativeml-openrail-m")
191
+
192
+ if save_to_aistudio:
193
+ from aistudio_sdk.hub import create_repo as aistudio_create_repo
194
+
195
+ assert "/" in repo_id, "Please specify the repo id in format of `user_id/repo_name`"
196
+ res = aistudio_create_repo(repo_id=repo_id, private=private, license=license, **token_kwargs)
197
+ if "error_code" in res:
198
+ if res["error_code"] == 10003 and exist_ok:
199
+ logger.info(
200
+ f"Repo {repo_id} already exists, it will override files with the same name. To avoid this, please set exist_ok=False"
201
+ )
202
+ else:
203
+ logger.error(
204
+ f"Failed to create repo {repo_id}, error_code: {res['error_code']}, error_msg: {res['error_msg']}"
205
+ )
206
+ else:
207
+ logger.info(f"Successfully created repo {repo_id}")
208
+ self._upload_folder_aistudio(
209
+ save_directory,
210
+ repo_id,
211
+ commit_message=commit_message,
212
+ **token_kwargs,
213
+ )
214
+
215
+ if push_to_hub:
216
+ repo_id = create_repo(repo_id, exist_ok=exist_ok, private=private, **token_kwargs).repo_id
217
+ self._upload_folder(
218
+ save_directory,
219
+ repo_id,
220
+ commit_message=commit_message,
221
+ create_pr=create_pr,
222
+ **token_kwargs,
223
+ )
224
+
225
+ @classmethod
226
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
227
+ r"""
228
+ Instantiate a Python class from a config dictionary.
229
+
230
+ Parameters:
231
+ config (`Dict[str, Any]`):
232
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
233
+ files of compatible classes.
234
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
235
+ Whether kwargs that are not consumed by the Python class should be returned or not.
236
+ kwargs (remaining dictionary of keyword arguments, *optional*):
237
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
238
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
239
+ overwrite the same named arguments in `config`.
240
+
241
+ Returns:
242
+ [`ModelMixin`] or [`SchedulerMixin`]:
243
+ A model or scheduler object instantiated from a config dictionary.
244
+
245
+ Examples:
246
+
247
+ ```python
248
+ >>> from ppdiffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
249
+
250
+ >>> # Download scheduler from huggingface.co and cache.
251
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
252
+
253
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
254
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
255
+
256
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
257
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
258
+ ```
259
+ """
260
+ # <===== TO BE REMOVED WITH DEPRECATION
261
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
262
+ if "pretrained_model_name_or_path" in kwargs:
263
+ config = kwargs.pop("pretrained_model_name_or_path")
264
+
265
+ if config is None:
266
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
267
+ # ======>
268
+
269
+ if not isinstance(config, dict):
270
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
271
+ if "Scheduler" in cls.__name__:
272
+ deprecation_message += (
273
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
274
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
275
+ " be removed in v1.0.0."
276
+ )
277
+ elif "Model" in cls.__name__:
278
+ deprecation_message += (
279
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
280
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
281
+ " instead. This functionality will be removed in v1.0.0."
282
+ )
283
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
284
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
285
+
286
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
287
+
288
+ # Allow dtype to be specified on initialization
289
+ if "dtype" in unused_kwargs:
290
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
291
+
292
+ # add possible deprecated kwargs
293
+ for deprecated_kwarg in cls._deprecated_kwargs:
294
+ if deprecated_kwarg in unused_kwargs:
295
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
296
+
297
+ # Return model and optionally state and/or unused_kwargs
298
+ model = cls(**init_dict)
299
+
300
+ # make sure to also save config parameters that might be used for compatible classes
301
+ model.register_to_config(**hidden_dict)
302
+
303
+ # add hidden kwargs of compatible classes to unused_kwargs
304
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
305
+
306
+ if return_unused_kwargs:
307
+ return (model, unused_kwargs)
308
+ else:
309
+ return model
310
+
311
+ @classmethod
312
+ def get_config_dict(cls, *args, **kwargs):
313
+ deprecation_message = (
314
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
315
+ " removed in version v1.0.0"
316
+ )
317
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
318
+ return cls.load_config(*args, **kwargs)
319
+
320
+ @classmethod
321
+ def load_config(
322
+ cls,
323
+ pretrained_model_name_or_path: Union[str, os.PathLike],
324
+ return_unused_kwargs=False,
325
+ return_commit_hash=False,
326
+ **kwargs,
327
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
328
+ r"""
329
+ Load a model or scheduler configuration.
330
+
331
+ Parameters:
332
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
333
+ Can be either:
334
+
335
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
336
+ the Hub.
337
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
338
+ [`~ConfigMixin.save_config`].
339
+
340
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
341
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
342
+ is not used.
343
+ force_download (`bool`, *optional*, defaults to `False`):
344
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
345
+ cached versions if they exist.
346
+ resume_download (`bool`, *optional*, defaults to `False`):
347
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
348
+ incompletely downloaded files are deleted.
349
+ proxies (`Dict[str, str]`, *optional*):
350
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
351
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
352
+ output_loading_info(`bool`, *optional*, defaults to `False`):
353
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
354
+ local_files_only (`bool`, *optional*, defaults to `False`):
355
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
356
+ won't be downloaded from the Hub.
357
+ use_auth_token (`str` or *bool*, *optional*):
358
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
359
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
360
+ revision (`str`, *optional*, defaults to `"main"`):
361
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
362
+ allowed by Git.
363
+ subfolder (`str`, *optional*, defaults to `""`):
364
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
365
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
366
+ Whether unused keyword arguments of the config are returned.
367
+ return_commit_hash (`bool`, *optional*, defaults to `False):
368
+ Whether the `commit_hash` of the loaded configuration are returned.
369
+
370
+ Returns:
371
+ `dict`:
372
+ A dictionary of all the parameters stored in a JSON configuration file.
373
+
374
+ """
375
+ from_hf_hub = kwargs.pop("from_hf_hub", FROM_HF_HUB)
376
+ from_aistudio = kwargs.pop("from_aistudio", FROM_AISTUDIO)
377
+ cache_dir = kwargs.pop("cache_dir", None)
378
+ if cache_dir is None:
379
+ if from_aistudio:
380
+ cache_dir = PPDIFFUSERS_CACHE
381
+ elif from_hf_hub:
382
+ cache_dir = DIFFUSERS_CACHE
383
+ else:
384
+ cache_dir = PPDIFFUSERS_CACHE
385
+
386
+ force_download = kwargs.pop("force_download", False)
387
+ resume_download = kwargs.pop("resume_download", False)
388
+ proxies = kwargs.pop("proxies", None)
389
+ use_auth_token = kwargs.pop("use_auth_token", None)
390
+ local_files_only = kwargs.pop("local_files_only", False)
391
+ revision = kwargs.pop("revision", None)
392
+ _ = kwargs.pop("mirror", None)
393
+ subfolder = kwargs.pop("subfolder", "")
394
+ if subfolder is None:
395
+ subfolder = ""
396
+ user_agent = kwargs.pop("user_agent", {})
397
+
398
+ user_agent = {**user_agent, "file_type": "config"}
399
+ user_agent = http_user_agent(user_agent)
400
+
401
+ # new add return_config_file
402
+ return_config_file = kwargs.pop("return_config_file", False)
403
+
404
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
405
+
406
+ if cls.config_name is None:
407
+ raise ValueError(
408
+ "`self.config_name` is not defined. Note that one should not load a config from "
409
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
410
+ )
411
+
412
+ if os.path.isfile(pretrained_model_name_or_path):
413
+ config_file = pretrained_model_name_or_path
414
+ elif os.path.isdir(pretrained_model_name_or_path):
415
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
416
+ # Load from a PyTorch checkpoint
417
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
418
+ elif subfolder is not None and os.path.isfile(
419
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
420
+ ):
421
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
422
+ else:
423
+ raise EnvironmentError(
424
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
425
+ )
426
+ else:
427
+ config_file = bos_aistudio_hf_download(
428
+ pretrained_model_name_or_path,
429
+ cls.config_name,
430
+ cache_dir=cache_dir,
431
+ force_download=force_download,
432
+ proxies=proxies,
433
+ resume_download=resume_download,
434
+ local_files_only=local_files_only,
435
+ use_auth_token=use_auth_token,
436
+ user_agent=user_agent,
437
+ subfolder=subfolder,
438
+ revision=revision,
439
+ from_hf_hub=from_hf_hub,
440
+ from_aistudio=from_aistudio,
441
+ )
442
+
443
+ try:
444
+ # Load config dict
445
+ config_dict = cls._dict_from_json_file(config_file)
446
+
447
+ commit_hash = extract_commit_hash(config_file)
448
+ except (json.JSONDecodeError, UnicodeDecodeError):
449
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
450
+
451
+ if not (return_unused_kwargs or return_commit_hash or return_config_file):
452
+ return config_dict
453
+
454
+ outputs = (config_dict,)
455
+
456
+ if return_unused_kwargs:
457
+ outputs += (kwargs,)
458
+
459
+ if return_commit_hash:
460
+ outputs += (commit_hash,)
461
+
462
+ if return_config_file:
463
+ outputs += (config_file,)
464
+
465
+ return outputs
466
+
467
+ @staticmethod
468
+ def _get_init_keys(cls):
469
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
470
+
471
+ @classmethod
472
+ def extract_init_dict(cls, config_dict, **kwargs):
473
+ # Skip keys that were not present in the original config, so default __init__ values were used
474
+ used_defaults = config_dict.get("_use_default_values", [])
475
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
476
+
477
+ # 0. Copy origin config dict
478
+ original_dict = dict(config_dict.items())
479
+
480
+ # 1. Retrieve expected config attributes from __init__ signature
481
+ expected_keys = cls._get_init_keys(cls)
482
+ expected_keys.remove("self")
483
+ # remove general kwargs if present in dict
484
+ if "kwargs" in expected_keys:
485
+ expected_keys.remove("kwargs")
486
+
487
+ # 2. Remove attributes that cannot be expected from expected config attributes
488
+ # remove keys to be ignored
489
+ if len(cls.ignore_for_config) > 0:
490
+ expected_keys = expected_keys - set(cls.ignore_for_config)
491
+
492
+ # load ppdiffusers library to import compatible and original scheduler
493
+ ppdiffusers_library = importlib.import_module(__name__.split(".")[0])
494
+
495
+ if cls.has_compatibles:
496
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
497
+ else:
498
+ compatible_classes = []
499
+
500
+ expected_keys_comp_cls = set()
501
+ for c in compatible_classes:
502
+ expected_keys_c = cls._get_init_keys(c)
503
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
504
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
505
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
506
+
507
+ # remove attributes from orig class that cannot be expected
508
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
509
+ if (
510
+ isinstance(orig_cls_name, str)
511
+ and orig_cls_name != cls.__name__
512
+ and hasattr(ppdiffusers_library, orig_cls_name)
513
+ ):
514
+ orig_cls = getattr(ppdiffusers_library, orig_cls_name)
515
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
516
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
517
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
518
+ raise ValueError(
519
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
520
+ )
521
+
522
+ # remove private attributes
523
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
524
+
525
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
526
+ init_dict = {}
527
+ for key in expected_keys:
528
+ # if config param is passed to kwarg and is present in config dict
529
+ # it should overwrite existing config dict key
530
+ if key in kwargs and key in config_dict:
531
+ config_dict[key] = kwargs.pop(key)
532
+
533
+ if key in kwargs:
534
+ # overwrite key
535
+ init_dict[key] = kwargs.pop(key)
536
+ elif key in config_dict:
537
+ # use value from config dict
538
+ init_dict[key] = config_dict.pop(key)
539
+
540
+ # 4. Give nice warning if unexpected values have been passed
541
+ if len(config_dict) > 0:
542
+ logger.warning(
543
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
544
+ "but are not expected and will be ignored. Please verify your "
545
+ f"{cls.config_name} configuration file."
546
+ )
547
+
548
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
549
+ passed_keys = set(init_dict.keys())
550
+ if len(expected_keys - passed_keys) > 0:
551
+ logger.info(
552
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
553
+ )
554
+
555
+ # 6. Define unused keyword arguments
556
+ unused_kwargs = {**config_dict, **kwargs}
557
+
558
+ # 7. Define "hidden" config parameters that were saved for compatible classes
559
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
560
+
561
+ return init_dict, unused_kwargs, hidden_config_dict
562
+
563
+ @classmethod
564
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
565
+ with open(json_file, "r", encoding="utf-8") as reader:
566
+ text = reader.read()
567
+ data = json.loads(text)
568
+ if "_diffusers_version" in data and "_ppdiffusers_version" not in data:
569
+ data["_ppdiffusers_version"] = data.pop("_diffusers_version", __version__)
570
+ if "_diffusers_version" not in data and "_ppdiffusers_version" not in data:
571
+ data["_ppdiffusers_version"] = __version__
572
+
573
+ # remove Onnx and Flax prefix
574
+ _class_name = data.get("_class_name", None)
575
+ if _class_name is not None:
576
+ if _class_name.startswith("Flax"):
577
+ data["_class_name"] = _class_name[4:]
578
+ elif _class_name.startswith("Onnx"):
579
+ data["_class_name"] = "FastDeploy" + _class_name[4:]
580
+
581
+ return data
582
+
583
+ def __repr__(self):
584
+ return f"{self.__class__.__name__} {self.to_json_string()}"
585
+
586
+ @property
587
+ def config(self) -> Dict[str, Any]:
588
+ """
589
+ Returns the config of the class as a frozen dictionary
590
+
591
+ Returns:
592
+ `Dict[str, Any]`: Config of the class.
593
+ """
594
+ return self._internal_dict
595
+
596
+ def to_json_string(self, to_diffusers=False) -> str:
597
+ """
598
+ Serializes the configuration instance to a JSON string.
599
+
600
+ Returns:
601
+ `str`:
602
+ String containing all the attributes that make up the configuration instance in JSON format.
603
+ """
604
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
605
+ config_dict["_class_name"] = self.__class__.__name__
606
+ # json
607
+ if to_diffusers:
608
+ config_dict["_diffusers_version"] = __version__
609
+ else:
610
+ config_dict["_ppdiffusers_version"] = __version__
611
+
612
+ def to_json_saveable(value):
613
+ if isinstance(value, np.ndarray):
614
+ value = value.tolist()
615
+ elif isinstance(value, PosixPath):
616
+ value = str(value)
617
+ elif _omegaconf_available and isinstance(value, ListConfig):
618
+ value = list(value)
619
+ return value
620
+
621
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
622
+ if to_diffusers:
623
+ config_dict.pop("_ppdiffusers_version", None)
624
+ else:
625
+ config_dict.pop("_diffusers_version", None)
626
+ # Don't save "_ignore_files" or "_use_default_values"
627
+ config_dict.pop("_ignore_files", None)
628
+ config_dict.pop("_use_default_values", None)
629
+
630
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
631
+ if to_diffusers:
632
+ json_string = json_string.replace('"ppdiffusers"', '"diffusers"').replace(
633
+ '"ppdiffusers.transformers"', '"transformers"'
634
+ )
635
+ return json_string
636
+
637
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], to_diffusers=False):
638
+ """
639
+ Save the configuration instance's parameters to a JSON file.
640
+ Args:
641
+ json_file_path (`str` or `os.PathLike`):
642
+ Path to the JSON file to save a configuration instance's parameters.
643
+ """
644
+ with open(json_file_path, "w", encoding="utf-8") as writer:
645
+ writer.write(self.to_json_string(to_diffusers=to_diffusers))
646
+
647
+
648
+ def register_to_config(init):
649
+ r"""
650
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
651
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
652
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
653
+
654
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
655
+ """
656
+
657
+ @functools.wraps(init)
658
+ def inner_init(self, *args, **kwargs):
659
+ # Ignore private kwargs in the init.
660
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
661
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
662
+ if not isinstance(self, ConfigMixin):
663
+ raise RuntimeError(
664
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
665
+ "not inherit from `ConfigMixin`."
666
+ )
667
+
668
+ ignore = getattr(self, "ignore_for_config", [])
669
+ # Get positional arguments aligned with kwargs
670
+ new_kwargs = {}
671
+ signature = inspect.signature(init)
672
+ parameters = {
673
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
674
+ }
675
+ for arg, name in zip(args, parameters.keys()):
676
+ new_kwargs[name] = arg
677
+
678
+ # Then add all kwargs
679
+ new_kwargs.update(
680
+ {
681
+ k: init_kwargs.get(k, default)
682
+ for k, default in parameters.items()
683
+ if k not in ignore and k not in new_kwargs
684
+ }
685
+ )
686
+
687
+ # Take note of the parameters that were not present in the loaded config
688
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
689
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
690
+
691
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
692
+ getattr(self, "register_to_config")(**new_kwargs)
693
+ init(self, *args, **init_kwargs)
694
+
695
+ return inner_init
PaddleMIX/ppdiffusers/ppdiffusers/image_processor.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import paddle
20
+ import PIL.Image
21
+ from PIL import Image
22
+
23
+ from .configuration_utils import ConfigMixin, register_to_config
24
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
+
26
+ PipelineImageInput = Union[
27
+ PIL.Image.Image,
28
+ np.ndarray,
29
+ paddle.Tensor,
30
+ List[PIL.Image.Image],
31
+ List[np.ndarray],
32
+ List[paddle.Tensor],
33
+ ]
34
+
35
+ PipelineDepthInput = Union[
36
+ PIL.Image.Image,
37
+ np.ndarray,
38
+ paddle.Tensor,
39
+ List[PIL.Image.Image],
40
+ List[np.ndarray],
41
+ List[paddle.Tensor],
42
+ ]
43
+
44
+
45
+ class VaeImageProcessor(ConfigMixin):
46
+ """
47
+ Image processor for VAE.
48
+
49
+ Args:
50
+ do_resize (`bool`, *optional*, defaults to `True`):
51
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
52
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
53
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
54
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
55
+ resample (`str`, *optional*, defaults to `lanczos`):
56
+ Resampling filter to use when resizing the image.
57
+ do_normalize (`bool`, *optional*, defaults to `True`):
58
+ Whether to normalize the image to [-1,1].
59
+ do_binarize (`bool`, *optional*, defaults to `False`):
60
+ Whether to binarize the image to 0/1.
61
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
62
+ Whether to convert the images to RGB format.
63
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
64
+ Whether to convert the images to grayscale format.
65
+ """
66
+
67
+ config_name = CONFIG_NAME
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ do_resize: bool = True,
73
+ vae_scale_factor: int = 8,
74
+ vae_latent_channels: int = 4,
75
+ resample: str = "lanczos",
76
+ do_normalize: bool = True,
77
+ do_binarize: bool = False,
78
+ do_convert_rgb: bool = False,
79
+ do_convert_grayscale: bool = False,
80
+ ):
81
+ super().__init__()
82
+ if do_convert_rgb and do_convert_grayscale:
83
+ raise ValueError(
84
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
85
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
86
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
87
+ )
88
+ self.config.do_convert_rgb = False
89
+
90
+ @staticmethod
91
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
92
+ """
93
+ Convert a numpy image or a batch of images to a PIL image.
94
+ """
95
+ if images.ndim == 3:
96
+ images = images[None, ...]
97
+ images = (images * 255).round().astype("uint8")
98
+ if images.shape[-1] == 1:
99
+ # special case for grayscale (single channel) images
100
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
101
+ else:
102
+ pil_images = [Image.fromarray(image) for image in images]
103
+
104
+ return pil_images
105
+
106
+ @staticmethod
107
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
108
+ """
109
+ Convert a PIL image or a list of PIL images to NumPy arrays.
110
+ """
111
+ if not isinstance(images, list):
112
+ images = [images]
113
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
114
+ images = np.stack(images, axis=0)
115
+
116
+ return images
117
+
118
+ @staticmethod
119
+ def numpy_to_pd(images: np.ndarray) -> paddle.Tensor:
120
+ """
121
+ Convert a NumPy image to a Paddle tensor.
122
+ """
123
+ if images.ndim == 3:
124
+ images = images[..., None]
125
+
126
+ images = paddle.to_tensor(images.transpose(0, 3, 1, 2))
127
+ return images
128
+
129
+ @staticmethod
130
+ def pd_to_numpy(images: paddle.Tensor) -> np.ndarray:
131
+ """
132
+ Convert a Paddle tensor to a NumPy image.
133
+ """
134
+ images = images.cast("float32").cpu().transpose([0, 2, 3, 1]).numpy()
135
+ return images
136
+
137
+ @staticmethod
138
+ def normalize(images: Union[np.ndarray, paddle.Tensor]) -> Union[np.ndarray, paddle.Tensor]:
139
+ """
140
+ Normalize an image array to [-1,1].
141
+ """
142
+ return 2.0 * images - 1.0
143
+
144
+ @staticmethod
145
+ def denormalize(images: Union[np.ndarray, paddle.Tensor]) -> Union[np.ndarray, paddle.Tensor]:
146
+ """
147
+ Denormalize an image array to [0,1].
148
+ """
149
+ return (images / 2 + 0.5).clip(0, 1)
150
+
151
+ @staticmethod
152
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
153
+ """
154
+ Converts a PIL image to RGB format.
155
+ """
156
+ image = image.convert("RGB")
157
+
158
+ return image
159
+
160
+ @staticmethod
161
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
162
+ """
163
+ Converts a PIL image to grayscale format.
164
+ """
165
+ image = image.convert("L")
166
+
167
+ return image
168
+
169
+ def get_default_height_width(
170
+ self,
171
+ image: Union[PIL.Image.Image, np.ndarray, paddle.Tensor],
172
+ height: Optional[int] = None,
173
+ width: Optional[int] = None,
174
+ ) -> Tuple[int, int]:
175
+ """
176
+ This function return the height and width that are downscaled to the next integer multiple of
177
+ `vae_scale_factor`.
178
+
179
+ Args:
180
+ image(`PIL.Image.Image`, `np.ndarray` or `paddle.Tensor`):
181
+ The image input, can be a PIL image, numpy array or paddle tensor. if it is a numpy array, should have
182
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a paddle tensor, should
183
+ have shape `[batch, channel, height, width]`.
184
+ height (`int`, *optional*, defaults to `None`):
185
+ The height in preprocessed image. If `None`, will use the height of `image` input.
186
+ width (`int`, *optional*`, defaults to `None`):
187
+ The width in preprocessed. If `None`, will use the width of the `image` input.
188
+ """
189
+
190
+ if height is None:
191
+ if isinstance(image, PIL.Image.Image):
192
+ height = image.height
193
+ elif isinstance(image, paddle.Tensor):
194
+ height = image.shape[2]
195
+ else:
196
+ height = image.shape[1]
197
+
198
+ if width is None:
199
+ if isinstance(image, PIL.Image.Image):
200
+ width = image.width
201
+ elif isinstance(image, paddle.Tensor):
202
+ width = image.shape[3]
203
+ else:
204
+ width = image.shape[2]
205
+
206
+ width, height = (
207
+ x - x % self.config.vae_scale_factor for x in (width, height)
208
+ ) # resize to integer multiple of vae_scale_factor
209
+
210
+ return height, width
211
+
212
+ def resize(
213
+ self,
214
+ image: Union[PIL.Image.Image, np.ndarray, paddle.Tensor],
215
+ height: Optional[int] = None,
216
+ width: Optional[int] = None,
217
+ ) -> Union[PIL.Image.Image, np.ndarray, paddle.Tensor]:
218
+ """
219
+ Resize image.
220
+
221
+ Args:
222
+ image (`PIL.Image.Image`, `np.ndarray` or `paddle.Tensor`):
223
+ The image input, can be a PIL image, numpy array or paddle tensor.
224
+ height (`int`, *optional*, defaults to `None`):
225
+ The height to resize to.
226
+ width (`int`, *optional*`, defaults to `None`):
227
+ The width to resize to.
228
+
229
+ Returns:
230
+ `PIL.Image.Image`, `np.ndarray` or `paddle.Tensor`:
231
+ The resized image.
232
+ """
233
+ if isinstance(image, PIL.Image.Image):
234
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
235
+ elif isinstance(image, paddle.Tensor):
236
+ image = paddle.nn.functional.interpolate(
237
+ image,
238
+ size=(height, width),
239
+ )
240
+ elif isinstance(image, np.ndarray):
241
+ image = self.numpy_to_pd(image)
242
+ image = paddle.nn.functional.interpolate(
243
+ image,
244
+ size=(height, width),
245
+ )
246
+ image = self.pd_to_numpy(image)
247
+ return image
248
+
249
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
250
+ """
251
+ Create a mask.
252
+
253
+ Args:
254
+ image (`PIL.Image.Image`):
255
+ The image input, should be a PIL image.
256
+
257
+ Returns:
258
+ `PIL.Image.Image`:
259
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
260
+ """
261
+ image[image < 0.5] = 0
262
+ image[image >= 0.5] = 1
263
+ return image
264
+
265
+ def preprocess(
266
+ self,
267
+ image: Union[paddle.Tensor, PIL.Image.Image, np.ndarray],
268
+ height: Optional[int] = None,
269
+ width: Optional[int] = None,
270
+ ) -> paddle.Tensor:
271
+ """
272
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or Paddle tensors.
273
+ """
274
+ supported_formats = (PIL.Image.Image, np.ndarray, paddle.Tensor)
275
+
276
+ # Expand the missing dimension for 3-dimensional paddle tensor or numpy array that represents grayscale image
277
+ if self.config.do_convert_grayscale and isinstance(image, (paddle.Tensor, np.ndarray)) and image.ndim == 3:
278
+ if isinstance(image, paddle.Tensor):
279
+ # if image is a paddle tensor could have 2 possible shapes:
280
+ # 1. batch x height x width: we should insert the channel dimension at position 1
281
+ # 2. channnel x height x width: we should insert batch dimension at position 0,
282
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
283
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
284
+ image = image.unsqueeze(1)
285
+ else:
286
+ # if it is a numpy array, it could have 2 possible shapes:
287
+ # 1. batch x height x width: insert channel dimension on last position
288
+ # 2. height x width x channel: insert batch dimension on first position
289
+ if image.shape[-1] == 1:
290
+ image = np.expand_dims(image, axis=0)
291
+ else:
292
+ image = np.expand_dims(image, axis=-1)
293
+
294
+ if isinstance(image, supported_formats):
295
+ image = [image]
296
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
297
+ raise ValueError(
298
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
299
+ )
300
+
301
+ if isinstance(image[0], PIL.Image.Image):
302
+ if self.config.do_convert_rgb:
303
+ image = [self.convert_to_rgb(i) for i in image]
304
+ elif self.config.do_convert_grayscale:
305
+ image = [self.convert_to_grayscale(i) for i in image]
306
+ if self.config.do_resize:
307
+ height, width = self.get_default_height_width(image[0], height, width)
308
+ image = [self.resize(i, height, width) for i in image]
309
+ image = self.pil_to_numpy(image) # to np
310
+ image = self.numpy_to_pd(image) # to pt
311
+
312
+ elif isinstance(image[0], np.ndarray):
313
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
314
+
315
+ image = self.numpy_to_pd(image)
316
+
317
+ height, width = self.get_default_height_width(image, height, width)
318
+ if self.config.do_resize:
319
+ image = self.resize(image, height, width)
320
+
321
+ elif isinstance(image[0], paddle.Tensor):
322
+ image = paddle.concat(image, axis=0) if image[0].ndim == 4 else paddle.stack(image, axis=0)
323
+
324
+ if self.config.do_convert_grayscale and image.ndim == 3:
325
+ image = image.unsqueeze(1)
326
+
327
+ channel = image.shape[1]
328
+ # don't need any preprocess if the image is latents
329
+ if channel == 4:
330
+ return image
331
+
332
+ height, width = self.get_default_height_width(image, height, width)
333
+ if self.config.do_resize:
334
+ image = self.resize(image, height, width)
335
+
336
+ # expected range [0,1], normalize to [-1,1]
337
+ do_normalize = self.config.do_normalize
338
+ if do_normalize and image.min() < 0:
339
+ warnings.warn(
340
+ "Passing `image` as paddle tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
341
+ f"when passing as paddle tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
342
+ FutureWarning,
343
+ )
344
+ do_normalize = False
345
+
346
+ if do_normalize:
347
+ image = self.normalize(image)
348
+
349
+ if self.config.do_binarize:
350
+ image = self.binarize(image)
351
+
352
+ # laixinlu: add this, for paddle not auto support float32 * bool
353
+ if isinstance(image, paddle.Tensor) and image.dtype == paddle.bool:
354
+ image = image.cast(dtype="float32")
355
+
356
+ return image
357
+
358
+ def postprocess(
359
+ self,
360
+ image: paddle.Tensor,
361
+ output_type: str = "pil",
362
+ do_denormalize: Optional[List[bool]] = None,
363
+ ) -> Union[PIL.Image.Image, np.ndarray, paddle.Tensor]:
364
+ """
365
+ Postprocess the image output from tensor to `output_type`.
366
+
367
+ Args:
368
+ image (`paddle.Tensor`):
369
+ The image input, should be a paddle tensor with shape `B x C x H x W`.
370
+ output_type (`str`, *optional*, defaults to `pil`):
371
+ The output type of the image, can be one of `pil`, `np`, `pd`, `latent`.
372
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
373
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
374
+ `VaeImageProcessor` config.
375
+
376
+ Returns:
377
+ `PIL.Image.Image`, `np.ndarray` or `paddle.Tensor`:
378
+ The postprocessed image.
379
+ """
380
+ if not isinstance(image, paddle.Tensor):
381
+ raise ValueError(
382
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support paddle tensor"
383
+ )
384
+ if output_type not in ["latent", "pd", "np", "pil"]:
385
+ deprecation_message = (
386
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
387
+ "`pil`, `np`, `pd`, `latent`"
388
+ )
389
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
390
+ output_type = "np"
391
+
392
+ if output_type == "latent":
393
+ return image
394
+
395
+ if do_denormalize is None:
396
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
397
+
398
+ image = paddle.stack(
399
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
400
+ )
401
+
402
+ if output_type == "pd":
403
+ return image
404
+
405
+ image = self.pd_to_numpy(image)
406
+
407
+ if output_type == "np":
408
+ return image
409
+
410
+ if output_type == "pil":
411
+ return self.numpy_to_pil(image)
412
+
413
+
414
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
415
+ """
416
+ Image processor for VAE LDM3D.
417
+
418
+ Args:
419
+ do_resize (`bool`, *optional*, defaults to `True`):
420
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
421
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
422
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
423
+ resample (`str`, *optional*, defaults to `lanczos`):
424
+ Resampling filter to use when resizing the image.
425
+ do_normalize (`bool`, *optional*, defaults to `True`):
426
+ Whether to normalize the image to [-1,1].
427
+ """
428
+
429
+ config_name = CONFIG_NAME
430
+
431
+ @register_to_config
432
+ def __init__(
433
+ self,
434
+ do_resize: bool = True,
435
+ vae_scale_factor: int = 8,
436
+ resample: str = "lanczos",
437
+ do_normalize: bool = True,
438
+ ):
439
+ super().__init__()
440
+
441
+ @staticmethod
442
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
443
+ """
444
+ Convert a NumPy image or a batch of images to a PIL image.
445
+ """
446
+ if images.ndim == 3:
447
+ images = images[None, ...]
448
+ images = (images * 255).round().astype("uint8")
449
+ if images.shape[-1] == 1:
450
+ # special case for grayscale (single channel) images
451
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
452
+ else:
453
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
454
+
455
+ return pil_images
456
+
457
+ @staticmethod
458
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
459
+ """
460
+ Convert a PIL image or a list of PIL images to NumPy arrays.
461
+ """
462
+ if not isinstance(images, list):
463
+ images = [images]
464
+
465
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
466
+ images = np.stack(images, axis=0)
467
+ return images
468
+
469
+ @staticmethod
470
+ def rgblike_to_depthmap(image: Union[np.ndarray, paddle.Tensor]) -> Union[np.ndarray, paddle.Tensor]:
471
+ """
472
+ Args:
473
+ image: RGB-like depth image
474
+
475
+ Returns: depth map
476
+
477
+ """
478
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
479
+
480
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
481
+ """
482
+ Convert a NumPy depth image or a batch of images to a PIL image.
483
+ """
484
+ if images.ndim == 3:
485
+ images = images[None, ...]
486
+ images_depth = images[:, :, :, 3:]
487
+ if images.shape[-1] == 6:
488
+ images_depth = (images_depth * 255).round().astype("uint8")
489
+ pil_images = [
490
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
491
+ ]
492
+ elif images.shape[-1] == 4:
493
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
494
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
495
+ else:
496
+ raise Exception("Not supported")
497
+
498
+ return pil_images
499
+
500
+ def postprocess(
501
+ self,
502
+ image: paddle.Tensor,
503
+ output_type: str = "pil",
504
+ do_denormalize: Optional[List[bool]] = None,
505
+ ) -> Union[PIL.Image.Image, np.ndarray, paddle.Tensor]:
506
+ """
507
+ Postprocess the image output from tensor to `output_type`.
508
+
509
+ Args:
510
+ image (`paddle.Tensor`):
511
+ The image input, should be a paddle tensor with shape `B x C x H x W`.
512
+ output_type (`str`, *optional*, defaults to `pil`):
513
+ The output type of the image, can be one of `pil`, `np`, `pd`, `latent`.
514
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
515
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
516
+ `VaeImageProcessor` config.
517
+
518
+ Returns:
519
+ `PIL.Image.Image`, `np.ndarray` or `paddle.Tensor`:
520
+ The postprocessed image.
521
+ """
522
+ if not isinstance(image, paddle.Tensor):
523
+ raise ValueError(
524
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support paddle tensor"
525
+ )
526
+ if output_type not in ["latent", "pd", "np", "pil"]:
527
+ deprecation_message = (
528
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
529
+ "`pil`, `np`, `pd`, `latent`"
530
+ )
531
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
532
+ output_type = "np"
533
+
534
+ if do_denormalize is None:
535
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
536
+
537
+ image = paddle.stack(
538
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
539
+ )
540
+
541
+ image = self.pd_to_numpy(image)
542
+
543
+ if output_type == "np":
544
+ if image.shape[-1] == 6:
545
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
546
+ else:
547
+ image_depth = image[:, :, :, 3:]
548
+ return image[:, :, :, :3], image_depth
549
+
550
+ if output_type == "pil":
551
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
552
+ else:
553
+ raise Exception(f"This type {output_type} is not supported")
554
+
555
+ def preprocess(
556
+ self,
557
+ rgb: Union[paddle.Tensor, PIL.Image.Image, np.ndarray],
558
+ depth: Union[paddle.Tensor, PIL.Image.Image, np.ndarray],
559
+ height: Optional[int] = None,
560
+ width: Optional[int] = None,
561
+ target_res: Optional[int] = None,
562
+ ) -> paddle.Tensor:
563
+ """
564
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or Paddle tensors.
565
+ """
566
+ supported_formats = (PIL.Image.Image, np.ndarray, paddle.Tensor)
567
+
568
+ # Expand the missing dimension for 3-dimensional paddle tensor or numpy array that represents grayscale image
569
+ if self.config.do_convert_grayscale and isinstance(rgb, (paddle.Tensor, np.ndarray)) and rgb.ndim == 3:
570
+ raise Exception("This is not yet supported")
571
+
572
+ if isinstance(rgb, supported_formats):
573
+ rgb = [rgb]
574
+ depth = [depth]
575
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
576
+ raise ValueError(
577
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
578
+ )
579
+
580
+ if isinstance(rgb[0], PIL.Image.Image):
581
+ if self.config.do_convert_rgb:
582
+ raise Exception("This is not yet supported")
583
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
584
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
585
+ if self.config.do_resize or target_res:
586
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
587
+ rgb = [self.resize(i, height, width) for i in rgb]
588
+ depth = [self.resize(i, height, width) for i in depth]
589
+ rgb = self.pil_to_numpy(rgb) # to np
590
+ rgb = self.numpy_to_pd(rgb) # to pt
591
+
592
+ depth = self.depth_pil_to_numpy(depth) # to np
593
+ depth = self.numpy_to_pd(depth) # to pt
594
+
595
+ elif isinstance(rgb[0], np.ndarray):
596
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
597
+ rgb = self.numpy_to_pd(rgb)
598
+ height, width = self.get_default_height_width(rgb, height, width)
599
+ if self.config.do_resize:
600
+ rgb = self.resize(rgb, height, width)
601
+
602
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
603
+ depth = self.numpy_to_pd(depth)
604
+ height, width = self.get_default_height_width(depth, height, width)
605
+ if self.config.do_resize:
606
+ depth = self.resize(depth, height, width)
607
+
608
+ elif isinstance(rgb[0], paddle.Tensor):
609
+ raise Exception("This is not yet supported")
610
+ # rgb = paddle.concat(rgb, axis=0) if rgb[0].ndim == 4 else paddle.stack(rgb, axis=0)
611
+
612
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
613
+ # rgb = rgb.unsqueeze(1)
614
+
615
+ # channel = rgb.shape[1]
616
+
617
+ # height, width = self.get_default_height_width(rgb, height, width)
618
+ # if self.config.do_resize:
619
+ # rgb = self.resize(rgb, height, width)
620
+
621
+ # depth = paddle.cat(depth, axis=0) if depth[0].ndim == 4 else paddle.stack(depth, axis=0)
622
+
623
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
624
+ # depth = depth.unsqueeze(1)
625
+
626
+ # channel = depth.shape[1]
627
+ # # don't need any preprocess if the image is latents
628
+ # if depth == 4:
629
+ # return rgb, depth
630
+
631
+ # height, width = self.get_default_height_width(depth, height, width)
632
+ # if self.config.do_resize:
633
+ # depth = self.resize(depth, height, width)
634
+ # expected range [0,1], normalize to [-1,1]
635
+ do_normalize = self.config.do_normalize
636
+ if rgb.min() < 0 and do_normalize:
637
+ warnings.warn(
638
+ "Passing `image` as paddle tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
639
+ f"when passing as paddle tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
640
+ FutureWarning,
641
+ )
642
+ do_normalize = False
643
+
644
+ if do_normalize:
645
+ rgb = self.normalize(rgb)
646
+ depth = self.normalize(depth)
647
+
648
+ if self.config.do_binarize:
649
+ rgb = self.binarize(rgb)
650
+ depth = self.binarize(depth)
651
+
652
+ return rgb, depth
653
+
654
+
655
+ def is_valid_image(image):
656
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, paddle.Tensor)) and image.ndim in (2, 3)
657
+
658
+
659
+ def is_valid_image_imagelist(images):
660
+ # check if the image input is one of the supported formats for image and image list:
661
+ # it can be either one of below 3
662
+ # (1) a 4d pytorch tensor or numpy array,
663
+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
664
+ # (3) a list of valid image
665
+ if isinstance(images, (np.ndarray, paddle.Tensor)) and images.ndim == 4:
666
+ return True
667
+ elif is_valid_image(images):
668
+ return True
669
+ elif isinstance(images, list):
670
+ return all(is_valid_image(image) for image in images)
671
+ return False
PaddleMIX/ppdiffusers/ppdiffusers/initializer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+
15
+ # limitations under the License.
16
+
17
+ # NOTE: This file is deprecated and will be removed in a future version.
18
+ # It only exists so that temporarely `from ppdiffusers.utils.initializer_utils import *` works
19
+ # flake8: noqa
20
+ from .utils.initializer_utils import * # noqa: F401
PaddleMIX/ppdiffusers/ppdiffusers/models/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
PaddleMIX/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py ADDED
@@ -0,0 +1,1190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import paddle
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import logging
22
+ from ..utils.accelerate_utils import apply_forward_hook
23
+ from .activations import get_activation
24
+ from .downsampling import CogVideoXDownsample3D
25
+ from .modeling_outputs import AutoencoderKLOutput
26
+ from .modeling_utils import ModelMixin
27
+ from .upsampling import CogVideoXUpsample3D
28
+ from .vae import DecoderOutput, DiagonalGaussianDistribution
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class CogVideoXSafeConv3d(paddle.nn.Conv3D):
34
+ """
35
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
36
+ """
37
+
38
+ def forward(self, input: paddle.Tensor) -> paddle.Tensor:
39
+ memory_count = paddle.prod(x=paddle.to_tensor(data=tuple(input.shape))).item() * 2 / 1024**3
40
+ if memory_count > 2:
41
+ kernel_size = self.kernel_size[0]
42
+ part_num = int(memory_count / 2) + 1
43
+ input_chunks = paddle.chunk(x=input, chunks=part_num, axis=2)
44
+ if kernel_size > 1:
45
+ input_chunks = [input_chunks[0]] + [
46
+ paddle.concat(x=(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), axis=2)
47
+ for i in range(1, len(input_chunks))
48
+ ]
49
+ output_chunks = []
50
+ for input_chunk in input_chunks:
51
+ output_chunks.append(super().forward(input_chunk))
52
+ output = paddle.concat(x=output_chunks, axis=2)
53
+ return output
54
+ else:
55
+ return super().forward(input)
56
+
57
+
58
+ class CogVideoXCausalConv3d(paddle.nn.Layer):
59
+ """A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
60
+
61
+ Args:
62
+ in_channels (`int`): Number of channels in the input tensor.
63
+ out_channels (`int`): Number of output channels produced by the convolution.
64
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
65
+ stride (`int`, defaults to `1`): Stride of the convolution.
66
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
67
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels: int,
73
+ out_channels: int,
74
+ kernel_size: Union[int, Tuple[int, int, int]],
75
+ stride: int = 1,
76
+ dilation: int = 1,
77
+ pad_mode: str = "constant",
78
+ ):
79
+ super().__init__()
80
+ if isinstance(kernel_size, int):
81
+ kernel_size = (kernel_size,) * 3
82
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
83
+ self.pad_mode = pad_mode
84
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
85
+ height_pad = height_kernel_size // 2
86
+ width_pad = width_kernel_size // 2
87
+ self.height_pad = height_pad
88
+ self.width_pad = width_pad
89
+ self.time_pad = time_pad
90
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
91
+ self.temporal_dim = 2
92
+ self.time_kernel_size = time_kernel_size
93
+ stride = stride, 1, 1
94
+ dilation = dilation, 1, 1
95
+ self.conv = CogVideoXSafeConv3d(
96
+ in_channels=in_channels,
97
+ out_channels=out_channels,
98
+ kernel_size=kernel_size,
99
+ stride=stride,
100
+ dilation=dilation,
101
+ )
102
+ self.conv_cache = None
103
+
104
+ def fake_context_parallel_forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
105
+ kernel_size = self.time_kernel_size
106
+ if kernel_size > 1:
107
+ cached_inputs = (
108
+ [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
109
+ )
110
+ inputs = paddle.concat(x=cached_inputs + [inputs], axis=2)
111
+ return inputs
112
+
113
+ def _clear_fake_context_parallel_cache(self):
114
+ del self.conv_cache
115
+ self.conv_cache = None
116
+
117
+ def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
118
+ inputs = self.fake_context_parallel_forward(inputs)
119
+ self._clear_fake_context_parallel_cache()
120
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
121
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, 0, 0)
122
+ inputs = paddle.nn.functional.pad(x=inputs, pad=padding_2d, mode="constant", value=0, data_format="NCDHW")
123
+ output = self.conv(inputs)
124
+ return output
125
+
126
+
127
+ class CogVideoXSpatialNorm3D(paddle.nn.Layer):
128
+ """
129
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
130
+ to 3D-video like data.
131
+
132
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
133
+
134
+ Args:
135
+ f_channels (`int`):
136
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
137
+ zq_channels (`int`):
138
+ The number of channels for the quantized vector as described in the paper.
139
+ groups (`int`):
140
+ Number of groups to separate the channels into for group normalization.
141
+ """
142
+
143
+ def __init__(self, f_channels: int, zq_channels: int, groups: int = 32):
144
+ super().__init__()
145
+ self.norm_layer = paddle.nn.GroupNorm(
146
+ num_channels=f_channels, num_groups=groups, epsilon=1e-06, weight_attr=True, bias_attr=True
147
+ )
148
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
149
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
150
+
151
+ def forward(self, f: paddle.Tensor, zq: paddle.Tensor) -> paddle.Tensor:
152
+ if tuple(f.shape)[2] > 1 and tuple(f.shape)[2] % 2 == 1:
153
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
154
+ f_first_size, f_rest_size = tuple(f_first.shape)[-3:], tuple(f_rest.shape)[-3:]
155
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
156
+ z_first = paddle.nn.functional.interpolate(x=z_first, size=f_first_size)
157
+ z_rest = paddle.nn.functional.interpolate(x=z_rest, size=f_rest_size)
158
+ zq = paddle.concat(x=[z_first, z_rest], axis=2)
159
+ else:
160
+ zq = paddle.nn.functional.interpolate(x=zq, size=tuple(f.shape)[-3:])
161
+ norm_f = self.norm_layer(f)
162
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
163
+ return new_f
164
+
165
+
166
+ class CogVideoXResnetBlock3D(paddle.nn.Layer):
167
+ """
168
+ A 3D ResNet block used in the CogVideoX model.
169
+
170
+ Args:
171
+ in_channels (`int`):
172
+ Number of input channels.
173
+ out_channels (`int`, *optional*):
174
+ Number of output channels. If None, defaults to `in_channels`.
175
+ dropout (`float`, defaults to `0.0`):
176
+ Dropout rate.
177
+ temb_channels (`int`, defaults to `512`):
178
+ Number of time embedding channels.
179
+ groups (`int`, defaults to `32`):
180
+ Number of groups to separate the channels into for group normalization.
181
+ eps (`float`, defaults to `1e-6`):
182
+ Epsilon value for normalization layers.
183
+ non_linearity (`str`, defaults to `"swish"`):
184
+ Activation function to use.
185
+ conv_shortcut (bool, defaults to `False`):
186
+ Whether or not to use a convolution shortcut.
187
+ spatial_norm_dim (`int`, *optional*):
188
+ The dimension to use for spatial norm if it is to be used instead of group norm.
189
+ pad_mode (str, defaults to `"first"`):
190
+ Padding mode.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ in_channels: int,
196
+ out_channels: Optional[int] = None,
197
+ dropout: float = 0.0,
198
+ temb_channels: int = 512,
199
+ groups: int = 32,
200
+ eps: float = 1e-06,
201
+ non_linearity: str = "swish",
202
+ conv_shortcut: bool = False,
203
+ spatial_norm_dim: Optional[int] = None,
204
+ pad_mode: str = "first",
205
+ ):
206
+ super().__init__()
207
+ out_channels = out_channels or in_channels
208
+ self.in_channels = in_channels
209
+ self.out_channels = out_channels
210
+ self.nonlinearity = get_activation(non_linearity)
211
+ self.use_conv_shortcut = conv_shortcut
212
+ if spatial_norm_dim is None:
213
+ self.norm1 = paddle.nn.GroupNorm(num_channels=in_channels, num_groups=groups, epsilon=eps)
214
+ self.norm2 = paddle.nn.GroupNorm(num_channels=out_channels, num_groups=groups, epsilon=eps)
215
+ else:
216
+ self.norm1 = CogVideoXSpatialNorm3D(f_channels=in_channels, zq_channels=spatial_norm_dim, groups=groups)
217
+ self.norm2 = CogVideoXSpatialNorm3D(f_channels=out_channels, zq_channels=spatial_norm_dim, groups=groups)
218
+ self.conv1 = CogVideoXCausalConv3d(
219
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
220
+ )
221
+ if temb_channels > 0:
222
+ self.temb_proj = paddle.nn.Linear(in_features=temb_channels, out_features=out_channels)
223
+ self.dropout = paddle.nn.Dropout(p=dropout)
224
+ self.conv2 = CogVideoXCausalConv3d(
225
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
226
+ )
227
+ if self.in_channels != self.out_channels:
228
+ if self.use_conv_shortcut:
229
+ self.conv_shortcut = CogVideoXCausalConv3d(
230
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
231
+ )
232
+ else:
233
+ self.conv_shortcut = CogVideoXSafeConv3d(
234
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
235
+ )
236
+
237
+ def forward(
238
+ self, inputs: paddle.Tensor, temb: Optional[paddle.Tensor] = None, zq: Optional[paddle.Tensor] = None
239
+ ) -> paddle.Tensor:
240
+ hidden_states = inputs
241
+ if zq is not None:
242
+ hidden_states = self.norm1(hidden_states, zq)
243
+ else:
244
+ hidden_states = self.norm1(hidden_states)
245
+ hidden_states = self.nonlinearity(hidden_states)
246
+ hidden_states = self.conv1(hidden_states)
247
+ if temb is not None:
248
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
249
+ if zq is not None:
250
+ hidden_states = self.norm2(hidden_states, zq)
251
+ else:
252
+ hidden_states = self.norm2(hidden_states)
253
+ hidden_states = self.nonlinearity(hidden_states)
254
+ hidden_states = self.dropout(hidden_states)
255
+ hidden_states = self.conv2(hidden_states)
256
+ if self.in_channels != self.out_channels:
257
+ inputs = self.conv_shortcut(inputs)
258
+ hidden_states = hidden_states + inputs
259
+ return hidden_states
260
+
261
+
262
+ class CogVideoXDownBlock3D(paddle.nn.Layer):
263
+ """
264
+ A downsampling block used in the CogVideoX model.
265
+
266
+ Args:
267
+ in_channels (`int`):
268
+ Number of input channels.
269
+ out_channels (`int`, *optional*):
270
+ Number of output channels. If None, defaults to `in_channels`.
271
+ temb_channels (`int`, defaults to `512`):
272
+ Number of time embedding channels.
273
+ num_layers (`int`, defaults to `1`):
274
+ Number of resnet layers.
275
+ dropout (`float`, defaults to `0.0`):
276
+ Dropout rate.
277
+ resnet_eps (`float`, defaults to `1e-6`):
278
+ Epsilon value for normalization layers.
279
+ resnet_act_fn (`str`, defaults to `"swish"`):
280
+ Activation function to use.
281
+ resnet_groups (`int`, defaults to `32`):
282
+ Number of groups to separate the channels into for group normalization.
283
+ add_downsample (`bool`, defaults to `True`):
284
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
285
+ compress_time (`bool`, defaults to `False`):
286
+ Whether or not to downsample across temporal dimension.
287
+ pad_mode (str, defaults to `"first"`):
288
+ Padding mode.
289
+ """
290
+
291
+ _supports_gradient_checkpointing = True
292
+
293
+ def __init__(
294
+ self,
295
+ in_channels: int,
296
+ out_channels: int,
297
+ temb_channels: int,
298
+ dropout: float = 0.0,
299
+ num_layers: int = 1,
300
+ resnet_eps: float = 1e-06,
301
+ resnet_act_fn: str = "swish",
302
+ resnet_groups: int = 32,
303
+ add_downsample: bool = True,
304
+ downsample_padding: int = 0,
305
+ compress_time: bool = False,
306
+ pad_mode: str = "first",
307
+ ):
308
+ super().__init__()
309
+ resnets = []
310
+ for i in range(num_layers):
311
+ in_channel = in_channels if i == 0 else out_channels
312
+ resnets.append(
313
+ CogVideoXResnetBlock3D(
314
+ in_channels=in_channel,
315
+ out_channels=out_channels,
316
+ dropout=dropout,
317
+ temb_channels=temb_channels,
318
+ groups=resnet_groups,
319
+ eps=resnet_eps,
320
+ non_linearity=resnet_act_fn,
321
+ pad_mode=pad_mode,
322
+ )
323
+ )
324
+ self.resnets = paddle.nn.LayerList(sublayers=resnets)
325
+ self.downsamplers = None
326
+ if add_downsample:
327
+ self.downsamplers = paddle.nn.LayerList(
328
+ sublayers=[
329
+ CogVideoXDownsample3D(
330
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
331
+ )
332
+ ]
333
+ )
334
+ self.gradient_checkpointing = False
335
+
336
+ def forward(
337
+ self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None, zq: Optional[paddle.Tensor] = None
338
+ ) -> paddle.Tensor:
339
+ for resnet in self.resnets:
340
+ if self.training and self.gradient_checkpointing:
341
+
342
+ def create_custom_forward(module):
343
+ def create_forward(*inputs):
344
+ return module(*inputs)
345
+
346
+ return create_forward
347
+
348
+ hidden_states = paddle.distributed.fleet.utils.recompute(
349
+ create_custom_forward(resnet), hidden_states, temb, zq
350
+ )
351
+ else:
352
+ hidden_states = resnet(hidden_states, temb, zq)
353
+ if self.downsamplers is not None:
354
+ for downsampler in self.downsamplers:
355
+ hidden_states = downsampler(hidden_states)
356
+ return hidden_states
357
+
358
+
359
+ class CogVideoXMidBlock3D(paddle.nn.Layer):
360
+ """
361
+ A middle block used in the CogVideoX model.
362
+
363
+ Args:
364
+ in_channels (`int`):
365
+ Number of input channels.
366
+ temb_channels (`int`, defaults to `512`):
367
+ Number of time embedding channels.
368
+ dropout (`float`, defaults to `0.0`):
369
+ Dropout rate.
370
+ num_layers (`int`, defaults to `1`):
371
+ Number of resnet layers.
372
+ resnet_eps (`float`, defaults to `1e-6`):
373
+ Epsilon value for normalization layers.
374
+ resnet_act_fn (`str`, defaults to `"swish"`):
375
+ Activation function to use.
376
+ resnet_groups (`int`, defaults to `32`):
377
+ Number of groups to separate the channels into for group normalization.
378
+ spatial_norm_dim (`int`, *optional*):
379
+ The dimension to use for spatial norm if it is to be used instead of group norm.
380
+ pad_mode (str, defaults to `"first"`):
381
+ Padding mode.
382
+ """
383
+
384
+ _supports_gradient_checkpointing = True
385
+
386
+ def __init__(
387
+ self,
388
+ in_channels: int,
389
+ temb_channels: int,
390
+ dropout: float = 0.0,
391
+ num_layers: int = 1,
392
+ resnet_eps: float = 1e-06,
393
+ resnet_act_fn: str = "swish",
394
+ resnet_groups: int = 32,
395
+ spatial_norm_dim: Optional[int] = None,
396
+ pad_mode: str = "first",
397
+ ):
398
+ super().__init__()
399
+ resnets = []
400
+ for _ in range(num_layers):
401
+ resnets.append(
402
+ CogVideoXResnetBlock3D(
403
+ in_channels=in_channels,
404
+ out_channels=in_channels,
405
+ dropout=dropout,
406
+ temb_channels=temb_channels,
407
+ groups=resnet_groups,
408
+ eps=resnet_eps,
409
+ spatial_norm_dim=spatial_norm_dim,
410
+ non_linearity=resnet_act_fn,
411
+ pad_mode=pad_mode,
412
+ )
413
+ )
414
+ self.resnets = paddle.nn.LayerList(sublayers=resnets)
415
+ self.gradient_checkpointing = False
416
+
417
+ def forward(
418
+ self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None, zq: Optional[paddle.Tensor] = None
419
+ ) -> paddle.Tensor:
420
+ for resnet in self.resnets:
421
+ if self.training and self.gradient_checkpointing:
422
+
423
+ def create_custom_forward(module):
424
+ def create_forward(*inputs):
425
+ return module(*inputs)
426
+
427
+ return create_forward
428
+
429
+ hidden_states = paddle.distributed.fleet.utils.recompute(
430
+ create_custom_forward(resnet), hidden_states, temb, zq
431
+ )
432
+ else:
433
+ hidden_states = resnet(hidden_states, temb, zq)
434
+ return hidden_states
435
+
436
+
437
+ class CogVideoXUpBlock3D(paddle.nn.Layer):
438
+ """
439
+ An upsampling block used in the CogVideoX model.
440
+
441
+ Args:
442
+ in_channels (`int`):
443
+ Number of input channels.
444
+ out_channels (`int`, *optional*):
445
+ Number of output channels. If None, defaults to `in_channels`.
446
+ temb_channels (`int`, defaults to `512`):
447
+ Number of time embedding channels.
448
+ dropout (`float`, defaults to `0.0`):
449
+ Dropout rate.
450
+ num_layers (`int`, defaults to `1`):
451
+ Number of resnet layers.
452
+ resnet_eps (`float`, defaults to `1e-6`):
453
+ Epsilon value for normalization layers.
454
+ resnet_act_fn (`str`, defaults to `"swish"`):
455
+ Activation function to use.
456
+ resnet_groups (`int`, defaults to `32`):
457
+ Number of groups to separate the channels into for group normalization.
458
+ spatial_norm_dim (`int`, defaults to `16`):
459
+ The dimension to use for spatial norm if it is to be used instead of group norm.
460
+ add_upsample (`bool`, defaults to `True`):
461
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
462
+ compress_time (`bool`, defaults to `False`):
463
+ Whether or not to downsample across temporal dimension.
464
+ pad_mode (str, defaults to `"first"`):
465
+ Padding mode.
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ in_channels: int,
471
+ out_channels: int,
472
+ temb_channels: int,
473
+ dropout: float = 0.0,
474
+ num_layers: int = 1,
475
+ resnet_eps: float = 1e-06,
476
+ resnet_act_fn: str = "swish",
477
+ resnet_groups: int = 32,
478
+ spatial_norm_dim: int = 16,
479
+ add_upsample: bool = True,
480
+ upsample_padding: int = 1,
481
+ compress_time: bool = False,
482
+ pad_mode: str = "first",
483
+ ):
484
+ super().__init__()
485
+ resnets = []
486
+ for i in range(num_layers):
487
+ in_channel = in_channels if i == 0 else out_channels
488
+ resnets.append(
489
+ CogVideoXResnetBlock3D(
490
+ in_channels=in_channel,
491
+ out_channels=out_channels,
492
+ dropout=dropout,
493
+ temb_channels=temb_channels,
494
+ groups=resnet_groups,
495
+ eps=resnet_eps,
496
+ non_linearity=resnet_act_fn,
497
+ spatial_norm_dim=spatial_norm_dim,
498
+ pad_mode=pad_mode,
499
+ )
500
+ )
501
+ self.resnets = paddle.nn.LayerList(sublayers=resnets)
502
+ self.upsamplers = None
503
+ if add_upsample:
504
+ self.upsamplers = paddle.nn.LayerList(
505
+ sublayers=[
506
+ CogVideoXUpsample3D(
507
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
508
+ )
509
+ ]
510
+ )
511
+ self.gradient_checkpointing = False
512
+
513
+ def forward(
514
+ self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None, zq: Optional[paddle.Tensor] = None
515
+ ) -> paddle.Tensor:
516
+ """Forward method of the `CogVideoXUpBlock3D` class."""
517
+ for resnet in self.resnets:
518
+ if self.training and self.gradient_checkpointing:
519
+
520
+ def create_custom_forward(module):
521
+ def create_forward(*inputs):
522
+ return module(*inputs)
523
+
524
+ return create_forward
525
+
526
+ hidden_states = paddle.distributed.fleet.utils.recompute(
527
+ create_custom_forward(resnet), hidden_states, temb, zq
528
+ )
529
+ else:
530
+ hidden_states = resnet(hidden_states, temb, zq)
531
+ if self.upsamplers is not None:
532
+ for upsampler in self.upsamplers:
533
+ hidden_states = upsampler(hidden_states)
534
+ return hidden_states
535
+
536
+
537
+ class CogVideoXEncoder3D(paddle.nn.Layer):
538
+ """
539
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
540
+
541
+ Args:
542
+ in_channels (`int`, *optional*, defaults to 3):
543
+ The number of input channels.
544
+ out_channels (`int`, *optional*, defaults to 3):
545
+ The number of output channels.
546
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
547
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
548
+ options.
549
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
550
+ The number of output channels for each block.
551
+ act_fn (`str`, *optional*, defaults to `"silu"`):
552
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
553
+ layers_per_block (`int`, *optional*, defaults to 2):
554
+ The number of layers per block.
555
+ norm_num_groups (`int`, *optional*, defaults to 32):
556
+ The number of groups for normalization.
557
+ """
558
+
559
+ _supports_gradient_checkpointing = True
560
+
561
+ def __init__(
562
+ self,
563
+ in_channels: int = 3,
564
+ out_channels: int = 16,
565
+ down_block_types: Tuple[str, ...] = (
566
+ "CogVideoXDownBlock3D",
567
+ "CogVideoXDownBlock3D",
568
+ "CogVideoXDownBlock3D",
569
+ "CogVideoXDownBlock3D",
570
+ ),
571
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
572
+ layers_per_block: int = 3,
573
+ act_fn: str = "silu",
574
+ norm_eps: float = 1e-06,
575
+ norm_num_groups: int = 32,
576
+ dropout: float = 0.0,
577
+ pad_mode: str = "first",
578
+ temporal_compression_ratio: float = 4,
579
+ ):
580
+ super().__init__()
581
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
582
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
583
+ self.down_blocks = paddle.nn.LayerList(sublayers=[])
584
+ output_channel = block_out_channels[0]
585
+ for i, down_block_type in enumerate(down_block_types):
586
+ input_channel = output_channel
587
+ output_channel = block_out_channels[i]
588
+ is_final_block = i == len(block_out_channels) - 1
589
+ compress_time = i < temporal_compress_level
590
+ if down_block_type == "CogVideoXDownBlock3D":
591
+ down_block = CogVideoXDownBlock3D(
592
+ in_channels=input_channel,
593
+ out_channels=output_channel,
594
+ temb_channels=0,
595
+ dropout=dropout,
596
+ num_layers=layers_per_block,
597
+ resnet_eps=norm_eps,
598
+ resnet_act_fn=act_fn,
599
+ resnet_groups=norm_num_groups,
600
+ add_downsample=not is_final_block,
601
+ compress_time=compress_time,
602
+ )
603
+ else:
604
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
605
+ self.down_blocks.append(down_block)
606
+ self.mid_block = CogVideoXMidBlock3D(
607
+ in_channels=block_out_channels[-1],
608
+ temb_channels=0,
609
+ dropout=dropout,
610
+ num_layers=2,
611
+ resnet_eps=norm_eps,
612
+ resnet_act_fn=act_fn,
613
+ resnet_groups=norm_num_groups,
614
+ pad_mode=pad_mode,
615
+ )
616
+ self.norm_out = paddle.nn.GroupNorm(
617
+ num_groups=norm_num_groups, num_channels=block_out_channels[-1], epsilon=1e-06
618
+ )
619
+ self.conv_act = paddle.nn.Silu()
620
+ self.conv_out = CogVideoXCausalConv3d(
621
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
622
+ )
623
+ self.gradient_checkpointing = False
624
+
625
+ def forward(self, sample: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
626
+ """The forward method of the `CogVideoXEncoder3D` class."""
627
+ hidden_states = self.conv_in(sample)
628
+ if self.training and self.gradient_checkpointing:
629
+
630
+ def create_custom_forward(module):
631
+ def custom_forward(*inputs):
632
+ return module(*inputs)
633
+
634
+ return custom_forward
635
+
636
+ for down_block in self.down_blocks:
637
+ hidden_states = paddle.distributed.fleet.utils.recompute(
638
+ create_custom_forward(down_block), hidden_states, temb, None
639
+ )
640
+ hidden_states = paddle.distributed.fleet.utils.recompute(
641
+ create_custom_forward(self.mid_block), hidden_states, temb, None
642
+ )
643
+ else:
644
+ for down_block in self.down_blocks:
645
+ hidden_states = down_block(hidden_states, temb, None)
646
+ hidden_states = self.mid_block(hidden_states, temb, None)
647
+ hidden_states = self.norm_out(hidden_states)
648
+ hidden_states = self.conv_act(hidden_states)
649
+ hidden_states = self.conv_out(hidden_states)
650
+ return hidden_states
651
+
652
+
653
+ class CogVideoXDecoder3D(paddle.nn.Layer):
654
+ """
655
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
656
+ sample.
657
+
658
+ Args:
659
+ in_channels (`int`, *optional*, defaults to 3):
660
+ The number of input channels.
661
+ out_channels (`int`, *optional*, defaults to 3):
662
+ The number of output channels.
663
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
664
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
665
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
666
+ The number of output channels for each block.
667
+ act_fn (`str`, *optional*, defaults to `"silu"`):
668
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
669
+ layers_per_block (`int`, *optional*, defaults to 2):
670
+ The number of layers per block.
671
+ norm_num_groups (`int`, *optional*, defaults to 32):
672
+ The number of groups for normalization.
673
+ """
674
+
675
+ _supports_gradient_checkpointing = True
676
+
677
+ def __init__(
678
+ self,
679
+ in_channels: int = 16,
680
+ out_channels: int = 3,
681
+ up_block_types: Tuple[str, ...] = (
682
+ "CogVideoXUpBlock3D",
683
+ "CogVideoXUpBlock3D",
684
+ "CogVideoXUpBlock3D",
685
+ "CogVideoXUpBlock3D",
686
+ ),
687
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
688
+ layers_per_block: int = 3,
689
+ act_fn: str = "silu",
690
+ norm_eps: float = 1e-06,
691
+ norm_num_groups: int = 32,
692
+ dropout: float = 0.0,
693
+ pad_mode: str = "first",
694
+ temporal_compression_ratio: float = 4,
695
+ ):
696
+ super().__init__()
697
+ reversed_block_out_channels = list(reversed(block_out_channels))
698
+ self.conv_in = CogVideoXCausalConv3d(
699
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
700
+ )
701
+ self.mid_block = CogVideoXMidBlock3D(
702
+ in_channels=reversed_block_out_channels[0],
703
+ temb_channels=0,
704
+ num_layers=2,
705
+ resnet_eps=norm_eps,
706
+ resnet_act_fn=act_fn,
707
+ resnet_groups=norm_num_groups,
708
+ spatial_norm_dim=in_channels,
709
+ pad_mode=pad_mode,
710
+ )
711
+ self.up_blocks = paddle.nn.LayerList(sublayers=[])
712
+ output_channel = reversed_block_out_channels[0]
713
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
714
+ for i, up_block_type in enumerate(up_block_types):
715
+ prev_output_channel = output_channel
716
+ output_channel = reversed_block_out_channels[i]
717
+ is_final_block = i == len(block_out_channels) - 1
718
+ compress_time = i < temporal_compress_level
719
+ if up_block_type == "CogVideoXUpBlock3D":
720
+ up_block = CogVideoXUpBlock3D(
721
+ in_channels=prev_output_channel,
722
+ out_channels=output_channel,
723
+ temb_channels=0,
724
+ dropout=dropout,
725
+ num_layers=layers_per_block + 1,
726
+ resnet_eps=norm_eps,
727
+ resnet_act_fn=act_fn,
728
+ resnet_groups=norm_num_groups,
729
+ spatial_norm_dim=in_channels,
730
+ add_upsample=not is_final_block,
731
+ compress_time=compress_time,
732
+ pad_mode=pad_mode,
733
+ )
734
+ prev_output_channel = output_channel
735
+ else:
736
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
737
+ self.up_blocks.append(up_block)
738
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
739
+ self.conv_act = paddle.nn.Silu()
740
+ self.conv_out = CogVideoXCausalConv3d(
741
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
742
+ )
743
+ self.gradient_checkpointing = False
744
+
745
+ def forward(self, sample: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
746
+ """The forward method of the `CogVideoXDecoder3D` class."""
747
+ hidden_states = self.conv_in(sample)
748
+ if self.training and self.gradient_checkpointing:
749
+
750
+ def create_custom_forward(module):
751
+ def custom_forward(*inputs):
752
+ return module(*inputs)
753
+
754
+ return custom_forward
755
+
756
+ hidden_states = paddle.distributed.fleet.utils.recompute(
757
+ create_custom_forward(self.mid_block), hidden_states, temb, sample
758
+ )
759
+ for up_block in self.up_blocks:
760
+ hidden_states = paddle.distributed.fleet.utils.recompute(
761
+ create_custom_forward(up_block), hidden_states, temb, sample
762
+ )
763
+ else:
764
+ hidden_states = self.mid_block(hidden_states, temb, sample)
765
+ for up_block in self.up_blocks:
766
+ hidden_states = up_block(hidden_states, temb, sample)
767
+ hidden_states = self.norm_out(hidden_states, sample)
768
+ hidden_states = self.conv_act(hidden_states)
769
+ hidden_states = self.conv_out(hidden_states)
770
+ return hidden_states
771
+
772
+
773
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin):
774
+ """
775
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
776
+ [CogVideoX](https://github.com/THUDM/CogVideo).
777
+
778
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
779
+ for all models (such as downloading or saving).
780
+
781
+ Parameters:
782
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
783
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
784
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
785
+ Tuple of downsample block types.
786
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
787
+ Tuple of upsample block types.
788
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
789
+ Tuple of block output channels.
790
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
791
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
792
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
793
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
794
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
795
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
796
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
797
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
798
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
799
+ force_upcast (`bool`, *optional*, default to `True`):
800
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
801
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
802
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
803
+ """
804
+
805
+ _supports_gradient_checkpointing = True
806
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
807
+
808
+ @register_to_config
809
+ def __init__(
810
+ self,
811
+ in_channels: int = 3,
812
+ out_channels: int = 3,
813
+ down_block_types: Tuple[str] = (
814
+ "CogVideoXDownBlock3D",
815
+ "CogVideoXDownBlock3D",
816
+ "CogVideoXDownBlock3D",
817
+ "CogVideoXDownBlock3D",
818
+ ),
819
+ up_block_types: Tuple[str] = (
820
+ "CogVideoXUpBlock3D",
821
+ "CogVideoXUpBlock3D",
822
+ "CogVideoXUpBlock3D",
823
+ "CogVideoXUpBlock3D",
824
+ ),
825
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
826
+ latent_channels: int = 16,
827
+ layers_per_block: int = 3,
828
+ act_fn: str = "silu",
829
+ norm_eps: float = 1e-06,
830
+ norm_num_groups: int = 32,
831
+ temporal_compression_ratio: float = 4,
832
+ sample_height: int = 480,
833
+ sample_width: int = 720,
834
+ scaling_factor: float = 1.15258426,
835
+ shift_factor: Optional[float] = None,
836
+ latents_mean: Optional[Tuple[float]] = None,
837
+ latents_std: Optional[Tuple[float]] = None,
838
+ force_upcast: float = True,
839
+ use_quant_conv: bool = False,
840
+ use_post_quant_conv: bool = False,
841
+ ):
842
+ super().__init__()
843
+ self.encoder = CogVideoXEncoder3D(
844
+ in_channels=in_channels,
845
+ out_channels=latent_channels,
846
+ down_block_types=down_block_types,
847
+ block_out_channels=block_out_channels,
848
+ layers_per_block=layers_per_block,
849
+ act_fn=act_fn,
850
+ norm_eps=norm_eps,
851
+ norm_num_groups=norm_num_groups,
852
+ temporal_compression_ratio=temporal_compression_ratio,
853
+ )
854
+ self.decoder = CogVideoXDecoder3D(
855
+ in_channels=latent_channels,
856
+ out_channels=out_channels,
857
+ up_block_types=up_block_types,
858
+ block_out_channels=block_out_channels,
859
+ layers_per_block=layers_per_block,
860
+ act_fn=act_fn,
861
+ norm_eps=norm_eps,
862
+ norm_num_groups=norm_num_groups,
863
+ temporal_compression_ratio=temporal_compression_ratio,
864
+ )
865
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
866
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
867
+ self.use_slicing = False
868
+ self.use_tiling = False
869
+ self.num_latent_frames_batch_size = 2
870
+ self.num_sample_frames_batch_size = 8
871
+ self.tile_sample_min_height = sample_height // 2
872
+ self.tile_sample_min_width = sample_width // 2
873
+ self.tile_latent_min_height = int(self.tile_sample_min_height / 2 ** (len(self.config.block_out_channels) - 1))
874
+ self.tile_latent_min_width = int(self.tile_sample_min_width / 2 ** (len(self.config.block_out_channels) - 1))
875
+ self.tile_overlap_factor_height = 1 / 6
876
+ self.tile_overlap_factor_width = 1 / 5
877
+
878
+ def _set_gradient_checkpointing(self, module, value=False):
879
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
880
+ module.gradient_checkpointing = value
881
+
882
+ def _clear_fake_context_parallel_cache(self):
883
+ for name, module in self.named_sublayers():
884
+ if isinstance(module, CogVideoXCausalConv3d):
885
+ logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
886
+ module._clear_fake_context_parallel_cache()
887
+
888
+ def enable_tiling(
889
+ self,
890
+ tile_sample_min_height: Optional[int] = None,
891
+ tile_sample_min_width: Optional[int] = None,
892
+ tile_overlap_factor_height: Optional[float] = None,
893
+ tile_overlap_factor_width: Optional[float] = None,
894
+ ) -> None:
895
+ """
896
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
897
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
898
+ processing larger images.
899
+
900
+ Args:
901
+ tile_sample_min_height (`int`, *optional*):
902
+ The minimum height required for a sample to be separated into tiles across the height dimension.
903
+ tile_sample_min_width (`int`, *optional*):
904
+ The minimum width required for a sample to be separated into tiles across the width dimension.
905
+ tile_overlap_factor_height (`int`, *optional*):
906
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
907
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
908
+ value might cause more tiles to be processed leading to slow down of the decoding process.
909
+ tile_overlap_factor_width (`int`, *optional*):
910
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
911
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
912
+ value might cause more tiles to be processed leading to slow down of the decoding process.
913
+ """
914
+ self.use_tiling = True
915
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
916
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
917
+ self.tile_latent_min_height = int(self.tile_sample_min_height / 2 ** (len(self.config.block_out_channels) - 1))
918
+ self.tile_latent_min_width = int(self.tile_sample_min_width / 2 ** (len(self.config.block_out_channels) - 1))
919
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
920
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
921
+
922
+ def disable_tiling(self) -> None:
923
+ """
924
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
925
+ decoding in one step.
926
+ """
927
+ self.use_tiling = False
928
+
929
+ def enable_slicing(self) -> None:
930
+ """
931
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
932
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
933
+ """
934
+ self.use_slicing = True
935
+
936
+ def disable_slicing(self) -> None:
937
+ """
938
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
939
+ decoding in one step.
940
+ """
941
+ self.use_slicing = False
942
+
943
+ def _encode(self, x: paddle.Tensor) -> paddle.Tensor:
944
+ batch_size, num_channels, num_frames, height, width = tuple(x.shape)
945
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
946
+ return self.tiled_encode(x)
947
+ frame_batch_size = self.num_sample_frames_batch_size
948
+ num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
949
+ enc = []
950
+ for i in range(num_batches):
951
+ remaining_frames = num_frames % frame_batch_size
952
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
953
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
954
+ x_intermediate = x[:, :, start_frame:end_frame]
955
+ x_intermediate = self.encoder(x_intermediate)
956
+ if self.quant_conv is not None:
957
+ x_intermediate = self.quant_conv(x_intermediate)
958
+ enc.append(x_intermediate)
959
+ self._clear_fake_context_parallel_cache()
960
+ enc = paddle.concat(x=enc, axis=2)
961
+ return enc
962
+
963
+ @apply_forward_hook
964
+ def encode(
965
+ self, x: paddle.Tensor, return_dict: bool = True
966
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
967
+ """
968
+ Encode a batch of images into latents.
969
+
970
+ Args:
971
+ x (`torch.Tensor`): Input batch of images.
972
+ return_dict (`bool`, *optional*, defaults to `True`):
973
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
974
+
975
+ Returns:
976
+ The latent representations of the encoded videos. If `return_dict` is True, a
977
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
978
+ """
979
+ if self.use_slicing and tuple(x.shape)[0] > 1:
980
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
981
+ h = paddle.concat(x=encoded_slices)
982
+ else:
983
+ h = self._encode(x)
984
+
985
+ posterior = DiagonalGaussianDistribution(h)
986
+ if not return_dict:
987
+ return (posterior,)
988
+ return AutoencoderKLOutput(latent_dist=posterior)
989
+
990
+ def _decode(self, z: paddle.Tensor, return_dict: bool = True) -> Union[DecoderOutput, paddle.Tensor]:
991
+ batch_size, num_channels, num_frames, height, width = tuple(z.shape)
992
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
993
+ return self.tiled_decode(z, return_dict=return_dict)
994
+ frame_batch_size = self.num_latent_frames_batch_size
995
+ num_batches = num_frames // frame_batch_size
996
+ dec = []
997
+ for i in range(num_batches):
998
+ remaining_frames = num_frames % frame_batch_size
999
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1000
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1001
+ z_intermediate = z[:, :, start_frame:end_frame]
1002
+ if self.post_quant_conv is not None:
1003
+ z_intermediate = self.post_quant_conv(z_intermediate)
1004
+ z_intermediate = self.decoder(z_intermediate)
1005
+ dec.append(z_intermediate)
1006
+ self._clear_fake_context_parallel_cache()
1007
+ dec = paddle.concat(x=dec, axis=2)
1008
+ if not return_dict:
1009
+ return (dec,)
1010
+ return DecoderOutput(sample=dec)
1011
+
1012
+ @apply_forward_hook
1013
+ def decode(self, z: paddle.Tensor, return_dict: bool = True) -> Union[DecoderOutput, paddle.Tensor]:
1014
+ """
1015
+ Decode a batch of images.
1016
+
1017
+ Args:
1018
+ z (`torch.Tensor`): Input batch of latent vectors.
1019
+ return_dict (`bool`, *optional*, defaults to `True`):
1020
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1021
+
1022
+ Returns:
1023
+ [`~models.vae.DecoderOutput`] or `tuple`:
1024
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1025
+ returned.
1026
+ """
1027
+ if self.use_slicing and tuple(z.shape)[0] > 1:
1028
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1029
+ decoded = paddle.concat(x=decoded_slices)
1030
+ else:
1031
+ decoded = self._decode(z).sample
1032
+ if not return_dict:
1033
+ return (decoded,)
1034
+ return DecoderOutput(sample=decoded)
1035
+
1036
+ def blend_v(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int) -> paddle.Tensor:
1037
+ blend_extent = min(tuple(a.shape)[3], tuple(b.shape)[3], blend_extent)
1038
+ for y in range(blend_extent):
1039
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1040
+ y / blend_extent
1041
+ )
1042
+ return b
1043
+
1044
+ def blend_h(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int) -> paddle.Tensor:
1045
+ blend_extent = min(tuple(a.shape)[4], tuple(b.shape)[4], blend_extent)
1046
+ for x in range(blend_extent):
1047
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1048
+ x / blend_extent
1049
+ )
1050
+ return b
1051
+
1052
+ def tiled_encode(self, x: paddle.Tensor) -> paddle.Tensor:
1053
+ """Encode a batch of images using a tiled encoder.
1054
+
1055
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1056
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1057
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1058
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1059
+ output, but they should be much less noticeable.
1060
+
1061
+ Args:
1062
+ x (`torch.Tensor`): Input batch of videos.
1063
+
1064
+ Returns:
1065
+ `torch.Tensor`:
1066
+ The latent representation of the encoded videos.
1067
+ """
1068
+ batch_size, num_channels, num_frames, height, width = tuple(x.shape)
1069
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1070
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1071
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1072
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1073
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1074
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1075
+ frame_batch_size = self.num_sample_frames_batch_size
1076
+ rows = []
1077
+ for i in range(0, height, overlap_height):
1078
+ row = []
1079
+ for j in range(0, width, overlap_width):
1080
+ num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1081
+ time = []
1082
+ for k in range(num_batches):
1083
+ remaining_frames = num_frames % frame_batch_size
1084
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1085
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1086
+ tile = x[
1087
+ :,
1088
+ :,
1089
+ start_frame:end_frame,
1090
+ i : i + self.tile_sample_min_height,
1091
+ j : j + self.tile_sample_min_width,
1092
+ ]
1093
+ tile = self.encoder(tile)
1094
+ if self.quant_conv is not None:
1095
+ tile = self.quant_conv(tile)
1096
+ time.append(tile)
1097
+ self._clear_fake_context_parallel_cache()
1098
+ row.append(paddle.concat(x=time, axis=2))
1099
+ rows.append(row)
1100
+ result_rows = []
1101
+ for i, row in enumerate(rows):
1102
+ result_row = []
1103
+ for j, tile in enumerate(row):
1104
+ if i > 0:
1105
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1106
+ if j > 0:
1107
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1108
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1109
+ result_rows.append(paddle.concat(x=result_row, axis=4))
1110
+ enc = paddle.concat(x=result_rows, axis=3)
1111
+ return enc
1112
+
1113
+ def tiled_decode(self, z: paddle.Tensor, return_dict: bool = True) -> Union[DecoderOutput, paddle.Tensor]:
1114
+ """
1115
+ Decode a batch of images using a tiled decoder.
1116
+
1117
+ Args:
1118
+ z (`torch.Tensor`): Input batch of latent vectors.
1119
+ return_dict (`bool`, *optional*, defaults to `True`):
1120
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1121
+
1122
+ Returns:
1123
+ [`~models.vae.DecoderOutput`] or `tuple`:
1124
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1125
+ returned.
1126
+ """
1127
+ batch_size, num_channels, num_frames, height, width = tuple(z.shape)
1128
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1129
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1130
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1131
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1132
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1133
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1134
+ frame_batch_size = self.num_latent_frames_batch_size
1135
+ rows = []
1136
+ for i in range(0, height, overlap_height):
1137
+ row = []
1138
+ for j in range(0, width, overlap_width):
1139
+ num_batches = num_frames // frame_batch_size
1140
+ time = []
1141
+ for k in range(num_batches):
1142
+ remaining_frames = num_frames % frame_batch_size
1143
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1144
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1145
+ tile = z[
1146
+ :,
1147
+ :,
1148
+ start_frame:end_frame,
1149
+ i : i + self.tile_latent_min_height,
1150
+ j : j + self.tile_latent_min_width,
1151
+ ]
1152
+ if self.post_quant_conv is not None:
1153
+ tile = self.post_quant_conv(tile)
1154
+ tile = self.decoder(tile)
1155
+ time.append(tile)
1156
+ self._clear_fake_context_parallel_cache()
1157
+ row.append(paddle.concat(x=time, axis=2))
1158
+ rows.append(row)
1159
+ result_rows = []
1160
+ for i, row in enumerate(rows):
1161
+ result_row = []
1162
+ for j, tile in enumerate(row):
1163
+ if i > 0:
1164
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1165
+ if j > 0:
1166
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1167
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1168
+ result_rows.append(paddle.concat(x=result_row, axis=4))
1169
+ dec = paddle.concat(x=result_rows, axis=3)
1170
+ if not return_dict:
1171
+ return (dec,)
1172
+ return DecoderOutput(sample=dec)
1173
+
1174
+ def forward(
1175
+ self,
1176
+ sample: paddle.Tensor,
1177
+ sample_posterior: bool = False,
1178
+ return_dict: bool = True,
1179
+ generator: Optional[paddle.seed] = None,
1180
+ ) -> Union[paddle.Tensor, paddle.Tensor]:
1181
+ x = sample
1182
+ posterior = self.encode(x).latent_dist
1183
+ if sample_posterior:
1184
+ z = posterior.sample(generator=generator)
1185
+ else:
1186
+ z = posterior.mode()
1187
+ dec = self.decode(z)
1188
+ if not return_dict:
1189
+ return (dec,)
1190
+ return dec
PaddleMIX/ppdiffusers/ppdiffusers/models/autoencoder_kl_temporal_decoder.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import paddle
17
+ import paddle.nn as nn
18
+ from paddle.distributed.fleet.utils import recompute
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..loaders import FromOriginalVAEMixin
22
+ from ..utils import recompute_use_reentrant
23
+ from ..utils.accelerate_utils import apply_forward_hook
24
+ from .attention_processor import (
25
+ CROSS_ATTENTION_PROCESSORS,
26
+ AttentionProcessor,
27
+ AttnProcessor,
28
+ )
29
+ from .modeling_outputs import AutoencoderKLOutput
30
+ from .modeling_utils import ModelMixin
31
+ from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
32
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
33
+
34
+
35
+ class TemporalDecoder(nn.Layer):
36
+ def __init__(
37
+ self,
38
+ in_channels: int = 4,
39
+ out_channels: int = 3,
40
+ block_out_channels: Tuple[int] = (128, 256, 512, 512),
41
+ layers_per_block: int = 2,
42
+ ):
43
+ super().__init__()
44
+ self.layers_per_block = layers_per_block
45
+
46
+ self.conv_in = nn.Conv2D(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
47
+ self.mid_block = MidBlockTemporalDecoder(
48
+ num_layers=self.layers_per_block,
49
+ in_channels=block_out_channels[-1],
50
+ out_channels=block_out_channels[-1],
51
+ attention_head_dim=block_out_channels[-1],
52
+ )
53
+
54
+ # up
55
+ self.up_blocks = nn.LayerList([])
56
+ reversed_block_out_channels = list(reversed(block_out_channels))
57
+ output_channel = reversed_block_out_channels[0]
58
+ for i in range(len(block_out_channels)):
59
+ prev_output_channel = output_channel
60
+ output_channel = reversed_block_out_channels[i]
61
+
62
+ is_final_block = i == len(block_out_channels) - 1
63
+ up_block = UpBlockTemporalDecoder(
64
+ num_layers=self.layers_per_block + 1,
65
+ in_channels=prev_output_channel,
66
+ out_channels=output_channel,
67
+ add_upsample=not is_final_block,
68
+ )
69
+ self.up_blocks.append(up_block)
70
+ prev_output_channel = output_channel
71
+
72
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, epsilon=1e-6)
73
+
74
+ self.conv_act = nn.Silu()
75
+ self.conv_out = nn.Conv2D(
76
+ in_channels=block_out_channels[0],
77
+ out_channels=out_channels,
78
+ kernel_size=3,
79
+ padding=1,
80
+ )
81
+
82
+ conv_out_kernel_size = (3, 1, 1)
83
+ padding = [int(k // 2) for k in conv_out_kernel_size]
84
+ self.time_conv_out = nn.Conv3D(
85
+ in_channels=out_channels,
86
+ out_channels=out_channels,
87
+ kernel_size=conv_out_kernel_size,
88
+ padding=padding,
89
+ )
90
+
91
+ self.gradient_checkpointing = False
92
+
93
+ def forward(
94
+ self,
95
+ sample: paddle.Tensor,
96
+ image_only_indicator: paddle.Tensor,
97
+ num_frames: int = 1,
98
+ ) -> paddle.Tensor:
99
+ r"""The forward method of the `Decoder` class."""
100
+
101
+ sample = self.conv_in(sample)
102
+
103
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
104
+ if self.gradient_checkpointing and not sample.stop_gradient:
105
+
106
+ def create_custom_forward(module):
107
+ def custom_forward(*inputs):
108
+ return module(*inputs)
109
+
110
+ return custom_forward
111
+
112
+ ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
113
+ # middle
114
+ sample = recompute(
115
+ create_custom_forward(self.mid_block),
116
+ sample,
117
+ image_only_indicator,
118
+ **ckpt_kwargs,
119
+ )
120
+ sample = sample.cast(upscale_dtype)
121
+
122
+ # up
123
+ for up_block in self.up_blocks:
124
+ sample = recompute(
125
+ create_custom_forward(up_block),
126
+ sample,
127
+ image_only_indicator,
128
+ **ckpt_kwargs,
129
+ )
130
+ else:
131
+ # middle
132
+ sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
133
+ sample = sample.cast(upscale_dtype)
134
+
135
+ # up
136
+ for up_block in self.up_blocks:
137
+ sample = up_block(sample, image_only_indicator=image_only_indicator)
138
+
139
+ # post-process
140
+ sample = self.conv_norm_out(sample)
141
+ sample = self.conv_act(sample)
142
+ sample = self.conv_out(sample)
143
+
144
+ batch_frames, channels, height, width = sample.shape
145
+ batch_size = batch_frames // num_frames
146
+ sample = sample[None, :].reshape([batch_size, num_frames, channels, height, width]).transpose([0, 2, 1, 3, 4])
147
+ sample = self.time_conv_out(sample)
148
+
149
+ sample = sample.transpose([0, 2, 1, 3, 4]).reshape([batch_frames, channels, height, width])
150
+
151
+ return sample
152
+
153
+
154
+ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
155
+ r"""
156
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
157
+
158
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
159
+ for all models (such as downloading or saving).
160
+
161
+ Parameters:
162
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
163
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
164
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
165
+ Tuple of downsample block types.
166
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
167
+ Tuple of block output channels.
168
+ layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
169
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
170
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
171
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
172
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
173
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
174
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
175
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
176
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
177
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
178
+ force_upcast (`bool`, *optional*, default to `True`):
179
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
180
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
181
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
182
+ """
183
+
184
+ _supports_gradient_checkpointing = True
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ in_channels: int = 3,
190
+ out_channels: int = 3,
191
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
192
+ block_out_channels: Tuple[int] = (64,),
193
+ layers_per_block: int = 1,
194
+ latent_channels: int = 4,
195
+ sample_size: int = 32,
196
+ scaling_factor: float = 0.18215,
197
+ force_upcast: float = True,
198
+ ):
199
+ super().__init__()
200
+
201
+ # pass init params to Encoder
202
+ self.encoder = Encoder(
203
+ in_channels=in_channels,
204
+ out_channels=latent_channels,
205
+ down_block_types=down_block_types,
206
+ block_out_channels=block_out_channels,
207
+ layers_per_block=layers_per_block,
208
+ double_z=True,
209
+ )
210
+
211
+ # pass init params to Decoder
212
+ self.decoder = TemporalDecoder(
213
+ in_channels=latent_channels,
214
+ out_channels=out_channels,
215
+ block_out_channels=block_out_channels,
216
+ layers_per_block=layers_per_block,
217
+ )
218
+
219
+ self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1)
220
+
221
+ sample_size = (
222
+ self.config.sample_size[0]
223
+ if isinstance(self.config.sample_size, (list, tuple))
224
+ else self.config.sample_size
225
+ )
226
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
227
+ self.tile_overlap_factor = 0.25
228
+
229
+ def _set_gradient_checkpointing(self, module, value=False):
230
+ if isinstance(module, (Encoder, TemporalDecoder)):
231
+ module.gradient_checkpointing = value
232
+
233
+ @property
234
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
235
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
236
+ r"""
237
+ Returns:
238
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
239
+ indexed by its weight name.
240
+ """
241
+ # set recursively
242
+ processors = {}
243
+
244
+ def fn_recursive_add_processors(name: str, module: nn.Layer, processors: Dict[str, AttentionProcessor]):
245
+ if hasattr(module, "get_processor"):
246
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
247
+
248
+ for sub_name, child in module.named_children():
249
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
250
+
251
+ return processors
252
+
253
+ for name, module in self.named_children():
254
+ fn_recursive_add_processors(name, module, processors)
255
+
256
+ return processors
257
+
258
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
259
+ def set_attn_processor(
260
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
261
+ ):
262
+ r"""
263
+ Sets the attention processor to use to compute attention.
264
+
265
+ Parameters:
266
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
267
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
268
+ for **all** `Attention` layers.
269
+
270
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
271
+ processor. This is strongly recommended when setting trainable attention processors.
272
+
273
+ """
274
+ count = len(self.attn_processors.keys())
275
+
276
+ if isinstance(processor, dict) and len(processor) != count:
277
+ raise ValueError(
278
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
279
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
280
+ )
281
+
282
+ def fn_recursive_attn_processor(name: str, module: nn.Layer, processor):
283
+ if hasattr(module, "set_processor"):
284
+ if not isinstance(processor, dict):
285
+ module.set_processor(processor, _remove_lora=_remove_lora)
286
+ else:
287
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
288
+
289
+ for sub_name, child in module.named_children():
290
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
291
+
292
+ for name, module in self.named_children():
293
+ fn_recursive_attn_processor(name, module, processor)
294
+
295
+ def set_default_attn_processor(self):
296
+ """
297
+ Disables custom attention processors and sets the default attention implementation.
298
+ """
299
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
300
+ processor = AttnProcessor()
301
+ else:
302
+ raise ValueError(
303
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
304
+ )
305
+
306
+ self.set_attn_processor(processor, _remove_lora=True)
307
+
308
+ @apply_forward_hook
309
+ def encode(
310
+ self, x: paddle.Tensor, return_dict: bool = True
311
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
312
+ """
313
+ Encode a batch of images into latents.
314
+
315
+ Args:
316
+ x (`paddle.Tensor`): Input batch of images.
317
+ return_dict (`bool`, *optional*, defaults to `True`):
318
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
319
+
320
+ Returns:
321
+ The latent representations of the encoded images. If `return_dict` is True, a
322
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
323
+ """
324
+ # TODO junnyu, support float16
325
+ x = x.cast(self.encoder.conv_in.weight.dtype)
326
+ h = self.encoder(x)
327
+ moments = self.quant_conv(h)
328
+ posterior = DiagonalGaussianDistribution(moments)
329
+
330
+ if not return_dict:
331
+ return (posterior,)
332
+
333
+ return AutoencoderKLOutput(latent_dist=posterior)
334
+
335
+ @apply_forward_hook
336
+ def decode(
337
+ self,
338
+ z: paddle.Tensor,
339
+ num_frames: int,
340
+ return_dict: bool = True,
341
+ ) -> Union[DecoderOutput, paddle.Tensor]:
342
+ """
343
+ Decode a batch of images.
344
+
345
+ Args:
346
+ z (`paddle.Tensor`): Input batch of latent vectors.
347
+ return_dict (`bool`, *optional*, defaults to `True`):
348
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
349
+
350
+ Returns:
351
+ [`~models.vae.DecoderOutput`] or `tuple`:
352
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
353
+ returned.
354
+
355
+ """
356
+ # TODO junnyu, add this to support pure fp16
357
+ z = z.cast(self.quant_conv.weight.dtype)
358
+
359
+ batch_size = z.shape[0] // num_frames
360
+ image_only_indicator = paddle.zeros([batch_size, num_frames], dtype=z.dtype)
361
+ decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
362
+
363
+ if not return_dict:
364
+ return (decoded,)
365
+
366
+ return DecoderOutput(sample=decoded)
367
+
368
+ def forward(
369
+ self,
370
+ sample: paddle.Tensor,
371
+ sample_posterior: bool = False,
372
+ return_dict: bool = True,
373
+ generator: Optional[paddle.Generator] = None,
374
+ num_frames: int = 1,
375
+ ) -> Union[DecoderOutput, paddle.Tensor]:
376
+ r"""
377
+ Args:
378
+ sample (`paddle.Tensor`): Input sample.
379
+ sample_posterior (`bool`, *optional*, defaults to `False`):
380
+ Whether to sample from the posterior.
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
383
+ """
384
+ x = sample
385
+ posterior = self.encode(x).latent_dist
386
+ if sample_posterior:
387
+ z = posterior.sample(generator=generator)
388
+ else:
389
+ z = posterior.mode()
390
+
391
+ dec = self.decode(z, num_frames=num_frames).sample
392
+
393
+ if not return_dict:
394
+ return (dec,)
395
+
396
+ return DecoderOutput(sample=dec)
PaddleMIX/ppdiffusers/ppdiffusers/models/autoencoder_tiny.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import paddle
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput
23
+ from ..utils.accelerate_utils import apply_forward_hook
24
+ from .modeling_utils import ModelMixin
25
+ from .vae import DecoderOutput, DecoderTiny, EncoderTiny
26
+
27
+
28
+ @dataclass
29
+ class AutoencoderTinyOutput(BaseOutput):
30
+ """
31
+ Output of AutoencoderTiny encoding method.
32
+
33
+ Args:
34
+ latents (`paddle.Tensor`): Encoded outputs of the `Encoder`.
35
+
36
+ """
37
+
38
+ latents: paddle.Tensor
39
+
40
+
41
+ class AutoencoderTiny(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
44
+
45
+ [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
46
+
47
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
48
+ all models (such as downloading or saving).
49
+
50
+ Parameters:
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
54
+ Tuple of integers representing the number of output channels for each encoder block. The length of the
55
+ tuple should be equal to the number of encoder blocks.
56
+ decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
57
+ Tuple of integers representing the number of output channels for each decoder block. The length of the
58
+ tuple should be equal to the number of decoder blocks.
59
+ act_fn (`str`, *optional*, defaults to `"relu"`):
60
+ Activation function to be used throughout the model.
61
+ latent_channels (`int`, *optional*, defaults to 4):
62
+ Number of channels in the latent representation. The latent space acts as a compressed representation of
63
+ the input image.
64
+ upsampling_scaling_factor (`int`, *optional*, defaults to 2):
65
+ Scaling factor for upsampling in the decoder. It determines the size of the output image during the
66
+ upsampling process.
67
+ num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
68
+ Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
69
+ length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
70
+ number of encoder blocks.
71
+ num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
72
+ Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
73
+ length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
74
+ number of decoder blocks.
75
+ latent_magnitude (`float`, *optional*, defaults to 3.0):
76
+ Magnitude of the latent representation. This parameter scales the latent representation values to control
77
+ the extent of information preservation.
78
+ latent_shift (float, *optional*, defaults to 0.5):
79
+ Shift applied to the latent representation. This parameter controls the center of the latent space.
80
+ scaling_factor (`float`, *optional*, defaults to 1.0):
81
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
82
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
83
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
84
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
85
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
86
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
87
+ however, no such scaling factor was used, hence the value of 1.0 as the default.
88
+ force_upcast (`bool`, *optional*, default to `False`):
89
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
90
+ can be fine-tuned / trained to a lower range without losing too much precision, in which case
91
+ `force_upcast` can be set to `False` (see this fp16-friendly
92
+ [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
93
+ """
94
+
95
+ _supports_gradient_checkpointing = True
96
+
97
+ @register_to_config
98
+ def __init__(
99
+ self,
100
+ in_channels: int = 3,
101
+ out_channels: int = 3,
102
+ encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
103
+ decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
104
+ act_fn: str = "relu",
105
+ latent_channels: int = 4,
106
+ upsampling_scaling_factor: int = 2,
107
+ num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
108
+ num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
109
+ latent_magnitude: int = 3,
110
+ latent_shift: float = 0.5,
111
+ force_upcast: bool = False,
112
+ scaling_factor: float = 1.0,
113
+ ):
114
+ super().__init__()
115
+
116
+ if len(encoder_block_out_channels) != len(num_encoder_blocks):
117
+ raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
118
+ if len(decoder_block_out_channels) != len(num_decoder_blocks):
119
+ raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
120
+
121
+ self.encoder = EncoderTiny(
122
+ in_channels=in_channels,
123
+ out_channels=latent_channels,
124
+ num_blocks=num_encoder_blocks,
125
+ block_out_channels=encoder_block_out_channels,
126
+ act_fn=act_fn,
127
+ )
128
+
129
+ self.decoder = DecoderTiny(
130
+ in_channels=latent_channels,
131
+ out_channels=out_channels,
132
+ num_blocks=num_decoder_blocks,
133
+ block_out_channels=decoder_block_out_channels,
134
+ upsampling_scaling_factor=upsampling_scaling_factor,
135
+ act_fn=act_fn,
136
+ )
137
+
138
+ self.latent_magnitude = latent_magnitude
139
+ self.latent_shift = latent_shift
140
+ self.scaling_factor = scaling_factor
141
+
142
+ self.use_slicing = False
143
+ self.use_tiling = False
144
+
145
+ # only relevant if vae tiling is enabled
146
+ self.spatial_scale_factor = 2**out_channels
147
+ self.tile_overlap_factor = 0.125
148
+ self.tile_sample_min_size = 512
149
+ self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
150
+
151
+ self.register_to_config(block_out_channels=decoder_block_out_channels)
152
+ self.register_to_config(force_upcast=False)
153
+
154
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
155
+ if isinstance(module, (EncoderTiny, DecoderTiny)):
156
+ module.gradient_checkpointing = value
157
+
158
+ def scale_latents(self, x: paddle.Tensor) -> paddle.Tensor:
159
+ """raw latents -> [0, 1]"""
160
+ return ((x / 2 * self.latent_magnitude) + self.latent_shift).clip(0, 1)
161
+
162
+ def unscale_latents(self, x: paddle.Tensor) -> paddle.Tensor:
163
+ """[0, 1] -> raw latents"""
164
+ return (x - self.latent_shift) * (2 * self.latent_magnitude)
165
+
166
+ def enable_slicing(self) -> None:
167
+ r"""
168
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
169
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
170
+ """
171
+ self.use_slicing = True
172
+
173
+ def disable_slicing(self) -> None:
174
+ r"""
175
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
176
+ decoding in one step.
177
+ """
178
+ self.use_slicing = False
179
+
180
+ def enable_tiling(self, use_tiling: bool = True) -> None:
181
+ r"""
182
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
183
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
184
+ processing larger images.
185
+ """
186
+ self.use_tiling = use_tiling
187
+
188
+ def disable_tiling(self) -> None:
189
+ r"""
190
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
191
+ decoding in one step.
192
+ """
193
+ self.enable_tiling(False)
194
+
195
+ def _tiled_encode(self, x: paddle.Tensor) -> paddle.Tensor:
196
+ r"""Encode a batch of images using a tiled encoder.
197
+
198
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
199
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
200
+ tiles overlap and are blended together to form a smooth output.
201
+
202
+ Args:
203
+ x (`paddle.Tensor`): Input batch of images.
204
+
205
+ Returns:
206
+ `paddle.Tensor`: Encoded batch of images.
207
+ """
208
+ # scale of encoder output relative to input
209
+ sf = self.spatial_scale_factor
210
+ tile_size = self.tile_sample_min_size
211
+
212
+ # number of pixels to blend and to traverse between tile
213
+ blend_size = int(tile_size * self.tile_overlap_factor)
214
+ traverse_size = tile_size - blend_size
215
+
216
+ # tiles index (up/left)
217
+ ti = range(0, x.shape[-2], traverse_size)
218
+ tj = range(0, x.shape[-1], traverse_size)
219
+
220
+ # mask for blending
221
+ blend_masks = paddle.stack(
222
+ paddle.meshgrid([paddle.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
223
+ )
224
+ blend_masks = blend_masks.clip(0, 1)
225
+
226
+ # output array
227
+ out = paddle.zeros([x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf])
228
+ for i in ti:
229
+ for j in tj:
230
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
231
+ # tile result
232
+ tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
233
+ tile = self.encoder(tile_in)
234
+ h, w = tile.shape[-2], tile.shape[-1]
235
+ # blend tile result into output
236
+ blend_mask_i = paddle.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
237
+ blend_mask_j = paddle.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
238
+ blend_mask = blend_mask_i * blend_mask_j
239
+ tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
240
+
241
+ # NOTE this copy_ method is not work in paddlepaddle
242
+ # tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
243
+ out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] = (
244
+ blend_mask * tile + (1 - blend_mask) * tile_out
245
+ )
246
+
247
+ return out
248
+
249
+ def _tiled_decode(self, x: paddle.Tensor) -> paddle.Tensor:
250
+ r"""Encode a batch of images using a tiled encoder.
251
+
252
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
253
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
254
+ tiles overlap and are blended together to form a smooth output.
255
+
256
+ Args:
257
+ x (`paddle.Tensor`): Input batch of images.
258
+
259
+ Returns:
260
+ `paddle.Tensor`: Encoded batch of images.
261
+ """
262
+ # scale of decoder output relative to input
263
+ sf = self.spatial_scale_factor
264
+ tile_size = self.tile_latent_min_size
265
+
266
+ # number of pixels to blend and to traverse between tiles
267
+ blend_size = int(tile_size * self.tile_overlap_factor)
268
+ traverse_size = tile_size - blend_size
269
+
270
+ # tiles index (up/left)
271
+ ti = range(0, x.shape[-2], traverse_size)
272
+ tj = range(0, x.shape[-1], traverse_size)
273
+
274
+ # mask for blending
275
+ blend_masks = paddle.stack(
276
+ paddle.meshgrid([paddle.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
277
+ )
278
+ blend_masks = blend_masks.clip(0, 1)
279
+
280
+ # output array
281
+ out = paddle.zeros([x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf])
282
+ for i in ti:
283
+ for j in tj:
284
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
285
+ # tile result
286
+ tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
287
+ tile = self.decoder(tile_in)
288
+ h, w = tile.shape[-2], tile.shape[-1]
289
+ # blend tile result into output
290
+ blend_mask_i = paddle.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
291
+ blend_mask_j = paddle.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
292
+ blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
293
+
294
+ # NOTE this copy_ method is not work in paddlepaddle
295
+ # tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
296
+ out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] = (
297
+ blend_mask * tile + (1 - blend_mask) * tile_out
298
+ )
299
+ return out
300
+
301
+ @apply_forward_hook
302
+ def encode(self, x: paddle.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[paddle.Tensor]]:
303
+ # TODO junnyu, support float16
304
+ x = x.cast(self.encoder.layers[0].weight.dtype)
305
+ if self.use_slicing and x.shape[0] > 1:
306
+ output = [
307
+ self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.chunk(x.shape[0])
308
+ ]
309
+ output = paddle.concat(output)
310
+ else:
311
+ output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
312
+
313
+ if not return_dict:
314
+ return (output,)
315
+
316
+ return AutoencoderTinyOutput(latents=output)
317
+
318
+ @apply_forward_hook
319
+ def decode(
320
+ self, x: paddle.Tensor, generator: Optional[paddle.Generator] = None, return_dict: bool = True
321
+ ) -> Union[DecoderOutput, Tuple[paddle.Tensor]]:
322
+ # TODO junnyu, add this to support pure fp16
323
+ x = x.cast(self.decoder.layers[0].weight.dtype)
324
+
325
+ if self.use_slicing and x.shape[0] > 1:
326
+ output = [
327
+ self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.chunk(x.shape[0])
328
+ ]
329
+ output = paddle.concat(output)
330
+ else:
331
+ output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
332
+
333
+ if not return_dict:
334
+ return (output,)
335
+
336
+ return DecoderOutput(sample=output)
337
+
338
+ def forward(
339
+ self,
340
+ sample: paddle.Tensor,
341
+ return_dict: bool = True,
342
+ ) -> Union[DecoderOutput, Tuple[paddle.Tensor]]:
343
+ r"""
344
+ Args:
345
+ sample (`paddle.Tensor`): Input sample.
346
+ return_dict (`bool`, *optional*, defaults to `True`):
347
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
348
+ """
349
+ enc = self.encode(sample).latents
350
+
351
+ # scale latents to be in [0, 1], then quantize latents to a byte tensor,
352
+ # as if we were storing the latents in an RGBA uint8 image.
353
+ scaled_enc = (self.scale_latents(enc) * 255).round().cast("byte")
354
+
355
+ # unquantize latents back into [0, 1], then unscale latents back to their original range,
356
+ # as if we were loading the latents from an RGBA uint8 image.
357
+ unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
358
+
359
+ dec = self.decode(unscaled_enc)
360
+
361
+ if not return_dict:
362
+ return (dec,)
363
+ return DecoderOutput(sample=dec)
PaddleMIX/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import paddle
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+ from ..configuration_utils import ConfigMixin, register_to_config
4
+ from ..utils import logging
5
+ from ..utils.paddle_utils import maybe_allow_in_graph
6
+ from .attention import Attention, FeedForward
7
+ from .attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
8
+ from .embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
9
+ from .modeling_outputs import Transformer2DModelOutput
10
+ from .modeling_utils import ModelMixin
11
+ from .normalization import AdaLayerNorm, CogVideoXLayerNormZero
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ @maybe_allow_in_graph
16
+ class CogVideoXBlock(paddle.nn.Layer):
17
+ """
18
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
19
+
20
+ Parameters:
21
+ dim (`int`):
22
+ The number of channels in the input and output.
23
+ num_attention_heads (`int`):
24
+ The number of heads to use for multi-head attention.
25
+ attention_head_dim (`int`):
26
+ The number of channels in each head.
27
+ time_embed_dim (`int`):
28
+ The number of channels in timestep embedding.
29
+ dropout (`float`, defaults to `0.0`):
30
+ The dropout probability to use.
31
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
32
+ Activation function to be used in feed-forward.
33
+ attention_bias (`bool`, defaults to `False`):
34
+ Whether or not to use bias in attention projection layers.
35
+ qk_norm (`bool`, defaults to `True`):
36
+ Whether or not to use normalization after query and key projections in Attention.
37
+ norm_elementwise_affine (`bool`, defaults to `True`):
38
+ Whether to use learnable elementwise affine parameters for normalization.
39
+ norm_eps (`float`, defaults to `1e-5`):
40
+ Epsilon value for normalization layers.
41
+ final_dropout (`bool` defaults to `False`):
42
+ Whether to apply a final dropout after the last feed-forward layer.
43
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
44
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
45
+ ff_bias (`bool`, defaults to `True`):
46
+ Whether or not to use bias in Feed-forward layer.
47
+ attention_out_bias (`bool`, defaults to `True`):
48
+ Whether or not to use bias in Attention output projection layer.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ num_attention_heads: int,
55
+ attention_head_dim: int,
56
+ time_embed_dim: int,
57
+ dropout: float=0.0,
58
+ activation_fn: str='gelu-approximate',
59
+ attention_bias: bool=False,
60
+ qk_norm: bool=True,
61
+ norm_elementwise_affine: bool=True,
62
+ norm_eps: float=1e-05,
63
+ final_dropout: bool=True,
64
+ ff_inner_dim: Optional[int]=None,
65
+ ff_bias: bool=True,
66
+ attention_out_bias: bool=True
67
+ ):
68
+ super().__init__()
69
+
70
+ # 1. self attention
71
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
72
+
73
+ self.attn1 = Attention(
74
+ query_dim=dim,
75
+ dim_head=attention_head_dim,
76
+ heads=num_attention_heads,
77
+ qk_norm='layer_norm' if qk_norm else None,
78
+ eps=1e-06,
79
+ bias=attention_bias,
80
+ out_bias=attention_out_bias,
81
+ processor=CogVideoXAttnProcessor2_0()
82
+ )
83
+
84
+ # 2. feed forward
85
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
86
+ self.ff = FeedForward(
87
+ dim,
88
+ dropout=dropout,
89
+ activation_fn=activation_fn,
90
+ final_dropout=final_dropout,
91
+ inner_dim=ff_inner_dim,
92
+ bias=ff_bias
93
+ )
94
+
95
+ def forward(
96
+ self,
97
+ hidden_states: paddle.Tensor,
98
+ encoder_hidden_states:paddle.Tensor,
99
+ temb: paddle.Tensor,
100
+ image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]]=None
101
+ ) ->paddle.Tensor:
102
+ text_seq_length = encoder_hidden_states.shape[1]
103
+
104
+ # norm and modulate
105
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
106
+ hidden_states, encoder_hidden_states, temb
107
+ )
108
+
109
+ # attention
110
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
111
+ hidden_states=norm_hidden_states,
112
+ encoder_hidden_states=norm_encoder_hidden_states,
113
+ image_rotary_emb=image_rotary_emb
114
+ )
115
+
116
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
117
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
118
+
119
+ # norm and modulate
120
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
121
+ hidden_states, encoder_hidden_states, temb
122
+ )
123
+
124
+ # feed forward
125
+ norm_hidden_states = paddle.concat([norm_encoder_hidden_states, norm_hidden_states], axis=1)
126
+ ff_output = self.ff(norm_hidden_states)
127
+
128
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
129
+ encoder_hidden_states = (encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length])
130
+
131
+ return hidden_states, encoder_hidden_states
132
+
133
+
134
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
135
+ """
136
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
137
+
138
+ Parameters:
139
+ num_attention_heads (`int`, defaults to `30`):
140
+ The number of heads to use for multi-head attention.
141
+ attention_head_dim (`int`, defaults to `64`):
142
+ The number of channels in each head.
143
+ in_channels (`int`, defaults to `16`):
144
+ The number of channels in the input.
145
+ out_channels (`int`, *optional*, defaults to `16`):
146
+ The number of channels in the output.
147
+ flip_sin_to_cos (`bool`, defaults to `True`):
148
+ Whether to flip the sin to cos in the time embedding.
149
+ time_embed_dim (`int`, defaults to `512`):
150
+ Output dimension of timestep embeddings.
151
+ text_embed_dim (`int`, defaults to `4096`):
152
+ Input dimension of text embeddings from the text encoder.
153
+ num_layers (`int`, defaults to `30`):
154
+ The number of layers of Transformer blocks to use.
155
+ dropout (`float`, defaults to `0.0`):
156
+ The dropout probability to use.
157
+ attention_bias (`bool`, defaults to `True`):
158
+ Whether or not to use bias in the attention projection layers.
159
+ sample_width (`int`, defaults to `90`):
160
+ The width of the input latents.
161
+ sample_height (`int`, defaults to `60`):
162
+ The height of the input latents.
163
+ sample_frames (`int`, defaults to `49`):
164
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
165
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
166
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
167
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
168
+ patch_size (`int`, defaults to `2`):
169
+ The size of the patches to use in the patch embedding layer.
170
+ temporal_compression_ratio (`int`, defaults to `4`):
171
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
172
+ max_text_seq_length (`int`, defaults to `226`):
173
+ The maximum sequence length of the input text embeddings.
174
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
175
+ Activation function to use in feed-forward.
176
+ timestep_activation_fn (`str`, defaults to `"silu"`):
177
+ Activation function to use when generating the timestep embeddings.
178
+ norm_elementwise_affine (`bool`, defaults to `True`):
179
+ Whether or not to use elementwise affine in normalization layers.
180
+ norm_eps (`float`, defaults to `1e-5`):
181
+ The epsilon value to use in normalization layers.
182
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
183
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
184
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
185
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
186
+ """
187
+ _supports_gradient_checkpointing = True
188
+
189
+ @register_to_config
190
+ def __init__(self, num_attention_heads: int=30, attention_head_dim: int
191
+ =64, in_channels: int=16, out_channels: Optional[int]=16,
192
+ flip_sin_to_cos: bool=True, freq_shift: int=0, time_embed_dim: int=
193
+ 512, text_embed_dim: int=4096, num_layers: int=30, dropout: float=
194
+ 0.0, attention_bias: bool=True, sample_width: int=90, sample_height:
195
+ int=60, sample_frames: int=49, patch_size: int=2,
196
+ temporal_compression_ratio: int=4, max_text_seq_length: int=226,
197
+ activation_fn: str='gelu-approximate', timestep_activation_fn: str=
198
+ 'silu', norm_elementwise_affine: bool=True, norm_eps: float=1e-05,
199
+ spatial_interpolation_scale: float=1.875,
200
+ temporal_interpolation_scale: float=1.0,
201
+ use_rotary_positional_embeddings: bool=False,
202
+ use_learned_positional_embeddings: bool=False):
203
+ super().__init__()
204
+ inner_dim = num_attention_heads * attention_head_dim
205
+ if (not use_rotary_positional_embeddings and
206
+ use_learned_positional_embeddings):
207
+ raise ValueError(
208
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional embeddings. If you're using a custom model and/or believe this should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
209
+ )
210
+ self.patch_embed = CogVideoXPatchEmbed(patch_size=patch_size,
211
+ in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=
212
+ text_embed_dim, bias=True, sample_width=sample_width,
213
+ sample_height=sample_height, sample_frames=sample_frames,
214
+ temporal_compression_ratio=temporal_compression_ratio,
215
+ max_text_seq_length=max_text_seq_length,
216
+ spatial_interpolation_scale=spatial_interpolation_scale,
217
+ temporal_interpolation_scale=temporal_interpolation_scale,
218
+ use_positional_embeddings=not use_rotary_positional_embeddings,
219
+ use_learned_positional_embeddings=use_learned_positional_embeddings
220
+ )
221
+ self.embedding_dropout = paddle.nn.Dropout(p=dropout)
222
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
223
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim,
224
+ timestep_activation_fn)
225
+ self.transformer_blocks = paddle.nn.LayerList(sublayers=[
226
+ CogVideoXBlock(dim=inner_dim, num_attention_heads=
227
+ num_attention_heads, attention_head_dim=attention_head_dim,
228
+ time_embed_dim=time_embed_dim, dropout=dropout, activation_fn=
229
+ activation_fn, attention_bias=attention_bias,
230
+ norm_elementwise_affine=norm_elementwise_affine, norm_eps=
231
+ norm_eps) for _ in range(num_layers)])
232
+ self.norm_final = paddle.nn.LayerNorm(normalized_shape=inner_dim,
233
+ epsilon=norm_eps, weight_attr=norm_elementwise_affine,
234
+ bias_attr=norm_elementwise_affine)
235
+ self.norm_out = AdaLayerNorm(embedding_dim=time_embed_dim,
236
+ output_dim=2 * inner_dim, norm_elementwise_affine=
237
+ norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1)
238
+ self.proj_out = paddle.nn.Linear(in_features=inner_dim,
239
+ out_features=patch_size * patch_size * out_channels)
240
+ self.gradient_checkpointing = False
241
+
242
+ def _set_gradient_checkpointing(self, module, value=False):
243
+ self.gradient_checkpointing = value
244
+
245
+ @property
246
+ def attn_processors(self) ->Dict[str, AttentionProcessor]:
247
+ """
248
+ Returns:
249
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
250
+ indexed by its weight name.
251
+ """
252
+ processors = {}
253
+
254
+ def fn_recursive_add_processors(name: str, module: paddle.nn.Layer,
255
+ processors: Dict[str, AttentionProcessor]):
256
+ if hasattr(module, 'get_processor'):
257
+ processors[f'{name}.processor'] = module.get_processor()
258
+ for sub_name, child in module.named_children():
259
+ fn_recursive_add_processors(f'{name}.{sub_name}', child,
260
+ processors)
261
+ return processors
262
+ for name, module in self.named_children():
263
+ fn_recursive_add_processors(name, module, processors)
264
+ return processors
265
+
266
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[
267
+ str, AttentionProcessor]]):
268
+ """
269
+ Sets the attention processor to use to compute attention.
270
+
271
+ Parameters:
272
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
273
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
274
+ for **all** `Attention` layers.
275
+
276
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
277
+ processor. This is strongly recommended when setting trainable attention processors.
278
+
279
+ """
280
+ count = len(self.attn_processors.keys())
281
+ if isinstance(processor, dict) and len(processor) != count:
282
+ raise ValueError(
283
+ f'A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.'
284
+ )
285
+
286
+ def fn_recursive_attn_processor(name: str, module: paddle.nn.Layer,
287
+ processor):
288
+ if hasattr(module, 'set_processor'):
289
+ if not isinstance(processor, dict):
290
+ module.set_processor(processor)
291
+ else:
292
+ module.set_processor(processor.pop(f'{name}.processor'))
293
+ for sub_name, child in module.named_children():
294
+ fn_recursive_attn_processor(f'{name}.{sub_name}', child,
295
+ processor)
296
+ for name, module in self.named_children():
297
+ fn_recursive_attn_processor(name, module, processor)
298
+
299
+ def fuse_qkv_projections(self):
300
+ """
301
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
302
+ are fused. For cross-attention modules, key and value projection matrices are fused.
303
+
304
+ <Tip warning={true}>
305
+
306
+ This API is ���� experimental.
307
+
308
+ </Tip>
309
+ """
310
+ self.original_attn_processors = None
311
+ for _, attn_processor in self.attn_processors.items():
312
+ if 'Added' in str(attn_processor.__class__.__name__):
313
+ raise ValueError(
314
+ '`fuse_qkv_projections()` is not supported for models having added KV projections.'
315
+ )
316
+ self.original_attn_processors = self.attn_processors
317
+ for module in self.sublayers():
318
+ if isinstance(module, Attention):
319
+ module.fuse_projections(fuse=True)
320
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
321
+
322
+ def unfuse_qkv_projections(self):
323
+ """Disables the fused QKV projection if enabled.
324
+
325
+ <Tip warning={true}>
326
+
327
+ This API is 🧪 experimental.
328
+
329
+ </Tip>
330
+
331
+ """
332
+ if self.original_attn_processors is not None:
333
+ self.set_attn_processor(self.original_attn_processors)
334
+
335
+ def forward(
336
+ self,
337
+ hidden_states: paddle.Tensor,
338
+ encoder_hidden_states: paddle.Tensor,
339
+ timestep: Union[int, float, paddle.Tensor],
340
+ timestep_cond: Optional[paddle.Tensor]=None,
341
+ image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]]=None,
342
+ return_dict: bool=True
343
+ ):
344
+ batch_size, num_frames, channels, height, width = hidden_states.shape
345
+
346
+ # 1. Time embedding
347
+ timesteps = timestep
348
+ t_emb = self.time_proj(timesteps)
349
+ t_emb = t_emb.cast(hidden_states.dtype)
350
+ emb = self.time_embedding(t_emb, timestep_cond)
351
+
352
+ # 2. Patch embedding
353
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
354
+ hidden_states = self.embedding_dropout(hidden_states)
355
+
356
+ text_seq_length = tuple(encoder_hidden_states.shape)[1]
357
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
358
+ hidden_states = hidden_states[:, text_seq_length:]
359
+
360
+ # 3. Transformer blocks
361
+ for i, block in enumerate(self.transformer_blocks):
362
+ if self.training and self.gradient_checkpointing:
363
+ raise NotImplementedError
364
+ else:
365
+ hidden_states, encoder_hidden_states = block(
366
+ hidden_states=hidden_states,
367
+ encoder_hidden_states=encoder_hidden_states,
368
+ temb=emb,
369
+ image_rotary_emb=image_rotary_emb
370
+ )
371
+ # print("hidden_states:", hidden_states.abs().mean().item(), hidden_states.min().item(), hidden_states.max().item())
372
+ # print("encoder_hidden_states:", encoder_hidden_states.abs().mean().item(), encoder_hidden_states.min().item(), encoder_hidden_states.max().item())
373
+
374
+ if not self.config.use_rotary_positional_embeddings:
375
+ # 2B
376
+ hidden_states = self.norm_final(hidden_states)
377
+ else:
378
+ # 5B
379
+ hidden_states = paddle.concat(x=[encoder_hidden_states,
380
+ hidden_states], axis=1)
381
+ hidden_states = self.norm_final(hidden_states)
382
+ hidden_states = hidden_states[:, text_seq_length:]
383
+
384
+ # 4. Final block
385
+ hidden_states = self.norm_out(hidden_states, temb=emb)
386
+ hidden_states = self.proj_out(hidden_states)
387
+
388
+ # 5. Unpatchify
389
+ p = self.config.patch_size
390
+ output = hidden_states.reshape([batch_size, num_frames, height // p, width // p, -1, p, p])
391
+ output = output.transpose(perm=[0, 1, 4, 2, 5, 3, 6]).flatten(5, 6).flatten(3, 4)
392
+ if not return_dict:
393
+ return output,
394
+ return Transformer2DModelOutput(sample=output)
PaddleMIX/ppdiffusers/ppdiffusers/models/consistency_decoder_vae.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import paddle
18
+ import paddle.nn.functional as F
19
+ from paddle import nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..schedulers import ConsistencyDecoderScheduler
23
+ from ..utils import BaseOutput
24
+ from ..utils.accelerate_utils import apply_forward_hook
25
+ from ..utils.paddle_utils import randn_tensor
26
+ from .attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from .modeling_utils import ModelMixin
34
+ from .unet_2d import UNet2DModel
35
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
36
+
37
+
38
+ @dataclass
39
+ class ConsistencyDecoderVAEOutput(BaseOutput):
40
+ """
41
+ Output of encoding method.
42
+
43
+ Args:
44
+ latent_dist (`DiagonalGaussianDistribution`):
45
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
46
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
47
+ """
48
+
49
+ latent_dist: "DiagonalGaussianDistribution"
50
+
51
+
52
+ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
53
+ r"""
54
+ The consistency decoder used with DALL-E 3.
55
+
56
+ Examples:
57
+ ```py
58
+ >>> import paddle
59
+ >>> from ppdiffusers import StableDiffusionPipeline, ConsistencyDecoderVAE
60
+
61
+ >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", paddle_dtype=paddle.float16)
62
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
63
+ ... "runwayml/stable-diffusion-v1-5", vae=vae, paddle_dtype=paddle.float16
64
+ ... )
65
+
66
+ >>> pipe("horse", generator=paddle.Generator().manual_seed(0)).images
67
+ ```
68
+ """
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ scaling_factor: float = 0.18215,
74
+ latent_channels: int = 4,
75
+ encoder_act_fn: str = "silu",
76
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
77
+ encoder_double_z: bool = True,
78
+ encoder_down_block_types: Tuple[str, ...] = (
79
+ "DownEncoderBlock2D",
80
+ "DownEncoderBlock2D",
81
+ "DownEncoderBlock2D",
82
+ "DownEncoderBlock2D",
83
+ ),
84
+ encoder_in_channels: int = 3,
85
+ encoder_layers_per_block: int = 2,
86
+ encoder_norm_num_groups: int = 32,
87
+ encoder_out_channels: int = 4,
88
+ decoder_add_attention: bool = False,
89
+ decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
90
+ decoder_down_block_types: Tuple[str, ...] = (
91
+ "ResnetDownsampleBlock2D",
92
+ "ResnetDownsampleBlock2D",
93
+ "ResnetDownsampleBlock2D",
94
+ "ResnetDownsampleBlock2D",
95
+ ),
96
+ decoder_downsample_padding: int = 1,
97
+ decoder_in_channels: int = 7,
98
+ decoder_layers_per_block: int = 3,
99
+ decoder_norm_eps: float = 1e-05,
100
+ decoder_norm_num_groups: int = 32,
101
+ decoder_num_train_timesteps: int = 1024,
102
+ decoder_out_channels: int = 6,
103
+ decoder_resnet_time_scale_shift: str = "scale_shift",
104
+ decoder_time_embedding_type: str = "learned",
105
+ decoder_up_block_types: Tuple[str, ...] = (
106
+ "ResnetUpsampleBlock2D",
107
+ "ResnetUpsampleBlock2D",
108
+ "ResnetUpsampleBlock2D",
109
+ "ResnetUpsampleBlock2D",
110
+ ),
111
+ ):
112
+ super().__init__()
113
+ self.encoder = Encoder(
114
+ act_fn=encoder_act_fn,
115
+ block_out_channels=encoder_block_out_channels,
116
+ double_z=encoder_double_z,
117
+ down_block_types=encoder_down_block_types,
118
+ in_channels=encoder_in_channels,
119
+ layers_per_block=encoder_layers_per_block,
120
+ norm_num_groups=encoder_norm_num_groups,
121
+ out_channels=encoder_out_channels,
122
+ )
123
+
124
+ self.decoder_unet = UNet2DModel(
125
+ add_attention=decoder_add_attention,
126
+ block_out_channels=decoder_block_out_channels,
127
+ down_block_types=decoder_down_block_types,
128
+ downsample_padding=decoder_downsample_padding,
129
+ in_channels=decoder_in_channels,
130
+ layers_per_block=decoder_layers_per_block,
131
+ norm_eps=decoder_norm_eps,
132
+ norm_num_groups=decoder_norm_num_groups,
133
+ num_train_timesteps=decoder_num_train_timesteps,
134
+ out_channels=decoder_out_channels,
135
+ resnet_time_scale_shift=decoder_resnet_time_scale_shift,
136
+ time_embedding_type=decoder_time_embedding_type,
137
+ up_block_types=decoder_up_block_types,
138
+ )
139
+ self.decoder_scheduler = ConsistencyDecoderScheduler()
140
+ self.register_to_config(block_out_channels=encoder_block_out_channels)
141
+ self.register_to_config(force_upcast=False)
142
+
143
+ self.register_buffer(
144
+ "means",
145
+ paddle.to_tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
146
+ persistable=False,
147
+ )
148
+ self.register_buffer(
149
+ "stds",
150
+ paddle.to_tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None],
151
+ persistable=False,
152
+ )
153
+
154
+ self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1)
155
+
156
+ self.use_slicing = False
157
+ self.use_tiling = False
158
+
159
+ # Copied from ppdiffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
160
+ def enable_tiling(self, use_tiling: bool = True):
161
+ r"""
162
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
163
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
164
+ processing larger images.
165
+ """
166
+ self.use_tiling = use_tiling
167
+
168
+ # Copied from ppdiffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
169
+ def disable_tiling(self):
170
+ r"""
171
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
172
+ decoding in one step.
173
+ """
174
+ self.enable_tiling(False)
175
+
176
+ # Copied from ppdiffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
177
+ def enable_slicing(self):
178
+ r"""
179
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
180
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
181
+ """
182
+ self.use_slicing = True
183
+
184
+ # Copied from ppdiffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
185
+ def disable_slicing(self):
186
+ r"""
187
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
188
+ decoding in one step.
189
+ """
190
+ self.use_slicing = False
191
+
192
+ @property
193
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
194
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
195
+ r"""
196
+ Returns:
197
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
198
+ indexed by its weight name.
199
+ """
200
+ # set recursively
201
+ processors = {}
202
+
203
+ def fn_recursive_add_processors(name: str, module: nn.Layer, processors: Dict[str, AttentionProcessor]):
204
+ if hasattr(module, "get_processor"):
205
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
206
+
207
+ for sub_name, child in module.named_children():
208
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
209
+
210
+ return processors
211
+
212
+ for name, module in self.named_children():
213
+ fn_recursive_add_processors(name, module, processors)
214
+
215
+ return processors
216
+
217
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
218
+ def set_attn_processor(
219
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
220
+ ):
221
+ r"""
222
+ Sets the attention processor to use to compute attention.
223
+
224
+ Parameters:
225
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
226
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
227
+ for **all** `Attention` layers.
228
+
229
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
230
+ processor. This is strongly recommended when setting trainable attention processors.
231
+
232
+ """
233
+ count = len(self.attn_processors.keys())
234
+
235
+ if isinstance(processor, dict) and len(processor) != count:
236
+ raise ValueError(
237
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
238
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
239
+ )
240
+
241
+ def fn_recursive_attn_processor(name: str, module: nn.Layer, processor):
242
+ if hasattr(module, "set_processor"):
243
+ if not isinstance(processor, dict):
244
+ module.set_processor(processor, _remove_lora=_remove_lora)
245
+ else:
246
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
247
+
248
+ for sub_name, child in module.named_children():
249
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
250
+
251
+ for name, module in self.named_children():
252
+ fn_recursive_attn_processor(name, module, processor)
253
+
254
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
255
+ def set_default_attn_processor(self):
256
+ """
257
+ Disables custom attention processors and sets the default attention implementation.
258
+ """
259
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
260
+ processor = AttnAddedKVProcessor()
261
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
262
+ processor = AttnProcessor()
263
+ else:
264
+ raise ValueError(
265
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
266
+ )
267
+
268
+ self.set_attn_processor(processor, _remove_lora=True)
269
+
270
+ @apply_forward_hook
271
+ def encode(
272
+ self, x: paddle.Tensor, return_dict: bool = True
273
+ ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
274
+ """
275
+ Encode a batch of images into latents.
276
+
277
+ Args:
278
+ x (`paddle.Tensor`): Input batch of images.
279
+ return_dict (`bool`, *optional*, defaults to `True`):
280
+ Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderOutput`] instead of a plain
281
+ tuple.
282
+
283
+ Returns:
284
+ The latent representations of the encoded images. If `return_dict` is True, a
285
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
286
+ is returned.
287
+ """
288
+ # TODO junnyu, support float16
289
+ x = x.cast(self.encoder.conv_in.weight.dtype)
290
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
291
+ return self.tiled_encode(x, return_dict=return_dict)
292
+
293
+ if self.use_slicing and x.shape[0] > 1:
294
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.chunk(x.shape[0])]
295
+ h = paddle.concat(encoded_slices)
296
+ else:
297
+ h = self.encoder(x)
298
+
299
+ moments = self.quant_conv(h)
300
+ posterior = DiagonalGaussianDistribution(moments)
301
+
302
+ if not return_dict:
303
+ return (posterior,)
304
+
305
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
306
+
307
+ @apply_forward_hook
308
+ def decode(
309
+ self,
310
+ z: paddle.Tensor,
311
+ generator: Optional[paddle.Generator] = None,
312
+ return_dict: bool = True,
313
+ num_inference_steps: int = 2,
314
+ ) -> Union[DecoderOutput, Tuple[paddle.Tensor]]:
315
+
316
+ z = (z * self.config.scaling_factor - self.means) / self.stds
317
+
318
+ scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
319
+ z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
320
+
321
+ batch_size, _, height, width = z.shape
322
+
323
+ self.decoder_scheduler.set_timesteps(num_inference_steps)
324
+
325
+ x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
326
+ (batch_size, 3, height, width),
327
+ generator=generator,
328
+ dtype=z.dtype,
329
+ )
330
+
331
+ for t in self.decoder_scheduler.timesteps:
332
+ model_input = paddle.concat([self.decoder_scheduler.scale_model_input(x_t, t).cast(z.dtype), z], axis=1)
333
+ model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
334
+ prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
335
+ x_t = prev_sample
336
+
337
+ x_0 = x_t
338
+
339
+ if not return_dict:
340
+ return (x_0,)
341
+
342
+ return DecoderOutput(sample=x_0)
343
+
344
+ # Copied from ppdiffusers.models.autoencoder_kl.AutoencoderKL.blend_v
345
+ def blend_v(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int) -> paddle.Tensor:
346
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
347
+ for y in range(blend_extent):
348
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
349
+ return b
350
+
351
+ # Copied from ppdiffusers.models.autoencoder_kl.AutoencoderKL.blend_h
352
+ def blend_h(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int) -> paddle.Tensor:
353
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
354
+ for x in range(blend_extent):
355
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
356
+ return b
357
+
358
+ def tiled_encode(self, x: paddle.Tensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
359
+ r"""Encode a batch of images using a tiled encoder.
360
+
361
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
362
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
363
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
364
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
365
+ output, but they should be much less noticeable.
366
+
367
+ Args:
368
+ x (`paddle.Tensor`): Input batch of images.
369
+ return_dict (`bool`, *optional*, defaults to `True`):
370
+ Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
371
+ plain tuple.
372
+
373
+ Returns:
374
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
375
+ If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
376
+ otherwise a plain `tuple` is returned.
377
+ """
378
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
379
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
380
+ row_limit = self.tile_latent_min_size - blend_extent
381
+
382
+ # Split the image into 512x512 tiles and encode them separately.
383
+ rows = []
384
+ for i in range(0, x.shape[2], overlap_size):
385
+ row = []
386
+ for j in range(0, x.shape[3], overlap_size):
387
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
388
+ tile = self.encoder(tile)
389
+ tile = self.quant_conv(tile)
390
+ row.append(tile)
391
+ rows.append(row)
392
+ result_rows = []
393
+ for i, row in enumerate(rows):
394
+ result_row = []
395
+ for j, tile in enumerate(row):
396
+ # blend the above tile and the left tile
397
+ # to the current tile and add the current tile to the result row
398
+ if i > 0:
399
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
400
+ if j > 0:
401
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
402
+ result_row.append(tile[:, :, :row_limit, :row_limit])
403
+ result_rows.append(paddle.concat(result_row, axis=3))
404
+
405
+ moments = paddle.concat(result_rows, axis=2)
406
+ posterior = DiagonalGaussianDistribution(moments)
407
+
408
+ if not return_dict:
409
+ return (posterior,)
410
+
411
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
412
+
413
+ def forward(
414
+ self,
415
+ sample: paddle.Tensor,
416
+ sample_posterior: bool = False,
417
+ return_dict: bool = True,
418
+ generator: Optional[paddle.Generator] = None,
419
+ ) -> Union[DecoderOutput, Tuple[paddle.Tensor]]:
420
+ r"""
421
+ Args:
422
+ sample (`paddle.Tensor`): Input sample.
423
+ sample_posterior (`bool`, *optional*, defaults to `False`):
424
+ Whether to sample from the posterior.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
427
+ generator (`torch.Generator`, *optional*, defaults to `None`):
428
+ Generator to use for sampling.
429
+
430
+ Returns:
431
+ [`DecoderOutput`] or `tuple`:
432
+ If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
433
+ """
434
+ x = sample
435
+ posterior = self.encode(x).latent_dist
436
+ if sample_posterior:
437
+ z = posterior.sample(generator=generator)
438
+ else:
439
+ z = posterior.mode()
440
+ dec = self.decode(z, generator=generator).sample
441
+
442
+ if not return_dict:
443
+ return (dec,)
444
+
445
+ return DecoderOutput(sample=dec)
PaddleMIX/ppdiffusers/ppdiffusers/models/controlnet.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import paddle
18
+ import paddle.nn as nn
19
+ import paddle.nn.functional as F
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import FromOriginalControlnetMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from .embeddings import (
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from .modeling_utils import ModelMixin
39
+ from .unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ DownBlock2D,
42
+ UNetMidBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ get_down_block,
45
+ )
46
+ from .unet_2d_condition import UNet2DConditionModel
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ @dataclass
52
+ class ControlNetOutput(BaseOutput):
53
+ """
54
+ The output of [`ControlNetModel`].
55
+
56
+ Args:
57
+ down_block_res_samples (`tuple[paddle.Tensor]`):
58
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
59
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
60
+ used to condition the original UNet's downsampling activations.
61
+ mid_down_block_re_sample (`paddle.Tensor`):
62
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
63
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
64
+ Output can be used to condition the original UNet's middle block activation.
65
+ """
66
+
67
+ down_block_res_samples: Tuple[paddle.Tensor]
68
+ mid_block_res_sample: paddle.Tensor
69
+
70
+
71
+ class ControlNetConditioningEmbedding(nn.Layer):
72
+ """
73
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
74
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
75
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
76
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
77
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
78
+ model) to encode image-space conditions ... into feature maps ..."
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ conditioning_embedding_channels: int,
84
+ conditioning_channels: int = 3,
85
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
86
+ ):
87
+ super().__init__()
88
+
89
+ self.conv_in = nn.Conv2D(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
90
+
91
+ self.blocks = nn.LayerList([])
92
+
93
+ for i in range(len(block_out_channels) - 1):
94
+ channel_in = block_out_channels[i]
95
+ channel_out = block_out_channels[i + 1]
96
+ self.blocks.append(nn.Conv2D(channel_in, channel_in, kernel_size=3, padding=1))
97
+ self.blocks.append(nn.Conv2D(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
98
+
99
+ self.conv_out = zero_module(
100
+ nn.Conv2D(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
101
+ )
102
+
103
+ def forward(self, conditioning):
104
+ embedding = self.conv_in(conditioning)
105
+ embedding = F.silu(embedding)
106
+
107
+ for block in self.blocks:
108
+ embedding = block(embedding)
109
+ embedding = F.silu(embedding)
110
+
111
+ embedding = self.conv_out(embedding)
112
+
113
+ return embedding
114
+
115
+
116
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
117
+ """
118
+ A ControlNet model.
119
+
120
+ Args:
121
+ in_channels (`int`, defaults to 4):
122
+ The number of channels in the input sample.
123
+ flip_sin_to_cos (`bool`, defaults to `True`):
124
+ Whether to flip the sin to cos in the time embedding.
125
+ freq_shift (`int`, defaults to 0):
126
+ The frequency shift to apply to the time embedding.
127
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
128
+ The tuple of downsample blocks to use.
129
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
130
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
131
+ The tuple of output channels for each block.
132
+ layers_per_block (`int`, defaults to 2):
133
+ The number of layers per block.
134
+ downsample_padding (`int`, defaults to 1):
135
+ The padding to use for the downsampling convolution.
136
+ mid_block_scale_factor (`float`, defaults to 1):
137
+ The scale factor to use for the mid block.
138
+ act_fn (`str`, defaults to "silu"):
139
+ The activation function to use.
140
+ norm_num_groups (`int`, *optional*, defaults to 32):
141
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
142
+ in post-processing.
143
+ norm_eps (`float`, defaults to 1e-5):
144
+ The epsilon to use for the normalization.
145
+ cross_attention_dim (`int`, defaults to 1280):
146
+ The dimension of the cross attention features.
147
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
148
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
149
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
150
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
151
+ encoder_hid_dim (`int`, *optional*, defaults to None):
152
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
153
+ dimension to `cross_attention_dim`.
154
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
155
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
156
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
157
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
158
+ The dimension of the attention heads.
159
+ use_linear_projection (`bool`, defaults to `False`):
160
+ class_embed_type (`str`, *optional*, defaults to `None`):
161
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
162
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
163
+ addition_embed_type (`str`, *optional*, defaults to `None`):
164
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
165
+ "text". "text" will use the `TextTimeEmbedding` layer.
166
+ num_class_embeds (`int`, *optional*, defaults to 0):
167
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
168
+ class conditioning with `class_embed_type` equal to `None`.
169
+ upcast_attention (`bool`, defaults to `False`):
170
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
171
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
172
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
173
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
174
+ `class_embed_type="projection"`.
175
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
176
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
177
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
178
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
179
+ global_pool_conditions (`bool`, defaults to `False`):
180
+ TODO(Patrick) - unused parameter.
181
+ addition_embed_type_num_heads (`int`, defaults to 64):
182
+ The number of heads to use for the `TextTimeEmbedding` layer.
183
+ """
184
+
185
+ _supports_gradient_checkpointing = True
186
+
187
+ @register_to_config
188
+ def __init__(
189
+ self,
190
+ in_channels: int = 4,
191
+ conditioning_channels: int = 3,
192
+ flip_sin_to_cos: bool = True,
193
+ freq_shift: int = 0,
194
+ down_block_types: Tuple[str, ...] = (
195
+ "CrossAttnDownBlock2D",
196
+ "CrossAttnDownBlock2D",
197
+ "CrossAttnDownBlock2D",
198
+ "DownBlock2D",
199
+ ),
200
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
201
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
202
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
203
+ layers_per_block: int = 2,
204
+ downsample_padding: int = 1,
205
+ mid_block_scale_factor: float = 1,
206
+ act_fn: str = "silu",
207
+ norm_num_groups: Optional[int] = 32,
208
+ norm_eps: float = 1e-5,
209
+ cross_attention_dim: int = 1280,
210
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
211
+ encoder_hid_dim: Optional[int] = None,
212
+ encoder_hid_dim_type: Optional[str] = None,
213
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
214
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
215
+ use_linear_projection: bool = False,
216
+ class_embed_type: Optional[str] = None,
217
+ addition_embed_type: Optional[str] = None,
218
+ addition_time_embed_dim: Optional[int] = None,
219
+ num_class_embeds: Optional[int] = None,
220
+ upcast_attention: bool = False,
221
+ resnet_time_scale_shift: str = "default",
222
+ projection_class_embeddings_input_dim: Optional[int] = None,
223
+ controlnet_conditioning_channel_order: str = "rgb",
224
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
225
+ global_pool_conditions: bool = False,
226
+ addition_embed_type_num_heads: int = 64,
227
+ ):
228
+ super().__init__()
229
+
230
+ # If `num_attention_heads` is not defined (which is the case for most models)
231
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
232
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
233
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
234
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
235
+ # which is why we correct for the naming here.
236
+ num_attention_heads = num_attention_heads or attention_head_dim
237
+
238
+ # Check inputs
239
+ if len(block_out_channels) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
250
+ raise ValueError(
251
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
252
+ )
253
+
254
+ if isinstance(transformer_layers_per_block, int):
255
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
256
+
257
+ # input
258
+ conv_in_kernel = 3
259
+ conv_in_padding = (conv_in_kernel - 1) // 2
260
+ self.conv_in = nn.Conv2D(
261
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
262
+ )
263
+
264
+ # time
265
+ time_embed_dim = block_out_channels[0] * 4
266
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
267
+ timestep_input_dim = block_out_channels[0]
268
+ self.time_embedding = TimestepEmbedding(
269
+ timestep_input_dim,
270
+ time_embed_dim,
271
+ act_fn=act_fn,
272
+ )
273
+
274
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
275
+ encoder_hid_dim_type = "text_proj"
276
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
277
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
278
+
279
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
280
+ raise ValueError(
281
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
282
+ )
283
+
284
+ if encoder_hid_dim_type == "text_proj":
285
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
286
+ elif encoder_hid_dim_type == "text_image_proj":
287
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
288
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
289
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
290
+ self.encoder_hid_proj = TextImageProjection(
291
+ text_embed_dim=encoder_hid_dim,
292
+ image_embed_dim=cross_attention_dim,
293
+ cross_attention_dim=cross_attention_dim,
294
+ )
295
+
296
+ elif encoder_hid_dim_type is not None:
297
+ raise ValueError(
298
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
299
+ )
300
+ else:
301
+ self.encoder_hid_proj = None
302
+
303
+ # class embedding
304
+ if class_embed_type is None and num_class_embeds is not None:
305
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
306
+ elif class_embed_type == "timestep":
307
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
308
+ elif class_embed_type == "identity":
309
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
310
+ elif class_embed_type == "projection":
311
+ if projection_class_embeddings_input_dim is None:
312
+ raise ValueError(
313
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
314
+ )
315
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
316
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
317
+ # 2. it projects from an arbitrary input dimension.
318
+ #
319
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
320
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
321
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
322
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
323
+ else:
324
+ self.class_embedding = None
325
+
326
+ if addition_embed_type == "text":
327
+ if encoder_hid_dim is not None:
328
+ text_time_embedding_from_dim = encoder_hid_dim
329
+ else:
330
+ text_time_embedding_from_dim = cross_attention_dim
331
+
332
+ self.add_embedding = TextTimeEmbedding(
333
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
334
+ )
335
+ elif addition_embed_type == "text_image":
336
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
339
+ self.add_embedding = TextImageTimeEmbedding(
340
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
341
+ )
342
+ elif addition_embed_type == "text_time":
343
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
344
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
345
+
346
+ elif addition_embed_type is not None:
347
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
348
+
349
+ # control net conditioning embedding
350
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
351
+ conditioning_embedding_channels=block_out_channels[0],
352
+ block_out_channels=conditioning_embedding_out_channels,
353
+ conditioning_channels=conditioning_channels,
354
+ )
355
+
356
+ self.down_blocks = nn.LayerList([])
357
+ self.controlnet_down_blocks = nn.LayerList([])
358
+
359
+ if isinstance(only_cross_attention, bool):
360
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
361
+
362
+ if isinstance(attention_head_dim, int):
363
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
364
+
365
+ if isinstance(num_attention_heads, int):
366
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
367
+
368
+ # down
369
+ output_channel = block_out_channels[0]
370
+
371
+ controlnet_block = nn.Conv2D(output_channel, output_channel, kernel_size=1)
372
+ controlnet_block = zero_module(controlnet_block)
373
+ self.controlnet_down_blocks.append(controlnet_block)
374
+
375
+ for i, down_block_type in enumerate(down_block_types):
376
+ input_channel = output_channel
377
+ output_channel = block_out_channels[i]
378
+ is_final_block = i == len(block_out_channels) - 1
379
+
380
+ down_block = get_down_block(
381
+ down_block_type,
382
+ num_layers=layers_per_block,
383
+ transformer_layers_per_block=transformer_layers_per_block[i],
384
+ in_channels=input_channel,
385
+ out_channels=output_channel,
386
+ temb_channels=time_embed_dim,
387
+ add_downsample=not is_final_block,
388
+ resnet_eps=norm_eps,
389
+ resnet_act_fn=act_fn,
390
+ resnet_groups=norm_num_groups,
391
+ cross_attention_dim=cross_attention_dim,
392
+ num_attention_heads=num_attention_heads[i],
393
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
394
+ downsample_padding=downsample_padding,
395
+ use_linear_projection=use_linear_projection,
396
+ only_cross_attention=only_cross_attention[i],
397
+ upcast_attention=upcast_attention,
398
+ resnet_time_scale_shift=resnet_time_scale_shift,
399
+ )
400
+ self.down_blocks.append(down_block)
401
+
402
+ for _ in range(layers_per_block):
403
+ controlnet_block = nn.Conv2D(output_channel, output_channel, kernel_size=1)
404
+ controlnet_block = zero_module(controlnet_block)
405
+ self.controlnet_down_blocks.append(controlnet_block)
406
+
407
+ if not is_final_block:
408
+ controlnet_block = nn.Conv2D(output_channel, output_channel, kernel_size=1)
409
+ controlnet_block = zero_module(controlnet_block)
410
+ self.controlnet_down_blocks.append(controlnet_block)
411
+
412
+ # mid
413
+ mid_block_channel = block_out_channels[-1]
414
+
415
+ controlnet_block = nn.Conv2D(mid_block_channel, mid_block_channel, kernel_size=1)
416
+ controlnet_block = zero_module(controlnet_block)
417
+ self.controlnet_mid_block = controlnet_block
418
+
419
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
420
+ self.mid_block = UNetMidBlock2DCrossAttn(
421
+ transformer_layers_per_block=transformer_layers_per_block[-1],
422
+ in_channels=mid_block_channel,
423
+ temb_channels=time_embed_dim,
424
+ resnet_eps=norm_eps,
425
+ resnet_act_fn=act_fn,
426
+ output_scale_factor=mid_block_scale_factor,
427
+ resnet_time_scale_shift=resnet_time_scale_shift,
428
+ cross_attention_dim=cross_attention_dim,
429
+ num_attention_heads=num_attention_heads[-1],
430
+ resnet_groups=norm_num_groups,
431
+ use_linear_projection=use_linear_projection,
432
+ upcast_attention=upcast_attention,
433
+ )
434
+ elif mid_block_type == "UNetMidBlock2D":
435
+ self.mid_block = UNetMidBlock2D(
436
+ in_channels=block_out_channels[-1],
437
+ temb_channels=time_embed_dim,
438
+ num_layers=0,
439
+ resnet_eps=norm_eps,
440
+ resnet_act_fn=act_fn,
441
+ output_scale_factor=mid_block_scale_factor,
442
+ resnet_groups=norm_num_groups,
443
+ resnet_time_scale_shift=resnet_time_scale_shift,
444
+ add_attention=False,
445
+ )
446
+ else:
447
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
448
+
449
+ @classmethod
450
+ def from_unet(
451
+ cls,
452
+ unet: UNet2DConditionModel,
453
+ controlnet_conditioning_channel_order: str = "rgb",
454
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
455
+ load_weights_from_unet: bool = True,
456
+ conditioning_channels: int = 3,
457
+ ):
458
+ r"""
459
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
460
+
461
+ Parameters:
462
+ unet (`UNet2DConditionModel`):
463
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
464
+ where applicable.
465
+ """
466
+ transformer_layers_per_block = (
467
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
468
+ )
469
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
470
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
471
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
472
+ addition_time_embed_dim = (
473
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
474
+ )
475
+
476
+ controlnet = cls(
477
+ encoder_hid_dim=encoder_hid_dim,
478
+ encoder_hid_dim_type=encoder_hid_dim_type,
479
+ addition_embed_type=addition_embed_type,
480
+ addition_time_embed_dim=addition_time_embed_dim,
481
+ transformer_layers_per_block=transformer_layers_per_block,
482
+ in_channels=unet.config.in_channels,
483
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
484
+ freq_shift=unet.config.freq_shift,
485
+ down_block_types=unet.config.down_block_types,
486
+ only_cross_attention=unet.config.only_cross_attention,
487
+ block_out_channels=unet.config.block_out_channels,
488
+ layers_per_block=unet.config.layers_per_block,
489
+ downsample_padding=unet.config.downsample_padding,
490
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
491
+ act_fn=unet.config.act_fn,
492
+ norm_num_groups=unet.config.norm_num_groups,
493
+ norm_eps=unet.config.norm_eps,
494
+ cross_attention_dim=unet.config.cross_attention_dim,
495
+ attention_head_dim=unet.config.attention_head_dim,
496
+ num_attention_heads=unet.config.num_attention_heads,
497
+ use_linear_projection=unet.config.use_linear_projection,
498
+ class_embed_type=unet.config.class_embed_type,
499
+ num_class_embeds=unet.config.num_class_embeds,
500
+ upcast_attention=unet.config.upcast_attention,
501
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
502
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
503
+ mid_block_type=unet.config.mid_block_type,
504
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
505
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
506
+ conditioning_channels=conditioning_channels,
507
+ )
508
+
509
+ if load_weights_from_unet:
510
+ controlnet.conv_in.load_dict(unet.conv_in.state_dict())
511
+ controlnet.time_proj.load_dict(unet.time_proj.state_dict())
512
+ controlnet.time_embedding.load_dict(unet.time_embedding.state_dict())
513
+
514
+ if controlnet.class_embedding:
515
+ controlnet.class_embedding.load_dict(unet.class_embedding.state_dict())
516
+
517
+ controlnet.down_blocks.load_dict(unet.down_blocks.state_dict())
518
+ controlnet.mid_block.load_dict(unet.mid_block.state_dict())
519
+
520
+ return controlnet
521
+
522
+ @property
523
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
524
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
525
+ r"""
526
+ Returns:
527
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
528
+ indexed by its weight name.
529
+ """
530
+ # set recursively
531
+ processors = {}
532
+
533
+ def fn_recursive_add_processors(name: str, module: nn.Layer, processors: Dict[str, AttentionProcessor]):
534
+ if hasattr(module, "get_processor"):
535
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
536
+
537
+ for sub_name, child in module.named_children():
538
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
539
+
540
+ return processors
541
+
542
+ for name, module in self.named_children():
543
+ fn_recursive_add_processors(name, module, processors)
544
+
545
+ return processors
546
+
547
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
548
+ def set_attn_processor(
549
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
550
+ ):
551
+ r"""
552
+ Sets the attention processor to use to compute attention.
553
+
554
+ Parameters:
555
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
556
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
557
+ for **all** `Attention` layers.
558
+
559
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
560
+ processor. This is strongly recommended when setting trainable attention processors.
561
+
562
+ """
563
+ count = len(self.attn_processors.keys())
564
+
565
+ if isinstance(processor, dict) and len(processor) != count:
566
+ raise ValueError(
567
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
568
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
569
+ )
570
+
571
+ def fn_recursive_attn_processor(name: str, module: nn.Layer, processor):
572
+ if hasattr(module, "set_processor"):
573
+ if not isinstance(processor, dict):
574
+ module.set_processor(processor, _remove_lora=_remove_lora)
575
+ else:
576
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
577
+
578
+ for sub_name, child in module.named_children():
579
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
580
+
581
+ for name, module in self.named_children():
582
+ fn_recursive_attn_processor(name, module, processor)
583
+
584
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
585
+ def set_default_attn_processor(self):
586
+ """
587
+ Disables custom attention processors and sets the default attention implementation.
588
+ """
589
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
590
+ processor = AttnAddedKVProcessor()
591
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
592
+ processor = AttnProcessor()
593
+ else:
594
+ raise ValueError(
595
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
596
+ )
597
+
598
+ self.set_attn_processor(processor, _remove_lora=True)
599
+
600
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
601
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
602
+ r"""
603
+ Enable sliced attention computation.
604
+
605
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
606
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
607
+
608
+ Args:
609
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
610
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
611
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
612
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
613
+ must be a multiple of `slice_size`.
614
+ """
615
+ sliceable_head_dims = []
616
+
617
+ def fn_recursive_retrieve_sliceable_dims(module: nn.Layer):
618
+ if hasattr(module, "set_attention_slice"):
619
+ sliceable_head_dims.append(module.sliceable_head_dim)
620
+
621
+ for child in module.children():
622
+ fn_recursive_retrieve_sliceable_dims(child)
623
+
624
+ # retrieve number of attention layers
625
+ for module in self.children():
626
+ fn_recursive_retrieve_sliceable_dims(module)
627
+
628
+ num_sliceable_layers = len(sliceable_head_dims)
629
+
630
+ if slice_size == "auto":
631
+ # half the attention head size is usually a good trade-off between
632
+ # speed and memory
633
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
634
+ elif slice_size == "max":
635
+ # make smallest slice possible
636
+ slice_size = num_sliceable_layers * [1]
637
+
638
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
639
+
640
+ if len(slice_size) != len(sliceable_head_dims):
641
+ raise ValueError(
642
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
643
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
644
+ )
645
+
646
+ for i in range(len(slice_size)):
647
+ size = slice_size[i]
648
+ dim = sliceable_head_dims[i]
649
+ if size is not None and size > dim:
650
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
651
+
652
+ # Recursively walk through all the children.
653
+ # Any children which exposes the set_attention_slice method
654
+ # gets the message
655
+ def fn_recursive_set_attention_slice(module: nn.Layer, slice_size: List[int]):
656
+ if hasattr(module, "set_attention_slice"):
657
+ module.set_attention_slice(slice_size.pop())
658
+
659
+ for child in module.children():
660
+ fn_recursive_set_attention_slice(child, slice_size)
661
+
662
+ reversed_slice_size = list(reversed(slice_size))
663
+ for module in self.children():
664
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
665
+
666
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
667
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
668
+ module.gradient_checkpointing = value
669
+
670
+ def forward(
671
+ self,
672
+ sample: paddle.Tensor,
673
+ timestep: Union[paddle.Tensor, float, int],
674
+ encoder_hidden_states: paddle.Tensor,
675
+ controlnet_cond: paddle.Tensor,
676
+ conditioning_scale: float = 1.0,
677
+ class_labels: Optional[paddle.Tensor] = None,
678
+ timestep_cond: Optional[paddle.Tensor] = None,
679
+ attention_mask: Optional[paddle.Tensor] = None,
680
+ added_cond_kwargs: Optional[Dict[str, paddle.Tensor]] = None,
681
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
682
+ guess_mode: bool = False,
683
+ return_dict: bool = True,
684
+ ) -> Union[ControlNetOutput, Tuple[Tuple[paddle.Tensor, ...], paddle.Tensor]]:
685
+ """
686
+ The [`ControlNetModel`] forward method.
687
+
688
+ Args:
689
+ sample (`paddle.Tensor`):
690
+ The noisy input tensor.
691
+ timestep (`Union[paddle.Tensor, float, int]`):
692
+ The number of timesteps to denoise an input.
693
+ encoder_hidden_states (`paddle.Tensor`):
694
+ The encoder hidden states.
695
+ controlnet_cond (`paddle.Tensor`):
696
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
697
+ conditioning_scale (`float`, defaults to `1.0`):
698
+ The scale factor for ControlNet outputs.
699
+ class_labels (`paddle.Tensor`, *optional*, defaults to `None`):
700
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
701
+ timestep_cond (`paddle.Tensor`, *optional*, defaults to `None`):
702
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
703
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
704
+ embeddings.
705
+ attention_mask (`paddle.Tensor`, *optional*, defaults to `None`):
706
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
707
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
708
+ negative values to the attention scores corresponding to "discard" tokens.
709
+ added_cond_kwargs (`dict`):
710
+ Additional conditions for the Stable Diffusion XL UNet.
711
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
712
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
713
+ guess_mode (`bool`, defaults to `False`):
714
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
715
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
716
+ return_dict (`bool`, defaults to `True`):
717
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
718
+
719
+ Returns:
720
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
721
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
722
+ returned where the first element is the sample tensor.
723
+ """
724
+ # TODO junnyu, add this to support pure fp16
725
+ sample = sample.cast(self.dtype)
726
+
727
+ # check channel order
728
+ channel_order = self.config.controlnet_conditioning_channel_order
729
+
730
+ if channel_order == "rgb":
731
+ # in rgb order by default
732
+ ...
733
+ elif channel_order == "bgr":
734
+ controlnet_cond = paddle.flip(controlnet_cond, axis=[1])
735
+ else:
736
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
737
+
738
+ # prepare attention_mask
739
+ if attention_mask is not None:
740
+ attention_mask = (1 - attention_mask.cast(sample.dtype)) * -10000.0
741
+ attention_mask = attention_mask.unsqueeze(1)
742
+
743
+ # 1. time
744
+ timesteps = timestep
745
+ if not paddle.is_tensor(timesteps):
746
+ timesteps = paddle.to_tensor([timesteps], dtype="int64")
747
+ elif len(timesteps.shape) == 0:
748
+ timesteps = timesteps[None]
749
+
750
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
751
+ timesteps = timesteps.expand(
752
+ [
753
+ sample.shape[0],
754
+ ]
755
+ )
756
+ t_emb = self.time_proj(timesteps)
757
+
758
+ # timesteps does not contain any weights and will always return f32 tensors
759
+ # but time_embedding might actually be running in fp16. so we need to cast here.
760
+ # there might be better ways to encapsulate this.
761
+ t_emb = t_emb.cast(dtype=sample.dtype)
762
+
763
+ emb = self.time_embedding(t_emb, timestep_cond)
764
+ aug_emb = None
765
+
766
+ if self.class_embedding is not None:
767
+ if class_labels is None:
768
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
769
+
770
+ # maybe cast it to float16
771
+ class_labels = class_labels.cast(sample.dtype)
772
+ if self.config.class_embed_type == "timestep":
773
+ class_labels = self.time_proj(class_labels)
774
+
775
+ # maybe cast it to int64
776
+ if isinstance(self.class_embedding, nn.Embedding):
777
+ class_labels = class_labels.cast(paddle.int64)
778
+ class_emb = self.class_embedding(class_labels).cast(dtype=sample.dtype)
779
+ emb = emb + class_emb
780
+
781
+ if self.config.addition_embed_type is not None:
782
+ if self.config.addition_embed_type == "text":
783
+ aug_emb = self.add_embedding(encoder_hidden_states)
784
+
785
+ elif self.config.addition_embed_type == "text_time":
786
+ if "text_embeds" not in added_cond_kwargs:
787
+ raise ValueError(
788
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
789
+ )
790
+ text_embeds = added_cond_kwargs.get("text_embeds")
791
+ if "time_ids" not in added_cond_kwargs:
792
+ raise ValueError(
793
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
794
+ )
795
+ time_ids = added_cond_kwargs.get("time_ids")
796
+ time_embeds = self.add_time_proj(time_ids.flatten())
797
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
798
+ # make sure [text_embeds, time_embeds] has the same dtype
799
+ time_embeds = time_embeds.cast(text_embeds.dtype)
800
+
801
+ add_embeds = paddle.concat([text_embeds, time_embeds], axis=-1)
802
+ add_embeds = add_embeds.cast(emb.dtype)
803
+ aug_emb = self.add_embedding(add_embeds)
804
+
805
+ emb = emb + aug_emb if aug_emb is not None else emb
806
+
807
+ # 2. pre-process
808
+ sample = self.conv_in(sample)
809
+
810
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
811
+ sample = sample + controlnet_cond
812
+
813
+ # 3. down
814
+ down_block_res_samples = (sample,)
815
+ for downsample_block in self.down_blocks:
816
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
817
+ sample, res_samples = downsample_block(
818
+ hidden_states=sample,
819
+ temb=emb,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ attention_mask=attention_mask,
822
+ cross_attention_kwargs=cross_attention_kwargs,
823
+ )
824
+ else:
825
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
826
+
827
+ down_block_res_samples += res_samples
828
+
829
+ # 4. mid
830
+ if self.mid_block is not None:
831
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
832
+ sample = self.mid_block(
833
+ sample,
834
+ emb,
835
+ encoder_hidden_states=encoder_hidden_states,
836
+ attention_mask=attention_mask,
837
+ cross_attention_kwargs=cross_attention_kwargs,
838
+ )
839
+ else:
840
+ sample = self.mid_block(sample, emb)
841
+
842
+ # 5. Control net blocks
843
+
844
+ controlnet_down_block_res_samples = ()
845
+
846
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
847
+ down_block_res_sample = controlnet_block(down_block_res_sample)
848
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
849
+
850
+ down_block_res_samples = controlnet_down_block_res_samples
851
+
852
+ mid_block_res_sample = self.controlnet_mid_block(sample)
853
+
854
+ # 6. scaling
855
+ if guess_mode and not self.config.global_pool_conditions:
856
+ scales = paddle.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
857
+ scales = scales * conditioning_scale
858
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
859
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
860
+ else:
861
+ if isinstance(conditioning_scale, (float, int)):
862
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
863
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
864
+ else:
865
+ # NEW ADD, for multiple conditioning scales
866
+ down_block_res_samples = [
867
+ sample * ccs for sample, ccs in zip(down_block_res_samples, conditioning_scale[:-1])
868
+ ]
869
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale[-1]
870
+
871
+ if self.config.global_pool_conditions:
872
+ down_block_res_samples = [
873
+ paddle.mean(sample, axis=(2, 3), keepdim=True) for sample in down_block_res_samples
874
+ ]
875
+ mid_block_res_sample = paddle.mean(mid_block_res_sample, axis=(2, 3), keepdim=True)
876
+
877
+ if not return_dict:
878
+ return (down_block_res_samples, mid_block_res_sample)
879
+
880
+ return ControlNetOutput(
881
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
882
+ )
883
+
884
+
885
+ @paddle.no_grad()
886
+ def zero_module(module):
887
+ for p in module.parameters():
888
+ p.zero_()
889
+ return module
PaddleMIX/ppdiffusers/ppdiffusers/models/dit_llama_t2i.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import paddle
18
+ import paddle.nn as nn
19
+ import paddle.nn.functional as F
20
+ from paddle.nn.functional.flash_attention import (
21
+ flash_attention,
22
+ scaled_dot_product_attention,
23
+ )
24
+
25
+ from ..configuration_utils import ConfigMixin, register_to_config
26
+ from .dit_llama import FeedForward, FinalLayer, TimestepEmbedder, TypePromote, modulate
27
+ from .modeling_utils import ModelMixin
28
+ from .transformer_2d import Transformer2DModelOutput
29
+
30
+
31
+ class Attention(nn.Layer):
32
+ def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True, y_dim=0):
33
+ """
34
+ Initialize the Attention module.
35
+
36
+ Args:
37
+ dim (int): Number of input dimensions.
38
+ n_heads (int): Number of heads.
39
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
40
+
41
+ Attributes:
42
+ n_kv_heads (int): Number of key and value heads.
43
+ n_local_heads (int): Number of local query heads.
44
+ n_local_kv_heads (int): Number of local key and value heads.
45
+ n_rep (int): Number of repetitions for local heads.
46
+ head_dim (int): Dimension size of each attention head.
47
+ wq (nn.Linear): Linear transformation for queries.
48
+ wk (nn.Linear): Linear transformation for keys.
49
+ wv (nn.Linear): Linear transformation for values.
50
+ wo (nn.Linear): Linear transformation for output.
51
+ cache_k (paddle.Tensor): Cached keys for attention.
52
+ cache_v (paddle.Tensor): Cached values for attention.
53
+
54
+ """
55
+ super().__init__()
56
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
57
+ self.n_local_heads = n_heads
58
+ self.n_local_kv_heads = self.n_kv_heads
59
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
60
+ self.head_dim = dim // n_heads
61
+
62
+ self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False)
63
+ self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
64
+ self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False)
65
+ self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False)
66
+
67
+ if y_dim > 0:
68
+ self.wk_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias_attr=False)
69
+ self.wv_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias_attr=False)
70
+ self.gate = nn.Parameter(paddle.zeros([self.n_local_heads]))
71
+
72
+ if qk_norm:
73
+ self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim)
74
+ self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
75
+ if y_dim > 0:
76
+ self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
77
+ else:
78
+ self.ky_norm = nn.Identity()
79
+ else:
80
+ self.q_norm = self.k_norm = nn.Identity()
81
+ self.ky_norm = nn.Identity()
82
+
83
+ self.fused_attn = fused_attn
84
+ self.scale = self.head_dim**-0.5
85
+
86
+ @staticmethod
87
+ def reshape_for_broadcast(freqs_cis, x):
88
+ """
89
+ Reshape frequency tensor for broadcasting it with another tensor.
90
+
91
+ This function reshapes the frequency tensor to have the same shape as
92
+ the target tensor 'x' for the purpose of broadcasting the frequency
93
+ tensor during element-wise operations.
94
+
95
+ Args:
96
+ freqs_cis (paddle.Tensor): Frequency tensor to be reshaped.
97
+ x (paddle.Tensor): Target tensor for broadcasting compatibility.
98
+
99
+ Returns:
100
+ paddle.Tensor: Reshaped frequency tensor.
101
+
102
+ Raises:
103
+ AssertionError: If the frequency tensor doesn't match the expected
104
+ shape.
105
+ AssertionError: If the target tensor 'x' doesn't have the expected
106
+ number of dimensions.
107
+ """
108
+ ndim = x.ndim
109
+ assert 0 <= 1 < ndim
110
+ assert tuple(freqs_cis.shape) == (tuple(x.shape)[1], tuple(x.shape)[-1])
111
+ shape = [(d if i == 1 or i == ndim - 1 else 1) for i, d in enumerate(tuple(x.shape))]
112
+ return freqs_cis.reshape([*shape])
113
+
114
+ @staticmethod
115
+ def apply_rotary_emb(xq, xk, freqs_cis):
116
+ """
117
+ Apply rotary embeddings to input tensors using the given frequency
118
+ tensor.
119
+
120
+ This function applies rotary embeddings to the given query 'xq' and
121
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
122
+ input tensors are reshaped as complex numbers, and the frequency tensor
123
+ is reshaped for broadcasting compatibility. The resulting tensors
124
+ contain rotary embeddings and are returned as real tensors.
125
+
126
+ Args:
127
+ xq (paddle.Tensor): Query tensor to apply rotary embeddings.
128
+ xk (paddle.Tensor): Key tensor to apply rotary embeddings.
129
+ freqs_cis (paddle.Tensor): Precomputed frequency tensor for complex
130
+ exponentials.
131
+
132
+ Returns:
133
+ Tuple[paddle.Tensor, paddle.Tensor]: Tuple of modified query tensor
134
+ and key tensor with rotary embeddings.
135
+ """
136
+ with paddle.amp.auto_cast(enable=False):
137
+ xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2]))
138
+ xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2]))
139
+ freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
140
+ xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3)
141
+ xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3)
142
+ return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype)
143
+
144
+ def forward(self, x, freqs_cis, y, y_mask):
145
+ """
146
+ Forward pass of the attention module.
147
+
148
+ Args:
149
+ x (paddle.Tensor): Input tensor.
150
+ freqs_cis (paddle.Tensor): Precomputed frequency tensor.
151
+
152
+ Returns:
153
+ paddle.Tensor: Output tensor after attention.
154
+
155
+ """
156
+ bsz, seqlen, _ = tuple(x.shape)
157
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
158
+ dtype = xq.dtype
159
+
160
+ xq = self.q_norm(xq)
161
+ xk = self.k_norm(xk)
162
+
163
+ xq = xq.reshape([bsz, seqlen, self.n_local_heads, self.head_dim])
164
+ xk = xk.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim])
165
+ xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim])
166
+
167
+ xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
168
+ xq, xk = xq.cast(dtype), xk.cast(dtype)
169
+
170
+ n_rep = self.n_local_heads // self.n_local_kv_heads
171
+
172
+ if dtype in [paddle.float16, paddle.bfloat16]:
173
+ output, _ = flash_attention(
174
+ xq,
175
+ xk,
176
+ xv,
177
+ dropout=0.0,
178
+ causal=False,
179
+ return_softmax=False,
180
+ )
181
+ else:
182
+ if n_rep > 1:
183
+ xk = xk.unsqueeze(axis=3).tile([1, 1, 1, n_rep, 1]).flatten(start_axis=2, stop_axis=3)
184
+ xv = xv.unsqueeze(axis=3).tile([1, 1, 1, n_rep, 1]).flatten(start_axis=2, stop_axis=3)
185
+ if self.fused_attn:
186
+ output = F.scaled_dot_product_attention_(
187
+ xq,
188
+ xk,
189
+ xv,
190
+ dropout_p=0.0,
191
+ is_causal=False,
192
+ )
193
+ else:
194
+ q = xq.transpose([0, 2, 1, 3]) * self.scale
195
+ attn = q @ xk.transpose([0, 2, 1, 3]).transpose([0, 1, 3, 2])
196
+ attn = F.softmax(attn, axis=-1)
197
+ output = attn @ xv.transpose([0, 2, 1, 3])
198
+ output = output.transpose([0, 2, 1, 3])
199
+
200
+ output = output.flatten(start_axis=-2)
201
+
202
+ if hasattr(self, "wk_y"):
203
+ yk = self.ky_norm(self.wk_y(y)).reshape([bsz, -1, self.n_local_kv_heads, self.head_dim])
204
+ yv = self.wv_y(y).reshape([bsz, -1, self.n_local_kv_heads, self.head_dim])
205
+ n_rep = self.n_local_heads // self.n_local_kv_heads
206
+
207
+ y_mask = y_mask.reshape([bsz, 1, 1, -1]).expand([bsz, self.n_local_heads, seqlen, -1])
208
+
209
+ if dtype in [paddle.float16, paddle.bfloat16]:
210
+ output_y = scaled_dot_product_attention(
211
+ xq,
212
+ yk,
213
+ yv,
214
+ attn_mask=y_mask.cast(dtype), # no need to transpose
215
+ )
216
+ else:
217
+ if n_rep > 1:
218
+ yk = yk.unsqueeze(3).tile([1, 1, 1, n_rep, 1]).flatten(2, 3)
219
+ yv = yv.unsqueeze(3).tile([1, 1, 1, n_rep, 1]).flatten(2, 3)
220
+
221
+ output_y = F.scaled_dot_product_attention_(
222
+ xq,
223
+ yk,
224
+ yv,
225
+ attn_mask=y_mask,
226
+ )
227
+
228
+ output_y = output_y * self.gate.tanh().reshape([1, 1, -1, 1])
229
+ output_y = output_y.flatten(-2)
230
+ output = output + output_y
231
+
232
+ return self.wo(output)
233
+
234
+
235
+ class TransformerBlock(nn.Layer):
236
+ def __init__(
237
+ self,
238
+ layer_id: int,
239
+ dim: int,
240
+ n_heads: int,
241
+ n_kv_heads: int,
242
+ multiple_of: int,
243
+ mlp_ratio: float,
244
+ ffn_dim_multiplier: float,
245
+ norm_eps: float,
246
+ qk_norm: bool,
247
+ fused_attn: bool,
248
+ y_dim: int,
249
+ ) -> None:
250
+ """
251
+ Initialize a TransformerBlock.
252
+
253
+ Args:
254
+ layer_id (int): Identifier for the layer.
255
+ dim (int): Embedding dimension of the input features.
256
+ n_heads (int): Number of attention heads.
257
+ n_kv_heads (Optional[int]): Number of attention heads in key and
258
+ value features (if using GQA), or set to None for the same as
259
+ query.
260
+ multiple_of (int): Value to ensure hidden dimension is a multiple
261
+ of this value in the FeedForward block.
262
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
263
+ dimension in the FeedForward block. Defaults to None.
264
+ norm_eps (float): A small value added to the norm layer
265
+ denominators to avoid division-by-zero.
266
+
267
+ Attributes:
268
+ n_heads (int): Number of attention heads.
269
+ dim (int): Dimension size of the model.
270
+ head_dim (int): Dimension size of each attention head.
271
+ attention (Attention): Attention module.
272
+ feed_forward (FeedForward): FeedForward module.
273
+ layer_id (int): Identifier for the layer.
274
+ attention_norm (RMSNorm): Layer normalization for attention output.
275
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
276
+ adaLN_modulation (nn.Sequential): A small network to generate
277
+ feature modulation factors.
278
+
279
+ """
280
+ super().__init__()
281
+ self.dim = dim
282
+ self.head_dim = dim // n_heads
283
+ self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn, y_dim)
284
+ mlp_hidden_dim = int(dim * mlp_ratio)
285
+ self.feed_forward = FeedForward(
286
+ dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier
287
+ )
288
+ self.layer_id = layer_id
289
+ self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False)
290
+ self.ffn_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False)
291
+
292
+ self.adaLN_modulation = nn.Sequential(
293
+ nn.Silu(),
294
+ nn.Linear(min(dim, 1024), 6 * dim),
295
+ )
296
+ self.attention_y_norm = nn.LayerNorm(y_dim, epsilon=norm_eps, bias_attr=False)
297
+
298
+ def forward(self, x, y, y_mask, freqs_cis, adaln_input=None):
299
+ """
300
+ Perform a forward pass through the TransformerBlock.
301
+
302
+ Args:
303
+ x (paddle.Tensor): Input tensor.
304
+ freqs_cis (paddle.Tensor): Precomputed cosine and sine frequencies.
305
+ mask (paddle.Tensor, optional): Masking tensor for attention.
306
+ Defaults to None.
307
+
308
+ Returns:
309
+ paddle.Tensor: Output tensor after applying attention and
310
+ feedforward layers.
311
+
312
+ """
313
+ y = y.cast(x.dtype)
314
+ if adaln_input is not None:
315
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(
316
+ 6, axis=1
317
+ )
318
+ h = x + gate_msa.unsqueeze(1) * self.attention(
319
+ modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis, self.attention_y_norm(y), y_mask
320
+ )
321
+ out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
322
+ else:
323
+ h = x + self.attention(self.attention_norm(x), freqs_cis, self.attention_y_norm(y), y_mask)
324
+ out = h + self.feed_forward(self.ffn_norm(h))
325
+ return out
326
+
327
+
328
+ class DiTLLaMAT2IModel(ModelMixin, ConfigMixin):
329
+ _supports_gradient_checkpointing = True
330
+ _use_memory_efficient_attention_xformers = True
331
+
332
+ @register_to_config
333
+ def __init__(
334
+ self,
335
+ patch_size: int = 2,
336
+ in_channels: int = 4,
337
+ out_channels: int = 8,
338
+ max_seq_len: int = 4224,
339
+ num_layers: int = 32,
340
+ num_attention_heads: int = 16,
341
+ attention_head_dim: int = 96,
342
+ mlp_ratio: float = 4.0,
343
+ n_kv_heads=None,
344
+ multiple_of: int = 256,
345
+ ffn_dim_multiplier=None,
346
+ norm_eps: float = 1e-05,
347
+ learn_sigma: bool = True,
348
+ qk_norm: bool = True,
349
+ cap_feat_dim: int = 4096,
350
+ rope_scaling_factor: float = 1.0,
351
+ ):
352
+ super().__init__()
353
+ self.max_seq_len = max_seq_len
354
+ self.patch_size = patch_size
355
+ self.in_channels = in_channels
356
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
357
+ dim = attention_head_dim * num_attention_heads
358
+
359
+ self.num_layers = num_layers
360
+ self.num_attention_heads = num_attention_heads
361
+ self.mlp_ratio = mlp_ratio
362
+ self.multiple_of = multiple_of
363
+ self.ffn_dim_multiplier = ffn_dim_multiplier
364
+ self.norm_eps = norm_eps
365
+ self.learn_sigma = learn_sigma
366
+ self.qk_norm = qk_norm
367
+
368
+ self.gradient_checkpointing = True
369
+ self.fused_attn = True
370
+
371
+ self.x_embedder = nn.Linear(in_channels * patch_size**2, dim)
372
+ self.t_embedder = TimestepEmbedder(min(dim, 1024))
373
+ self.cap_embedder = nn.Sequential(
374
+ nn.LayerNorm(cap_feat_dim),
375
+ nn.Linear(cap_feat_dim, min(dim, 1024)),
376
+ )
377
+
378
+ # 2. Define transformers blocks
379
+ self.layers = nn.LayerList(
380
+ [
381
+ TransformerBlock(
382
+ layer_id=idx,
383
+ dim=dim,
384
+ n_heads=num_attention_heads,
385
+ n_kv_heads=n_kv_heads,
386
+ multiple_of=multiple_of,
387
+ mlp_ratio=mlp_ratio,
388
+ ffn_dim_multiplier=ffn_dim_multiplier,
389
+ norm_eps=norm_eps,
390
+ qk_norm=qk_norm,
391
+ fused_attn=self.fused_attn,
392
+ y_dim=cap_feat_dim,
393
+ )
394
+ for idx in range(num_layers)
395
+ ]
396
+ )
397
+
398
+ # 3. Define output layers
399
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
400
+ self.freqs_cis = self.precompute_freqs_cis(
401
+ dim // num_attention_heads, max_seq_len, rope_scaling_factor=rope_scaling_factor
402
+ )
403
+ self.eol_token = self.create_parameter(shape=[dim])
404
+ self.pad_token = self.create_parameter(shape=[dim])
405
+
406
+ def _set_gradient_checkpointing(self, module, value=False):
407
+ if hasattr(module, "gradient_checkpointing"):
408
+ module.gradient_checkpointing = value
409
+
410
+ def enable_gradient_checkpointing(self, value=True):
411
+ self.gradient_checkpointing = value
412
+
413
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[str] = None):
414
+ self._use_memory_efficient_attention_xformers = True
415
+ self.fused_attn = True
416
+
417
+ def unpatchify(self, x, img_size, return_tensor=False):
418
+ """
419
+ Args:
420
+ x: (N, T, patch_size**2 * C)
421
+ imgs: (N, H, W, C)
422
+ """
423
+ pH = pW = self.patch_size
424
+ if return_tensor:
425
+ H, W = img_size[0]
426
+ B = x.shape[0]
427
+ L = (H // pH) * (W // pW + 1) # one additional for eol
428
+ x = x[:, :L].reshape([B, H // pH, W // pW + 1, pH, pW, self.out_channels])
429
+ x = x[:, :, :-1]
430
+ x = x.transpose([0, 5, 1, 3, 2, 4]).flatten(4, 5).flatten(2, 3)
431
+ return x
432
+ else:
433
+ imgs = []
434
+ for i in range(x.shape[0]):
435
+ H, W = img_size[i]
436
+ L = (H // pH) * (W // pW + 1)
437
+ imgs.append(
438
+ x[i][:L]
439
+ .reshape([H // pH, W // pW + 1, pH, pW, self.out_channels])[:, :-1, :, :, :]
440
+ .transpose([4, 0, 2, 1, 3])
441
+ .flatten(3, 4)
442
+ .flatten(1, 2)
443
+ )
444
+ return imgs
445
+
446
+ def patchify_and_embed(self, x):
447
+ if isinstance(x, paddle.Tensor):
448
+ pH = pW = self.patch_size
449
+ B, C, H, W = x.shape[:]
450
+ x = x.reshape([B, C, H // pH, pH, W // pW, pW]).transpose([0, 2, 4, 1, 3, 5]).flatten(3)
451
+ x = self.x_embedder(x)
452
+
453
+ x = paddle.concat(
454
+ [
455
+ x,
456
+ self.eol_token.reshape([1, 1, 1, -1]).expand([B, H // pH, 1, -1]),
457
+ ],
458
+ axis=2,
459
+ )
460
+ x = x.flatten(1, 2)
461
+
462
+ if x.shape[1] < self.max_seq_len:
463
+ x = paddle.concat(
464
+ [
465
+ x,
466
+ self.pad_token.reshape([1, 1, -1]).expand([B, self.max_seq_len - x.shape[1], -1]),
467
+ ],
468
+ axis=1,
469
+ )
470
+ return x, [(H, W)] * B
471
+ else:
472
+ pH = pW = self.patch_size
473
+ x_embed = []
474
+ img_size = []
475
+ for img in x:
476
+ C, H, W = img.shape[:]
477
+ img_size.append((H, W))
478
+ img = img.reshape([C, H // pH, pH, W // pW, pW]).transpose([1, 3, 0, 2, 4]).flatten(2)
479
+ img = self.x_embedder(img)
480
+ img = paddle.concat(
481
+ [
482
+ img,
483
+ self.eol_token.reshape([1, 1, -1]).expand([H // pH, 1, -1]),
484
+ ],
485
+ axis=1,
486
+ )
487
+ img = img.flatten(0, 1)
488
+ if img.shape[0] < self.max_seq_len:
489
+ img = paddle.concat(
490
+ [
491
+ img,
492
+ self.pad_token.reshape([1, -1]).expand([self.max_seq_len - img.shape[0], -1]),
493
+ ],
494
+ axis=0,
495
+ )
496
+ x_embed.append(img)
497
+ x_embed = paddle.stack(x_embed, axis=0)
498
+ return x_embed, img_size
499
+
500
+ @staticmethod
501
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, rope_scaling_factor: float = 1.0):
502
+
503
+ """
504
+ Precompute the frequency tensor for complex exponentials (cis) with
505
+ given dimensions.
506
+
507
+ This function calculates a frequency tensor with complex exponentials
508
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
509
+ parameter scales the frequencies. The returned tensor contains complex
510
+ values in complex64 data type.
511
+
512
+ Args:
513
+ dim (int): Dimension of the frequency tensor.
514
+ end (int): End index for precomputing frequencies.
515
+ theta (float, optional): Scaling factor for frequency computation.
516
+ Defaults to 10000.0.
517
+
518
+ Returns:
519
+ paddle.Tensor: Precomputed frequency tensor with complex
520
+ exponentials.
521
+ """
522
+ freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim)
523
+ t = paddle.arange(end=end, dtype=paddle.float32)
524
+ t = t / rope_scaling_factor
525
+ input_0, vec2_0 = TypePromote(t, freqs)
526
+ freqs = paddle.outer(input_0, vec2_0).cast("float32")
527
+ freqs_cis = paddle.complex(
528
+ paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs)
529
+ )
530
+ return freqs_cis
531
+
532
+ def forward(
533
+ self,
534
+ hidden_states: paddle.Tensor,
535
+ timestep: paddle.Tensor,
536
+ cap_feats: paddle.Tensor,
537
+ cap_mask: paddle.Tensor,
538
+ return_dict: bool = True,
539
+ ):
540
+ """
541
+ Args:
542
+ hidden_states: (N, C, H, W) tensor of spatial inputs (images or latent
543
+ representations of images)
544
+ timestep: (N,) tensor of diffusion timesteps
545
+ class_labels: (N,) tensor of class labels
546
+ """
547
+ hidden_states = hidden_states.cast(self.dtype)
548
+ timestep = timestep.cast(self.dtype)
549
+
550
+ # 1. Input
551
+ x_is_tensor = isinstance(hidden_states, paddle.Tensor)
552
+ hidden_states, img_size = self.patchify_and_embed(hidden_states)
553
+
554
+ t = self.t_embedder(timestep).cast(self.dtype)
555
+ cap_mask_float = cap_mask.cast("float32").unsqueeze(-1)
556
+ cap_feats_pool = (cap_feats * cap_mask_float).sum(axis=1) / cap_mask_float.sum(axis=1)
557
+ cap_emb = self.cap_embedder(cap_feats_pool.cast(self.dtype))
558
+ adaln_input = t + cap_emb
559
+
560
+ # 2. Blocks
561
+ for i, layer in enumerate(self.layers):
562
+ if self.gradient_checkpointing:
563
+ hidden_states = paddle.distributed.fleet.utils.recompute(
564
+ layer, hidden_states, cap_feats, cap_mask, self.freqs_cis[: hidden_states.shape[1]], adaln_input
565
+ )
566
+ else:
567
+ hidden_states = layer(
568
+ hidden_states,
569
+ cap_feats,
570
+ cap_mask,
571
+ self.freqs_cis[: hidden_states.shape[1]],
572
+ adaln_input,
573
+ )
574
+
575
+ # 3. Output
576
+ hidden_states = self.final_layer(hidden_states, adaln_input)
577
+ output = self.unpatchify(hidden_states, img_size, return_tensor=x_is_tensor)
578
+
579
+ if not return_dict:
580
+ return (output,)
581
+
582
+ return Transformer2DModelOutput(sample=output)
PaddleMIX/ppdiffusers/ppdiffusers/models/downsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import paddle
18
+
19
+ from .normalization import RMSNorm
20
+ from .upsampling import upfirdn2d_native
21
+
22
+
23
+ class Downsample1D(paddle.nn.Layer):
24
+ """A 1D downsampling layer with an optional convolution.
25
+
26
+ Parameters:
27
+ channels (`int`):
28
+ number of channels in the inputs and outputs.
29
+ use_conv (`bool`, default `False`):
30
+ option to use a convolution.
31
+ out_channels (`int`, optional):
32
+ number of output channels. Defaults to `channels`.
33
+ padding (`int`, default `1`):
34
+ padding for the convolution.
35
+ name (`str`, default `conv`):
36
+ name of the downsampling 1D layer.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ channels: int,
42
+ use_conv: bool = False,
43
+ out_channels: Optional[int] = None,
44
+ padding: int = 1,
45
+ name: str = "conv",
46
+ ):
47
+ super().__init__()
48
+ self.channels = channels
49
+ self.out_channels = out_channels or channels
50
+ self.use_conv = use_conv
51
+ self.padding = padding
52
+ stride = 2
53
+ self.name = name
54
+ if use_conv:
55
+ self.conv = paddle.nn.Conv1D(
56
+ in_channels=self.channels,
57
+ out_channels=self.out_channels,
58
+ kernel_size=3,
59
+ stride=stride,
60
+ padding=padding,
61
+ )
62
+ else:
63
+ assert self.channels == self.out_channels
64
+ self.conv = paddle.nn.AvgPool1D(kernel_size=stride, stride=stride, exclusive=False)
65
+
66
+ def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
67
+ assert tuple(inputs.shape)[1] == self.channels
68
+ return self.conv(inputs)
69
+
70
+
71
+ class Downsample2D(paddle.nn.Layer):
72
+ """A 2D downsampling layer with an optional convolution.
73
+
74
+ Parameters:
75
+ channels (`int`):
76
+ number of channels in the inputs and outputs.
77
+ use_conv (`bool`, default `False`):
78
+ option to use a convolution.
79
+ out_channels (`int`, optional):
80
+ number of output channels. Defaults to `channels`.
81
+ padding (`int`, default `1`):
82
+ padding for the convolution.
83
+ name (`str`, default `conv`):
84
+ name of the downsampling 2D layer.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ channels: int,
90
+ use_conv: bool = False,
91
+ out_channels: Optional[int] = None,
92
+ padding: int = 1,
93
+ name: str = "conv",
94
+ kernel_size=3,
95
+ norm_type=None,
96
+ eps=None,
97
+ elementwise_affine=None,
98
+ bias=True,
99
+ ):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.out_channels = out_channels or channels
103
+ self.use_conv = use_conv
104
+ self.padding = padding
105
+ stride = 2
106
+ self.name = name
107
+ if norm_type == "ln_norm":
108
+ self.norm = paddle.nn.LayerNorm(
109
+ normalized_shape=channels, epsilon=eps, weight_attr=elementwise_affine, bias_attr=elementwise_affine
110
+ )
111
+ elif norm_type == "rms_norm":
112
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
113
+ elif norm_type is None:
114
+ self.norm = None
115
+ else:
116
+ raise ValueError(f"unknown norm_type: {norm_type}")
117
+ if use_conv:
118
+ conv = paddle.nn.Conv2D(
119
+ in_channels=self.channels,
120
+ out_channels=self.out_channels,
121
+ kernel_size=kernel_size,
122
+ stride=stride,
123
+ padding=padding,
124
+ bias_attr=bias,
125
+ )
126
+ else:
127
+ assert self.channels == self.out_channels
128
+ conv = paddle.nn.AvgPool2D(kernel_size=stride, stride=stride, exclusive=False)
129
+ if name == "conv":
130
+ self.Conv2d_0 = conv
131
+ self.conv = conv
132
+ elif name == "Conv2d_0":
133
+ self.conv = conv
134
+ else:
135
+ self.conv = conv
136
+
137
+ def forward(self, hidden_states: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
138
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
139
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
140
+ print("scale", "1.0.0", deprecation_message)
141
+ assert tuple(hidden_states.shape)[1] == self.channels
142
+ if self.norm is not None:
143
+ hidden_states = self.norm(hidden_states.transpose(perm=[0, 2, 3, 1])).transpose(perm=[0, 3, 1, 2])
144
+ if self.use_conv and self.padding == 0:
145
+ pad = 0, 1, 0, 1
146
+ hidden_states = paddle.nn.functional.pad(
147
+ x=hidden_states, pad=pad, mode="constant", value=0, pad_from_left_axis=False
148
+ )
149
+ assert tuple(hidden_states.shape)[1] == self.channels
150
+ hidden_states = self.conv(hidden_states)
151
+ return hidden_states
152
+
153
+
154
+ class FirDownsample2D(paddle.nn.Layer):
155
+ """A 2D FIR downsampling layer with an optional convolution.
156
+
157
+ Parameters:
158
+ channels (`int`):
159
+ number of channels in the inputs and outputs.
160
+ use_conv (`bool`, default `False`):
161
+ option to use a convolution.
162
+ out_channels (`int`, optional):
163
+ number of output channels. Defaults to `channels`.
164
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
165
+ kernel for the FIR filter.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ channels: Optional[int] = None,
171
+ out_channels: Optional[int] = None,
172
+ use_conv: bool = False,
173
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
174
+ ):
175
+ super().__init__()
176
+ out_channels = out_channels if out_channels else channels
177
+ if use_conv:
178
+ self.Conv2d_0 = paddle.nn.Conv2D(
179
+ in_channels=channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1
180
+ )
181
+ self.fir_kernel = fir_kernel
182
+ self.use_conv = use_conv
183
+ self.out_channels = out_channels
184
+
185
+ def _downsample_2d(
186
+ self,
187
+ hidden_states: paddle.Tensor,
188
+ weight: Optional[paddle.Tensor] = None,
189
+ kernel: Optional[paddle.Tensor] = None,
190
+ factor: int = 2,
191
+ gain: float = 1,
192
+ ) -> paddle.Tensor:
193
+ """Fused `Conv2d()` followed by `downsample_2d()`.
194
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
195
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
196
+ arbitrary order.
197
+
198
+ Args:
199
+ hidden_states (`torch.Tensor`):
200
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
201
+ weight (`torch.Tensor`, *optional*):
202
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
203
+ performed by `inChannels = x.shape[0] // numGroups`.
204
+ kernel (`torch.Tensor`, *optional*):
205
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
206
+ corresponds to average pooling.
207
+ factor (`int`, *optional*, default to `2`):
208
+ Integer downsampling factor.
209
+ gain (`float`, *optional*, default to `1.0`):
210
+ Scaling factor for signal magnitude.
211
+
212
+ Returns:
213
+ output (`torch.Tensor`):
214
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
215
+ datatype as `x`.
216
+ """
217
+ assert isinstance(factor, int) and factor >= 1
218
+ if kernel is None:
219
+ kernel = [1] * factor
220
+ kernel = paddle.to_tensor(data=kernel, dtype="float32")
221
+ if kernel.ndim == 1:
222
+ kernel = paddle.outer(x=kernel, y=kernel)
223
+ kernel /= paddle.sum(x=kernel)
224
+ kernel = kernel * gain
225
+ if self.use_conv:
226
+ _, _, convH, convW = tuple(weight.shape)
227
+ pad_value = tuple(kernel.shape)[0] - factor + (convW - 1)
228
+ stride_value = [factor, factor]
229
+ upfirdn_input = upfirdn2d_native(
230
+ hidden_states,
231
+ paddle.to_tensor(data=kernel, place=hidden_states.place),
232
+ pad=((pad_value + 1) // 2, pad_value // 2),
233
+ )
234
+ output = paddle.nn.functional.conv2d(x=upfirdn_input, weight=weight, stride=stride_value, padding=0)
235
+ else:
236
+ pad_value = tuple(kernel.shape)[0] - factor
237
+ output = upfirdn2d_native(
238
+ hidden_states,
239
+ paddle.to_tensor(data=kernel, place=hidden_states.place),
240
+ down=factor,
241
+ pad=((pad_value + 1) // 2, pad_value // 2),
242
+ )
243
+ return output
244
+
245
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
246
+ if self.use_conv:
247
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
248
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
249
+ else:
250
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
251
+ return hidden_states
252
+
253
+
254
+ class KDownsample2D(paddle.nn.Layer):
255
+ """A 2D K-downsampling layer.
256
+
257
+ Parameters:
258
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
259
+ """
260
+
261
+ def __init__(self, pad_mode: str = "reflect"):
262
+ super().__init__()
263
+ self.pad_mode = pad_mode
264
+ kernel_1d = paddle.to_tensor(data=[[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
265
+ self.pad = tuple(kernel_1d.shape)[1] // 2 - 1
266
+ self.register_buffer(name="kernel", tensor=kernel_1d.T @ kernel_1d, persistable=False)
267
+
268
+ def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
269
+ inputs = paddle.nn.functional.pad(x=inputs, pad=(self.pad,) * 4, mode=self.pad_mode, pad_from_left_axis=False)
270
+ weight = paddle.zeros(
271
+ shape=[
272
+ tuple(inputs.shape)[1],
273
+ tuple(inputs.shape)[1],
274
+ tuple(self.kernel.shape)[0],
275
+ tuple(self.kernel.shape)[1],
276
+ ],
277
+ dtype=inputs.dtype,
278
+ )
279
+ indices = paddle.arange(end=tuple(inputs.shape)[1])
280
+ kernel = self.kernel.to(weight)[None, :].expand(shape=[tuple(inputs.shape)[1], -1, -1])
281
+ weight[indices, indices] = kernel
282
+ return paddle.nn.functional.conv2d(x=inputs, weight=weight, stride=2)
283
+
284
+
285
+ class CogVideoXDownsample3D(paddle.nn.Layer):
286
+ """
287
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
288
+
289
+ Args:
290
+ in_channels (`int`):
291
+ Number of channels in the input image.
292
+ out_channels (`int`):
293
+ Number of channels produced by the convolution.
294
+ kernel_size (`int`, defaults to `3`):
295
+ Size of the convolving kernel.
296
+ stride (`int`, defaults to `2`):
297
+ Stride of the convolution.
298
+ padding (`int`, defaults to `0`):
299
+ Padding added to all four sides of the input.
300
+ compress_time (`bool`, defaults to `False`):
301
+ Whether or not to compress the time dimension.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ in_channels: int,
307
+ out_channels: int,
308
+ kernel_size: int = 3,
309
+ stride: int = 2,
310
+ padding: int = 0,
311
+ compress_time: bool = False,
312
+ ):
313
+ super().__init__()
314
+ self.conv = paddle.nn.Conv2D(
315
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding
316
+ )
317
+ self.compress_time = compress_time
318
+
319
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
320
+ if self.compress_time:
321
+ batch_size, channels, frames, height, width = tuple(x.shape)
322
+ x = x.transpose(perm=[0, 3, 4, 1, 2]).reshape([batch_size * height * width, channels, frames])
323
+ if tuple(x.shape)[-1] % 2 == 1:
324
+ x_first, x_rest = x[..., 0], x[..., 1:]
325
+ if tuple(x_rest.shape)[-1] > 0:
326
+ x_rest = paddle.nn.functional.avg_pool1d(kernel_size=2, stride=2, x=x_rest, exclusive=False)
327
+ x = paddle.concat(x=[x_first[..., None], x_rest], axis=-1)
328
+ x = x.reshape([batch_size, height, width, channels, tuple(x.shape)[-1]]).transpose(
329
+ perm=[0, 3, 4, 1, 2]
330
+ )
331
+ else:
332
+ x = paddle.nn.functional.avg_pool1d(kernel_size=2, stride=2, x=x, exclusive=False)
333
+ x = x.reshape([batch_size, height, width, channels, tuple(x.shape)[-1]]).transpose(
334
+ perm=[0, 3, 4, 1, 2]
335
+ )
336
+ pad = (0, 1, 0, 1, 0, 0)
337
+ x = paddle.nn.functional.pad(x=x, pad=pad, mode="constant", value=0, data_format="NCDHW")
338
+ batch_size, channels, frames, height, width = tuple(x.shape)
339
+ x = x.transpose(perm=[0, 2, 1, 3, 4]).reshape([batch_size * frames, channels, height, width])
340
+ x = self.conv(x)
341
+ x = x.reshape([batch_size, frames, tuple(x.shape)[1], tuple(x.shape)[2], tuple(x.shape)[3]]).transpose(
342
+ perm=[0, 2, 1, 3, 4]
343
+ )
344
+ return x
345
+
346
+
347
+ def downsample_2d(
348
+ hidden_states: paddle.Tensor, kernel: Optional[paddle.Tensor] = None, factor: int = 2, gain: float = 1
349
+ ) -> paddle.Tensor:
350
+ """Downsample2D a batch of 2D images with the given filter.
351
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
352
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
353
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
354
+ shape is a multiple of the downsampling factor.
355
+
356
+ Args:
357
+ hidden_states (`torch.Tensor`)
358
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
359
+ kernel (`torch.Tensor`, *optional*):
360
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
361
+ corresponds to average pooling.
362
+ factor (`int`, *optional*, default to `2`):
363
+ Integer downsampling factor.
364
+ gain (`float`, *optional*, default to `1.0`):
365
+ Scaling factor for signal magnitude.
366
+
367
+ Returns:
368
+ output (`torch.Tensor`):
369
+ Tensor of the shape `[N, C, H // factor, W // factor]`
370
+ """
371
+ assert isinstance(factor, int) and factor >= 1
372
+ if kernel is None:
373
+ kernel = [1] * factor
374
+ kernel = paddle.to_tensor(data=kernel, dtype="float32")
375
+ if kernel.ndim == 1:
376
+ kernel = paddle.outer(x=kernel, y=kernel)
377
+ kernel /= paddle.sum(x=kernel)
378
+ kernel = kernel * gain
379
+ pad_value = tuple(kernel.shape)[0] - factor
380
+ output = upfirdn2d_native(
381
+ hidden_states, kernel.to(device=hidden_states.place), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
382
+ )
383
+ return output
PaddleMIX/ppdiffusers/ppdiffusers/models/dual_transformer_2d.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ import paddle.nn as nn
17
+
18
+ from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Layer):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.LayerList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ added_cond_kwargs=None,
103
+ class_labels=None,
104
+ cross_attention_kwargs=None,
105
+ attention_mask=None,
106
+ encoder_attention_mask=None,
107
+ return_dict: bool = True,
108
+ ):
109
+ """
110
+ Args:
111
+ hidden_states ( When discrete, `paddle.Tensor` of shape `(batch size, num latent pixels)`.
112
+ When continuous, `paddle.Tensor` of shape `(batch size, channel, height, width)`): Input
113
+ hidden_states.
114
+ encoder_hidden_states ( `paddle.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
115
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
116
+ self-attention.
117
+ timestep ( `paddle.Tensor`, *optional*):
118
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
119
+ attention_mask (`paddle.Tensor`, *optional*):
120
+ Optional attention mask to be applied in Attention.
121
+ cross_attention_kwargs (`dict`, *optional*):
122
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
123
+ `self.processor` in
124
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
125
+ return_dict (`bool`, *optional*, defaults to `True`):
126
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
127
+
128
+ Returns:
129
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
130
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
131
+ returning a tuple, the first element is the sample tensor.
132
+ """
133
+ input_states = hidden_states
134
+
135
+ encoded_states = []
136
+ tokens_start = 0
137
+ # attention_mask is not used yet
138
+ for i in range(2):
139
+ # for each of the two transformers, pass the corresponding condition tokens
140
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
141
+ transformer_index = self.transformer_index_for_condition[i]
142
+ encoded_state = self.transformers[transformer_index](
143
+ input_states,
144
+ encoder_hidden_states=condition_state,
145
+ timestep=timestep,
146
+ cross_attention_kwargs=cross_attention_kwargs,
147
+ return_dict=False,
148
+ )[0]
149
+ encoded_states.append(encoded_state - input_states)
150
+ tokens_start += self.condition_lengths[i]
151
+
152
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
153
+ output_states = output_states + input_states
154
+
155
+ if not return_dict:
156
+ return (output_states,)
157
+
158
+ return Transformer2DModelOutput(sample=output_states)
PaddleMIX/ppdiffusers/ppdiffusers/models/lora.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ # IMPORTANT: #
17
+ ###################################################################
18
+ # ----------------------------------------------------------------#
19
+ # This file is deprecated and will be removed soon #
20
+ # (as soon as PEFT will become a required dependency for LoRA) #
21
+ # ----------------------------------------------------------------#
22
+ ###################################################################
23
+
24
+ import contextlib
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import paddle
28
+ from paddle import nn
29
+
30
+ from ppdiffusers.transformers import CLIPTextModel, CLIPTextModelWithProjection
31
+
32
+ from ..utils import logging
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def text_encoder_attn_modules(text_encoder):
38
+ attn_modules = []
39
+
40
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
41
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
42
+ name = f"text_model.encoder.layers.{i}.self_attn"
43
+ mod = layer.self_attn
44
+ attn_modules.append((name, mod))
45
+ else:
46
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
47
+
48
+ return attn_modules
49
+
50
+
51
+ def text_encoder_mlp_modules(text_encoder):
52
+ mlp_modules = []
53
+
54
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
55
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
56
+ mlp_mod = layer.mlp
57
+ name = f"text_model.encoder.layers.{i}.mlp"
58
+ mlp_modules.append((name, mlp_mod))
59
+ else:
60
+ raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
61
+
62
+ return mlp_modules
63
+
64
+
65
+ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
66
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
67
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
68
+ attn_module.q_proj.lora_scale = lora_scale
69
+ attn_module.k_proj.lora_scale = lora_scale
70
+ attn_module.v_proj.lora_scale = lora_scale
71
+ attn_module.out_proj.lora_scale = lora_scale
72
+
73
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
74
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
75
+ mlp_module.fc1.lora_scale = lora_scale
76
+ mlp_module.fc2.lora_scale = lora_scale
77
+
78
+
79
+ class PatchedLoraProjection(nn.Layer):
80
+ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
81
+ super().__init__()
82
+ from ..models.lora import LoRALinearLayer
83
+
84
+ self.regular_linear_layer = regular_linear_layer
85
+
86
+ if dtype is None:
87
+ dtype = self.regular_linear_layer.weight.dtype
88
+
89
+ self.lora_linear_layer = LoRALinearLayer(
90
+ self.regular_linear_layer.in_features,
91
+ self.regular_linear_layer.out_features,
92
+ network_alpha=network_alpha,
93
+ dtype=dtype,
94
+ rank=rank,
95
+ )
96
+
97
+ self.lora_scale = lora_scale
98
+
99
+ # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
100
+ # when saving the whole text encoder model and when LoRA is unloaded or fused
101
+ def state_dict(self, destination=None, include_sublayers=True, structured_name_prefix="", use_hook=True):
102
+ if self.lora_linear_layer is None:
103
+ return self.regular_linear_layer.state_dict(
104
+ destination=destination,
105
+ include_sublayers=include_sublayers,
106
+ structured_name_prefix=structured_name_prefix,
107
+ use_hook=use_hook,
108
+ )
109
+
110
+ return super().state_dict(
111
+ destination=destination,
112
+ include_sublayers=include_sublayers,
113
+ structured_name_prefix=structured_name_prefix,
114
+ use_hook=use_hook,
115
+ )
116
+
117
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
118
+ if self.lora_linear_layer is None:
119
+ return
120
+
121
+ dtype = self.regular_linear_layer.weight.dtype
122
+
123
+ w_orig = self.regular_linear_layer.weight.cast("float32")
124
+ w_up = self.lora_linear_layer.up.weight.cast("float32")
125
+ w_down = self.lora_linear_layer.down.weight.cast("float32")
126
+
127
+ if self.lora_linear_layer.network_alpha is not None:
128
+ w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
129
+
130
+ fused_weight = w_orig + (lora_scale * paddle.matmul(w_down[None, :], w_up[None, :])[0])
131
+
132
+ if safe_fusing and paddle.isnan(fused_weight).any().item():
133
+ raise ValueError(
134
+ "This LoRA weight seems to be broken. "
135
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
136
+ "LoRA weights will not be fused."
137
+ )
138
+
139
+ self.regular_linear_layer.weight.copy_(fused_weight.cast(dtype=dtype), False)
140
+
141
+ # we can drop the lora layer now
142
+ self.lora_linear_layer = None
143
+
144
+ # offload the up and down matrices to CPU to not blow the memory
145
+ self.w_up = w_up.cpu()
146
+ self.w_down = w_down.cpu()
147
+ self.lora_scale = lora_scale
148
+
149
+ def _unfuse_lora(self):
150
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
151
+ return
152
+
153
+ fused_weight = self.regular_linear_layer.weight
154
+ dtype = fused_weight.dtype
155
+
156
+ w_up = self.w_up.cast("float32")
157
+ w_down = self.w_down.cast("float32")
158
+
159
+ unfused_weight = fused_weight.cast("float32") - (
160
+ self.lora_scale * paddle.matmul(w_down[None, :], w_up[None, :])[0]
161
+ )
162
+ self.regular_linear_layer.weight.copy_(unfused_weight.cast(dtype=dtype), False)
163
+
164
+ self.w_up = None
165
+ self.w_down = None
166
+
167
+ def forward(self, input):
168
+ if self.lora_scale is None:
169
+ self.lora_scale = 1.0
170
+ if self.lora_linear_layer is None:
171
+ return self.regular_linear_layer(input)
172
+ return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
173
+
174
+
175
+ class LoRALinearLayer(nn.Layer):
176
+ r"""
177
+ A linear layer that is used with LoRA.
178
+
179
+ Parameters:
180
+ in_features (`int`):
181
+ Number of input features.
182
+ out_features (`int`):
183
+ Number of output features.
184
+ rank (`int`, `optional`, defaults to 4):
185
+ The rank of the LoRA layer.
186
+ network_alpha (`float`, `optional`, defaults to `None`):
187
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
188
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
189
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
190
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
191
+ The dtype to use for the layer's weights.
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ in_features: int,
197
+ out_features: int,
198
+ rank: int = 4,
199
+ network_alpha: Optional[float] = None,
200
+ dtype: Optional[paddle.dtype] = None,
201
+ ):
202
+ super().__init__()
203
+ if dtype is not None:
204
+ ctx = paddle.dtype_guard(dtype)
205
+ else:
206
+ ctx = contextlib.nullcontext()
207
+ with ctx:
208
+ self.down = nn.Linear(in_features, rank, bias_attr=False)
209
+ self.up = nn.Linear(rank, out_features, bias_attr=False)
210
+
211
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
212
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
213
+ self.network_alpha = network_alpha
214
+ self.rank = rank
215
+ self.out_features = out_features
216
+ self.in_features = in_features
217
+
218
+ nn.init.normal_(self.down.weight, std=1 / rank)
219
+ nn.init.zeros_(self.up.weight)
220
+
221
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
222
+ orig_dtype = hidden_states.dtype
223
+ dtype = self.down.weight.dtype
224
+ down_hidden_states = self.down(hidden_states.cast(dtype))
225
+ up_hidden_states = self.up(down_hidden_states)
226
+
227
+ if self.network_alpha is not None:
228
+ up_hidden_states *= self.network_alpha / self.rank
229
+
230
+ return up_hidden_states.cast(orig_dtype)
231
+
232
+
233
+ class LoRAConv2dLayer(nn.Layer):
234
+ r"""
235
+ A convolutional layer that is used with LoRA.
236
+
237
+ Parameters:
238
+ in_features (`int`):
239
+ Number of input features.
240
+ out_features (`int`):
241
+ Number of output features.
242
+ rank (`int`, `optional`, defaults to 4):
243
+ The rank of the LoRA layer.
244
+ kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
245
+ The kernel size of the convolution.
246
+ stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
247
+ The stride of the convolution.
248
+ padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
249
+ The padding of the convolution.
250
+ network_alpha (`float`, `optional`, defaults to `None`):
251
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
252
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
253
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ in_features: int,
259
+ out_features: int,
260
+ rank: int = 4,
261
+ kernel_size: Union[int, Tuple[int, int]] = (1, 1),
262
+ stride: Union[int, Tuple[int, int]] = (1, 1),
263
+ padding: Union[int, Tuple[int, int], str] = 0,
264
+ network_alpha: Optional[float] = None,
265
+ ):
266
+ super().__init__()
267
+
268
+ self.down = nn.Conv2D(
269
+ in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias_attr=False
270
+ )
271
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
272
+ # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
273
+ self.up = nn.Conv2D(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias_attr=False)
274
+
275
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
276
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
277
+ self.network_alpha = network_alpha
278
+ self.rank = rank
279
+
280
+ nn.init.normal_(self.down.weight, std=1 / rank)
281
+ nn.init.zeros_(self.up.weight)
282
+
283
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
284
+ orig_dtype = hidden_states.dtype
285
+ dtype = self.down.weight.dtype
286
+
287
+ down_hidden_states = self.down(hidden_states.cast(dtype))
288
+ up_hidden_states = self.up(down_hidden_states)
289
+
290
+ if self.network_alpha is not None:
291
+ up_hidden_states *= self.network_alpha / self.rank
292
+
293
+ return up_hidden_states.cast(orig_dtype)
294
+
295
+
296
+ class LoRACompatibleConv(nn.Conv2D):
297
+ """
298
+ A convolutional layer that can be used with LoRA.
299
+ """
300
+
301
+ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
302
+ super().__init__(*args, **kwargs)
303
+ self.in_channels = self._in_channels
304
+ self.out_channels = self._out_channels
305
+ self.kernel_size = self._kernel_size
306
+ self.lora_layer = lora_layer
307
+ self.data_format = kwargs.get("data_format", "NCHW")
308
+
309
+ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
310
+ self.lora_layer = lora_layer
311
+
312
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
313
+ if self.lora_layer is None:
314
+ return
315
+
316
+ dtype = self.weight.dtype
317
+
318
+ w_orig = self.weight.cast("float32")
319
+ w_up = self.lora_layer.up.weight.cast("float32")
320
+ w_down = self.lora_layer.down.weight.cast("float32")
321
+
322
+ if self.lora_layer.network_alpha is not None:
323
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
324
+
325
+ fusion = paddle.matmul(w_up.flatten(start_axis=1), w_down.flatten(start_axis=1))
326
+ fusion = fusion.reshape(w_orig.shape)
327
+ fused_weight = w_orig + (lora_scale * fusion)
328
+
329
+ if safe_fusing and paddle.isnan(fused_weight).any().item():
330
+ raise ValueError(
331
+ "This LoRA weight seems to be broken. "
332
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
333
+ "LoRA weights will not be fused."
334
+ )
335
+
336
+ self.weight.copy_(fused_weight.cast(dtype=dtype), False)
337
+
338
+ # we can drop the lora layer now
339
+ self.lora_layer = None
340
+
341
+ # offload the up and down matrices to CPU to not blow the memory
342
+ self.w_up = w_up.cpu()
343
+ self.w_down = w_down.cpu()
344
+ self._lora_scale = lora_scale
345
+
346
+ def _unfuse_lora(self):
347
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
348
+ return
349
+
350
+ fused_weight = self.weight
351
+ dtype = fused_weight.dtype
352
+
353
+ w_up = self.w_up.cast("float32")
354
+ w_down = self.w_down.cast("float32")
355
+
356
+ fusion = paddle.matmul(w_up.flatten(start_axis=1), w_down.flatten(start_axis=1))
357
+ fusion = fusion.reshape(fused_weight.shape)
358
+ unfused_weight = fused_weight.cast("float32") - (self._lora_scale * fusion)
359
+ self.weight.copy_(unfused_weight.cast(dtype=dtype), False)
360
+
361
+ self.w_up = None
362
+ self.w_down = None
363
+
364
+ def forward(self, hidden_states: paddle.Tensor, scale: float = 1.0) -> paddle.Tensor:
365
+ if self.lora_layer is None:
366
+ # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
367
+ # see: https://github.com/huggingface/diffusers/pull/4315
368
+ return nn.functional.conv2d(
369
+ hidden_states,
370
+ self.weight,
371
+ self.bias,
372
+ self._stride,
373
+ self._padding,
374
+ self._dilation,
375
+ self._groups,
376
+ data_format=self.data_format,
377
+ )
378
+ else:
379
+ original_outputs = nn.functional.conv2d(
380
+ hidden_states,
381
+ self.weight,
382
+ self.bias,
383
+ self._stride,
384
+ self._padding,
385
+ self._dilation,
386
+ self._groups,
387
+ data_format=self.data_format,
388
+ )
389
+ return original_outputs + (scale * self.lora_layer(hidden_states))
390
+
391
+
392
+ class LoRACompatibleLinear(nn.Linear):
393
+ """
394
+ A Linear layer that can be used with LoRA.
395
+ """
396
+
397
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
398
+ super().__init__(*args, **kwargs)
399
+ self.lora_layer = lora_layer
400
+
401
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
402
+ self.lora_layer = lora_layer
403
+
404
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
405
+ if self.lora_layer is None:
406
+ return
407
+
408
+ dtype = self.weight.dtype
409
+
410
+ w_orig = self.weight.cast("float32")
411
+ w_up = self.lora_layer.up.weight.cast("float32")
412
+ w_down = self.lora_layer.down.weight.cast("float32")
413
+
414
+ if self.lora_layer.network_alpha is not None:
415
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
416
+
417
+ fused_weight = w_orig + (lora_scale * paddle.matmul(w_down[None, :], w_up[None, :])[0])
418
+
419
+ if safe_fusing and paddle.isnan(fused_weight).any().item():
420
+ raise ValueError(
421
+ "This LoRA weight seems to be broken. "
422
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
423
+ "LoRA weights will not be fused."
424
+ )
425
+ self.weight.copy_(fused_weight.cast(dtype=dtype), False)
426
+
427
+ # we can drop the lora layer now
428
+ self.lora_layer = None
429
+
430
+ # offload the up and down matrices to CPU to not blow the memory
431
+ self.w_up = w_up.cpu()
432
+ self.w_down = w_down.cpu()
433
+ self._lora_scale = lora_scale
434
+
435
+ def _unfuse_lora(self):
436
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
437
+ return
438
+
439
+ fused_weight = self.weight
440
+ dtype = fused_weight.dtype
441
+
442
+ w_up = self.w_up.cast("float32")
443
+ w_down = self.w_down.cast("float32")
444
+
445
+ unfused_weight = fused_weight.cast("float32") - (
446
+ self._lora_scale * paddle.matmul(w_down[None, :], w_up[None, :])[0]
447
+ )
448
+ self.weight.copy_(unfused_weight.cast(dtype=dtype), False)
449
+
450
+ self.w_up = None
451
+ self.w_down = None
452
+
453
+ def forward(self, hidden_states: paddle.Tensor, scale: float = 1.0) -> paddle.Tensor:
454
+ if self.lora_layer is None:
455
+ return nn.functional.linear(
456
+ hidden_states,
457
+ self.weight,
458
+ self.bias,
459
+ )
460
+ else:
461
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
462
+ return out
PaddleMIX/ppdiffusers/ppdiffusers/models/lvdm_attention_temporal.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn.functional as F
17
+ from paddle.distributed.fleet.utils import recompute
18
+
19
+ try:
20
+ from paddle.incubate.nn.memory_efficient_attention import ( # noqa
21
+ memory_efficient_attention,
22
+ )
23
+
24
+ _ppxformers_available = True
25
+ except:
26
+ _ppxformers_available = False
27
+
28
+ import math
29
+
30
+ from einops import rearrange, repeat
31
+
32
+ from ..utils.initializer_utils import constant_, xavier_uniform_
33
+ from .lvdm_util import (
34
+ GEGLU,
35
+ Normalize,
36
+ conv_nd,
37
+ default,
38
+ exists,
39
+ normalization,
40
+ zero_module,
41
+ )
42
+
43
+
44
+ class FeedForward(paddle.nn.Layer):
45
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
46
+ super().__init__()
47
+ inner_dim = int(dim * mult)
48
+ dim_out = default(dim_out, dim)
49
+ project_in = (
50
+ paddle.nn.Sequential(paddle.nn.Linear(in_features=dim, out_features=inner_dim), paddle.nn.GELU())
51
+ if not glu
52
+ else GEGLU(dim, inner_dim)
53
+ )
54
+ self.net = paddle.nn.Sequential(
55
+ project_in, paddle.nn.Dropout(p=dropout), paddle.nn.Linear(in_features=inner_dim, out_features=dim_out)
56
+ )
57
+
58
+ def forward(self, x):
59
+ return self.net(x)
60
+
61
+
62
+ class RelativePosition(paddle.nn.Layer):
63
+ """https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py"""
64
+
65
+ def __init__(self, num_units, max_relative_position):
66
+ super().__init__()
67
+ self.num_units = num_units
68
+ self.max_relative_position = max_relative_position
69
+ self.embeddings_table = paddle.nn.Parameter(paddle.empty(shape=[max_relative_position * 2 + 1, num_units]))
70
+ xavier_uniform_(self.embeddings_table)
71
+
72
+ def forward(self, length_q, length_k):
73
+ range_vec_q = paddle.arange(end=length_q)
74
+ range_vec_k = paddle.arange(end=length_k)
75
+ distance_mat = range_vec_k[(None), :] - range_vec_q[:, (None)]
76
+ distance_mat_clipped = paddle.clip(
77
+ x=distance_mat, min=-self.max_relative_position, max=self.max_relative_position
78
+ )
79
+ final_mat = distance_mat_clipped + self.max_relative_position
80
+ final_mat = final_mat.astype(dtype="int64")
81
+ embeddings = self.embeddings_table[final_mat]
82
+ return embeddings
83
+
84
+
85
+ class TemporalCrossAttention(paddle.nn.Layer):
86
+ def __init__(
87
+ self,
88
+ query_dim,
89
+ context_dim=None,
90
+ heads=8,
91
+ dim_head=64,
92
+ dropout=0.0,
93
+ use_relative_position=False,
94
+ temporal_length=None,
95
+ **kwargs
96
+ ):
97
+ super().__init__()
98
+ inner_dim = dim_head * heads
99
+ context_dim = default(context_dim, query_dim)
100
+ self.context_dim = context_dim
101
+ self.scale = dim_head**-0.5
102
+ self.heads = heads
103
+ self.temporal_length = temporal_length
104
+ self.use_relative_position = use_relative_position
105
+ self.to_q = paddle.nn.Linear(in_features=query_dim, out_features=inner_dim, bias_attr=False)
106
+ self.to_k = paddle.nn.Linear(in_features=context_dim, out_features=inner_dim, bias_attr=False)
107
+ self.to_v = paddle.nn.Linear(in_features=context_dim, out_features=inner_dim, bias_attr=False)
108
+ self.to_out = paddle.nn.Sequential(
109
+ paddle.nn.Linear(in_features=inner_dim, out_features=query_dim), paddle.nn.Dropout(p=dropout)
110
+ )
111
+ if use_relative_position:
112
+ assert temporal_length is not None
113
+ self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
114
+ self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length)
115
+ constant_(self.to_q.weight, 0)
116
+ constant_(self.to_k.weight, 0)
117
+ constant_(self.to_v.weight, 0)
118
+ constant_(self.to_out[0].weight, 0)
119
+ constant_(self.to_out[0].bias, 0)
120
+
121
+ def forward(self, x, context=None, mask=None):
122
+ nh = self.heads
123
+ out = x
124
+ q = self.to_q(out)
125
+ context = default(context, x)
126
+ k = self.to_k(context)
127
+ v = self.to_v(context)
128
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=nh), (q, k, v))
129
+ sim = paddle.einsum("b i d, b j d -> b i j", q, k) * self.scale
130
+ if self.use_relative_position:
131
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
132
+ k2 = self.relative_position_k(len_q, len_k)
133
+ sim2 = paddle.einsum("b t d, t s d -> b t s", q, k2) * self.scale
134
+ sim += sim2
135
+ if mask is not None:
136
+ max_neg_value = -1000000000.0
137
+ sim = sim + (1 - mask.astype(dtype="float32")) * max_neg_value
138
+ attn = paddle.nn.functional.softmax(sim, axis=-1)
139
+ out = paddle.einsum("b i j, b j d -> b i d", attn, v)
140
+ if self.use_relative_position:
141
+ v2 = self.relative_position_v(len_q, len_v)
142
+ out2 = paddle.einsum("b t s, t s d -> b t d", attn, v2)
143
+ out += out2
144
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=nh)
145
+ return self.to_out(out)
146
+
147
+
148
+ class CrossAttention(paddle.nn.Layer):
149
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
150
+ super().__init__()
151
+ inner_dim = dim_head * heads
152
+ context_dim = default(context_dim, query_dim)
153
+ self.scale = dim_head**-0.5
154
+ self.heads = heads
155
+ self.to_q = paddle.nn.Linear(in_features=query_dim, out_features=inner_dim, bias_attr=False)
156
+ self.to_k = paddle.nn.Linear(in_features=context_dim, out_features=inner_dim, bias_attr=False)
157
+ self.to_v = paddle.nn.Linear(in_features=context_dim, out_features=inner_dim, bias_attr=False)
158
+ self.to_out = paddle.nn.Sequential(
159
+ paddle.nn.Linear(in_features=inner_dim, out_features=query_dim), paddle.nn.Dropout(p=dropout)
160
+ )
161
+
162
+ def forward(self, x, context=None, mask=None):
163
+ h = self.heads
164
+ # b = x.shape[0]
165
+ q = self.to_q(x)
166
+ context = default(context, x)
167
+ k = self.to_k(context)
168
+ v = self.to_v(context)
169
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
170
+ sim = paddle.einsum("b i d, b j d -> b i j", q, k) * self.scale
171
+ if exists(mask):
172
+ mask = rearrange(mask, "b ... -> b (...)")
173
+ max_neg_value = -paddle.finfo(sim.dtype).max
174
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
175
+ sim = paddle.masked_fill(sim, ~mask, max_neg_value)
176
+ attn = paddle.nn.functional.softmax(sim, axis=-1)
177
+ out = paddle.einsum("b i j, b j d -> b i d", attn, v)
178
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
179
+ return self.to_out(out)
180
+
181
+
182
+ class MemoryEfficientCrossAttention(paddle.nn.Layer):
183
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
184
+ super().__init__()
185
+ # print(
186
+ # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads."
187
+ # )
188
+ inner_dim = dim_head * heads
189
+ context_dim = default(context_dim, query_dim)
190
+ self.heads = heads
191
+ self.dim_head = dim_head
192
+ self.to_q = paddle.nn.Linear(in_features=query_dim, out_features=inner_dim, bias_attr=False)
193
+ self.to_k = paddle.nn.Linear(in_features=context_dim, out_features=inner_dim, bias_attr=False)
194
+ self.to_v = paddle.nn.Linear(in_features=context_dim, out_features=inner_dim, bias_attr=False)
195
+ self.to_out = paddle.nn.Sequential(
196
+ paddle.nn.Linear(in_features=inner_dim, out_features=query_dim), paddle.nn.Dropout(p=dropout)
197
+ )
198
+ self.attention_op = "cutlass"
199
+
200
+ def forward(self, x, context=None, mask=None):
201
+ q = self.to_q(x)
202
+ context = default(context, x)
203
+ k = self.to_k(context)
204
+ v = self.to_v(context)
205
+ b, _, _ = q.shape
206
+ q, k, v = map(lambda t: t.reshape([0, 0, self.heads, self.dim_head]), (q, k, v))
207
+ out = F.scaled_dot_product_attention_(
208
+ q,
209
+ k,
210
+ v,
211
+ attn_mask=None,
212
+ dropout_p=0.0,
213
+ attention_op=self.attention_op,
214
+ training=True,
215
+ )
216
+ if exists(mask):
217
+ raise NotImplementedError
218
+ out = out.reshape([0, 0, self.heads * self.dim_head])
219
+ return self.to_out(out)
220
+
221
+
222
+ class BasicTransformerBlockST(paddle.nn.Layer):
223
+ """
224
+ if no context is given to forward function, cross-attention defaults to self-attention
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ dim,
230
+ n_heads,
231
+ d_head,
232
+ dropout=0.0,
233
+ context_dim=None,
234
+ gated_ff=True,
235
+ checkpoint=True,
236
+ temporal_length=None,
237
+ use_relative_position=True,
238
+ **kwargs
239
+ ):
240
+ super().__init__()
241
+ if _ppxformers_available:
242
+ self.attn1 = MemoryEfficientCrossAttention(
243
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs
244
+ )
245
+ self.attn2 = MemoryEfficientCrossAttention(
246
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs
247
+ )
248
+ else:
249
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs)
250
+ self.attn2 = CrossAttention(
251
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout, **kwargs
252
+ )
253
+
254
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
255
+ self.norm1 = paddle.nn.LayerNorm(normalized_shape=dim, epsilon=1e-05, weight_attr=None, bias_attr=None)
256
+ self.norm2 = paddle.nn.LayerNorm(normalized_shape=dim, epsilon=1e-05, weight_attr=None, bias_attr=None)
257
+ self.norm3 = paddle.nn.LayerNorm(normalized_shape=dim, epsilon=1e-05, weight_attr=None, bias_attr=None)
258
+ self.checkpoint = checkpoint
259
+ self.attn1_tmp = TemporalCrossAttention(
260
+ query_dim=dim,
261
+ heads=n_heads,
262
+ dim_head=d_head,
263
+ dropout=dropout,
264
+ temporal_length=temporal_length,
265
+ use_relative_position=use_relative_position,
266
+ **kwargs,
267
+ )
268
+ self.attn2_tmp = TemporalCrossAttention(
269
+ query_dim=dim,
270
+ heads=n_heads,
271
+ dim_head=d_head,
272
+ dropout=dropout,
273
+ context_dim=None,
274
+ temporal_length=temporal_length,
275
+ use_relative_position=use_relative_position,
276
+ **kwargs,
277
+ )
278
+ self.norm4 = paddle.nn.LayerNorm(normalized_shape=dim, epsilon=1e-05, weight_attr=None, bias_attr=None)
279
+ self.norm5 = paddle.nn.LayerNorm(normalized_shape=dim, epsilon=1e-05, weight_attr=None, bias_attr=None)
280
+
281
+ def forward(self, x, context=None, **kwargs):
282
+ if self.checkpoint:
283
+ return recompute(self._forward, x, context)
284
+ else:
285
+ return self._forward(x, context)
286
+
287
+ def _forward(self, x, context=None, mask=None):
288
+ assert x.dim() == 5, f"x shape = {x.shape}"
289
+ b, c, t, h, w = x.shape
290
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
291
+ x = self.attn1(self.norm1(x)) + x
292
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
293
+ x = rearrange(x, "b c t h w -> (b h w) t c")
294
+ x = self.attn1_tmp(self.norm4(x), mask=mask) + x
295
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
296
+ x = rearrange(x, "b c t h w -> (b t) (h w) c")
297
+ if context is not None:
298
+ context_ = []
299
+ for i in range(context.shape[0]):
300
+ context_.append(context[i].unsqueeze(axis=0).tile(repeat_times=[t, 1, 1]))
301
+ context_ = paddle.concat(x=context_, axis=0)
302
+ else:
303
+ context_ = None
304
+ x = self.attn2(self.norm2(x), context=context_) + x
305
+ x = rearrange(x, "(b t) (h w) c -> b c t h w", b=b, h=h)
306
+ x = rearrange(x, "b c t h w -> (b h w) t c")
307
+ x = self.attn2_tmp(self.norm5(x), context=None, mask=mask) + x
308
+ x = self.ff(self.norm3(x)) + x
309
+ x = rearrange(x, "(b h w) t c -> b c t h w", b=b, h=h, w=w)
310
+ return x
311
+
312
+
313
+ class SpatialTemporalTransformer(paddle.nn.Layer):
314
+ """
315
+ Transformer block for video-like data (5D tensor).
316
+ First, project the input (aka embedding) with NO reshape.
317
+ Then apply standard transformer action.
318
+ The 5D -> 3D reshape operation will be done in the specific attention module.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ in_channels,
324
+ n_heads,
325
+ d_head,
326
+ depth=1,
327
+ dropout=0.0,
328
+ context_dim=None,
329
+ temporal_length=None,
330
+ use_relative_position=True,
331
+ **kwargs
332
+ ):
333
+ super().__init__()
334
+ self.in_channels = in_channels
335
+ inner_dim = n_heads * d_head
336
+ self.norm = Normalize(in_channels)
337
+ self.proj_in = paddle.nn.Conv3D(
338
+ in_channels=in_channels, out_channels=inner_dim, kernel_size=1, stride=1, padding=0
339
+ )
340
+ self.transformer_blocks = paddle.nn.LayerList(
341
+ sublayers=[
342
+ BasicTransformerBlockST(
343
+ inner_dim,
344
+ n_heads,
345
+ d_head,
346
+ dropout=dropout,
347
+ context_dim=context_dim,
348
+ temporal_length=temporal_length,
349
+ use_relative_position=use_relative_position,
350
+ **kwargs,
351
+ )
352
+ for d in range(depth)
353
+ ]
354
+ )
355
+ self.proj_out = zero_module(
356
+ paddle.nn.Conv3D(in_channels=inner_dim, out_channels=in_channels, kernel_size=1, stride=1, padding=0)
357
+ )
358
+
359
+ def forward(self, x, context=None, **kwargs):
360
+ assert x.dim() == 5, f"x shape = {x.shape}"
361
+ x_in = x
362
+ x = self.norm(x)
363
+ x = self.proj_in(x)
364
+ for block in self.transformer_blocks:
365
+ x = block(x, context=context, **kwargs)
366
+ x = self.proj_out(x)
367
+ return x + x_in
368
+
369
+
370
+ class STAttentionBlock(paddle.nn.Layer):
371
+ def __init__(
372
+ self,
373
+ channels,
374
+ num_heads=1,
375
+ num_head_channels=-1,
376
+ use_checkpoint=False,
377
+ temporal_length=16,
378
+ use_relative_position=False,
379
+ ):
380
+ super().__init__()
381
+ if num_head_channels == -1:
382
+ self.num_heads = num_heads
383
+ else:
384
+ assert (
385
+ channels % num_head_channels == 0
386
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
387
+ self.num_heads = channels // num_head_channels
388
+ self.use_checkpoint = use_checkpoint
389
+ self.temporal_length = temporal_length
390
+ self.use_relative_position = use_relative_position
391
+ self.norm_s = normalization(channels)
392
+ self.norm_t = normalization(channels)
393
+ self.qkv_s = conv_nd(1, channels, channels * 3, 1)
394
+ self.qkv_t = conv_nd(1, channels, channels * 3, 1)
395
+ self.attention_s = QKVAttention(self.num_heads)
396
+ self.attention_t = QKVAttention(self.num_heads)
397
+ if use_relative_position:
398
+ self.relative_position_k = RelativePosition(
399
+ num_units=channels // self.num_heads, max_relative_position=temporal_length
400
+ )
401
+ self.relative_position_v = RelativePosition(
402
+ num_units=channels // self.num_heads, max_relative_position=temporal_length
403
+ )
404
+ self.proj_out_s = zero_module(conv_nd(1, channels, channels, 1))
405
+ self.proj_out_t = zero_module(conv_nd(1, channels, channels, 1))
406
+
407
+ def forward(self, x, mask=None):
408
+ b, c, t, h, w = x.shape
409
+ out = rearrange(x, "b c t h w -> (b t) c (h w)")
410
+ qkv = self.qkv_s(self.norm_s(out))
411
+ out = self.attention_s(qkv)
412
+ out = self.proj_out_s(out)
413
+ out = rearrange(out, "(b t) c (h w) -> b c t h w", b=b, h=h)
414
+ x += out
415
+ out = rearrange(x, "b c t h w -> (b h w) c t")
416
+ qkv = self.qkv_t(self.norm_t(out))
417
+ if self.use_relative_position:
418
+ len_q = qkv.shape[-1]
419
+ len_k, len_v = len_q, len_q
420
+ k_rp = self.relative_position_k(len_q, len_k)
421
+ v_rp = self.relative_position_v(len_q, len_v)
422
+ out = self.attention_t(qkv, rp=(k_rp, v_rp), mask=mask)
423
+ else:
424
+ out = self.attention_t(qkv, rp=None, mask=mask)
425
+ out = self.proj_out_t(out)
426
+ out = rearrange(out, "(b h w) c t -> b c t h w", b=b, h=h, w=w)
427
+ return x + out
428
+
429
+
430
+ class QKVAttention(paddle.nn.Layer):
431
+ def __init__(self, n_heads):
432
+ super().__init__()
433
+ self.n_heads = n_heads
434
+
435
+ def forward(self, qkv, rp=None, mask=None):
436
+ bs, width, length = qkv.shape
437
+ assert width % (3 * self.n_heads) == 0
438
+ ch = width // (3 * self.n_heads)
439
+ q, k, v = qkv.chunk(chunks=3, axis=1)
440
+ scale = 1 / math.sqrt(math.sqrt(ch))
441
+ weight = paddle.einsum(
442
+ "bct,bcs->bts",
443
+ (q * scale).reshape([bs * self.n_heads, ch, length]),
444
+ (k * scale).reshape([bs * self.n_heads, ch, length]),
445
+ )
446
+ if rp is not None:
447
+ k_rp, v_rp = rp
448
+ weight2 = paddle.einsum("bct,tsc->bst", (q * scale).reshape([bs * self.n_heads, ch, length]), k_rp)
449
+ weight += weight2
450
+ if mask is not None:
451
+ INF = -100000000.0
452
+ weight = paddle.where(mask == 0, weight.astype(dtype="float32"), INF)
453
+ weight = paddle.nn.functional.softmax(x=weight.astype(dtype="float32"), axis=-1).astype(weight.dtype)
454
+ a = paddle.einsum("bts,bcs->bct", weight, v.reshape([bs * self.n_heads, ch, length]))
455
+ if rp is not None:
456
+ x = paddle.einsum("bts,tsc->btc", weight, v_rp)
457
+ perm_3 = list(range(x.ndim))
458
+ perm_3[1] = 2
459
+ perm_3[2] = 1
460
+ a2 = x.transpose(perm=perm_3)
461
+ a += a2
462
+ return a.reshape([bs, -1, length])
PaddleMIX/ppdiffusers/ppdiffusers/models/lvdm_unet_3d.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import abstractmethod
16
+ from dataclasses import dataclass
17
+
18
+ import paddle
19
+ from einops import rearrange
20
+ from paddle.distributed.fleet.utils import recompute
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput
24
+ from .lvdm_attention_temporal import SpatialTemporalTransformer, STAttentionBlock
25
+ from .lvdm_util import (
26
+ avg_pool_nd,
27
+ conv_nd,
28
+ linear,
29
+ nonlinearity,
30
+ normalization,
31
+ timestep_embedding,
32
+ zero_module,
33
+ )
34
+ from .modeling_utils import ModelMixin
35
+
36
+
37
+ @dataclass
38
+ class LVDMUNet3DModelOutput(BaseOutput):
39
+ """
40
+ Args:
41
+ sample (`paddle.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
42
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
43
+ """
44
+
45
+ sample: paddle.Tensor
46
+
47
+
48
+ def convert_module_to_f16(x):
49
+ pass
50
+
51
+
52
+ def convert_module_to_f32(x):
53
+ pass
54
+
55
+
56
+ class TimestepBlock(paddle.nn.Layer):
57
+ """
58
+ Any module where forward() takes timestep embeddings as a second argument.
59
+ """
60
+
61
+ @abstractmethod
62
+ def forward(self, x, emb):
63
+ """
64
+ Apply the module to `x` given `emb` timestep embeddings.
65
+ """
66
+
67
+
68
+ class TimestepEmbedSequential(paddle.nn.Sequential, TimestepBlock):
69
+ """
70
+ A sequential module that passes timestep embeddings to the children that
71
+ support it as an extra input.
72
+ """
73
+
74
+ def forward(self, x, emb, context=None, **kwargs):
75
+ for layer in self:
76
+ if isinstance(layer, TimestepBlock):
77
+ x = layer(x, emb, **kwargs)
78
+ # elif isinstance(layer, STTransformerClass):
79
+ elif isinstance(layer, SpatialTemporalTransformer):
80
+ x = layer(x, context, **kwargs)
81
+ else:
82
+ x = layer(x)
83
+ return x
84
+
85
+
86
+ class Upsample(paddle.nn.Layer):
87
+ """
88
+ An upsampling layer with an optional convolution.
89
+ :param channels: channels in the inputs and outputs.
90
+ :param use_conv: a bool determining if a convolution is applied.
91
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
92
+ upsampling occurs in the inner-two dimensions.
93
+ """
94
+
95
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, kernel_size_t=3, padding_t=1):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.out_channels = out_channels or channels
99
+ self.use_conv = use_conv
100
+ self.dims = dims
101
+ if use_conv:
102
+ self.conv = conv_nd(
103
+ dims, self.channels, self.out_channels, (kernel_size_t, 3, 3), padding=(padding_t, 1, 1)
104
+ )
105
+
106
+ def forward(self, x):
107
+ assert x.shape[1] == self.channels
108
+ if self.dims == 3:
109
+ x = paddle.nn.functional.interpolate(
110
+ x=x, size=(x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest", data_format="NCDHW"
111
+ )
112
+ else:
113
+ x = paddle.nn.functional.interpolate(x=x, scale_factor=2, mode="nearest")
114
+ if self.use_conv:
115
+ x = self.conv(x)
116
+ return x
117
+
118
+
119
+ class Downsample(paddle.nn.Layer):
120
+ """
121
+ A downsampling layer with an optional convolution.
122
+ :param channels: channels in the inputs and outputs.
123
+ :param use_conv: a bool determining if a convolution is applied.
124
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
125
+ downsampling occurs in the inner-two dimensions.
126
+ """
127
+
128
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, kernel_size_t=3, padding_t=1):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.out_channels = out_channels or channels
132
+ self.use_conv = use_conv
133
+ self.dims = dims
134
+ stride = 2 if dims != 3 else (1, 2, 2)
135
+ if use_conv:
136
+ self.op = conv_nd(
137
+ dims, self.channels, self.out_channels, (kernel_size_t, 3, 3), stride=stride, padding=(padding_t, 1, 1)
138
+ )
139
+ else:
140
+ assert self.channels == self.out_channels
141
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
142
+
143
+ def forward(self, x):
144
+ assert x.shape[1] == self.channels
145
+ return self.op(x)
146
+
147
+
148
+ class ResBlock(TimestepBlock):
149
+ """
150
+ A residual block that can optionally change the number of channels.
151
+ :param channels: the number of input channels.
152
+ :param emb_channels: the number of timestep embedding channels.
153
+ :param dropout: the rate of dropout.
154
+ :param out_channels: if specified, the number of out channels.
155
+ :param use_conv: if True and out_channels is specified, use a spatial
156
+ convolution instead of a smaller 1x1 convolution to change the
157
+ channels in the skip connection.
158
+ :param dims: determines if the signal is 1D, 2D, or 3D.
159
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
160
+ :param up: if True, use this block for upsampling.
161
+ :param down: if True, use this block for downsampling.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ channels,
167
+ emb_channels,
168
+ dropout,
169
+ out_channels=None,
170
+ use_conv=False,
171
+ use_scale_shift_norm=False,
172
+ dims=2,
173
+ use_checkpoint=False,
174
+ up=False,
175
+ down=False,
176
+ kernel_size_t=3,
177
+ padding_t=1,
178
+ nonlinearity_type="silu",
179
+ **kwargs
180
+ ):
181
+ super().__init__()
182
+ self.channels = channels
183
+ self.emb_channels = emb_channels
184
+ self.dropout = dropout
185
+ self.out_channels = out_channels or channels
186
+ self.use_conv = use_conv
187
+ self.use_checkpoint = use_checkpoint
188
+ self.use_scale_shift_norm = use_scale_shift_norm
189
+ self.nonlinearity_type = nonlinearity_type
190
+ self.in_layers = paddle.nn.Sequential(
191
+ normalization(channels),
192
+ nonlinearity(nonlinearity_type),
193
+ conv_nd(dims, channels, self.out_channels, (kernel_size_t, 3, 3), padding=(padding_t, 1, 1)),
194
+ )
195
+ self.updown = up or down
196
+ if up:
197
+ self.h_upd = Upsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
198
+ self.x_upd = Upsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
199
+ elif down:
200
+ self.h_upd = Downsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
201
+ self.x_upd = Downsample(channels, False, dims, kernel_size_t=kernel_size_t, padding_t=padding_t)
202
+ else:
203
+ self.h_upd = self.x_upd = paddle.nn.Identity()
204
+ self.emb_layers = paddle.nn.Sequential(
205
+ nonlinearity(nonlinearity_type),
206
+ linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels),
207
+ )
208
+ self.out_layers = paddle.nn.Sequential(
209
+ normalization(self.out_channels),
210
+ nonlinearity(nonlinearity_type),
211
+ paddle.nn.Dropout(p=dropout),
212
+ zero_module(
213
+ conv_nd(dims, self.out_channels, self.out_channels, (kernel_size_t, 3, 3), padding=(padding_t, 1, 1))
214
+ ),
215
+ )
216
+ if self.out_channels == channels:
217
+ self.skip_connection = paddle.nn.Identity()
218
+ elif use_conv:
219
+ self.skip_connection = conv_nd(
220
+ dims, channels, self.out_channels, (kernel_size_t, 3, 3), padding=(padding_t, 1, 1)
221
+ )
222
+ else:
223
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
224
+
225
+ def forward(self, x, emb, **kwargs):
226
+ """
227
+ Apply the block to a Tensor, conditioned on a timestep embedding.
228
+ :param x: an [N x C x ...] Tensor of features.
229
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
230
+ :return: an [N x C x ...] Tensor of outputs.
231
+ """
232
+ if self.use_checkpoint:
233
+ return recompute(self._forward, x, emb)
234
+ else:
235
+ return self._forward(x, emb)
236
+
237
+ def _forward(self, x, emb):
238
+ if self.updown:
239
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
240
+ h = in_rest(x)
241
+ h = self.h_upd(h)
242
+ x = self.x_upd(x)
243
+ h = in_conv(h)
244
+ else:
245
+ h = self.in_layers(x)
246
+ emb_out = self.emb_layers(emb).astype(h.dtype)
247
+ if emb_out.dim() == 3:
248
+ emb_out = rearrange(emb_out, "b t c -> b c t")
249
+ while len(emb_out.shape) < h.dim():
250
+ emb_out = emb_out[..., None]
251
+ if self.use_scale_shift_norm:
252
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
253
+ scale, shift = paddle.chunk(x=emb_out, chunks=2, axis=1)
254
+ h = out_norm(h) * (1 + scale) + shift
255
+ h = out_rest(h)
256
+ else:
257
+ h = h + emb_out
258
+ h = self.out_layers(h)
259
+ out = self.skip_connection(x) + h
260
+ return out
261
+
262
+
263
+ # def make_spatialtemporal_transformer(module_name='attention_temporal',
264
+ # class_name='SpatialTemporalTransformer'):
265
+ # module = __import__(f'.lvdm_attention_temporal', fromlist=[
266
+ # class_name])
267
+ # global STTransformerClass
268
+ # STTransformerClass = getattr(module, class_name)
269
+ # return STTransformerClass
270
+
271
+
272
+ def make_spatialtemporal_transformer(module_name="attention_temporal", class_name="SpatialTemporalTransformer"):
273
+ # Todo: Support loading more types of transformers
274
+ assert module_name == "attention_temporal" and class_name == "SpatialTemporalTransformer"
275
+ return SpatialTemporalTransformer
276
+
277
+
278
+ class LVDMUNet3DModel(ModelMixin, ConfigMixin):
279
+ """
280
+ The full UNet model with attention and timestep embedding.
281
+ :param in_channels: channels in the input Tensor.
282
+ :param model_channels: base channel count for the model.
283
+ :param out_channels: channels in the output Tensor.
284
+ :param num_res_blocks: number of residual blocks per downsample.
285
+ :param attention_resolutions: a collection of downsample rates at which
286
+ attention will take place. May be a set, list, or tuple.
287
+ For example, if this contains 4, then at 4x downsampling, attention
288
+ will be used.
289
+ :param dropout: the dropout probability.
290
+ :param channel_mult: channel multiplier for each level of the UNet.
291
+ :param conv_resample: if True, use learned convolutions for upsampling and
292
+ downsampling.
293
+ :param dims: determines if the signal is 1D, 2D, or 3D.
294
+ :param num_classes: if specified (as an int), then this model will be
295
+ class-conditional with `num_classes` classes.
296
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
297
+ :param num_heads: the number of attention heads in each attention layer.
298
+ :param num_heads_channels: if specified, ignore num_heads and instead use
299
+ a fixed channel width per attention head.
300
+ :param num_heads_upsample: works with num_heads to set a different number
301
+ of heads for upsampling. Deprecated.
302
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
303
+ :param resblock_updown: use residual blocks for up/downsampling.
304
+ """
305
+
306
+ @register_to_config
307
+ def __init__(
308
+ self,
309
+ image_size,
310
+ in_channels,
311
+ model_channels,
312
+ out_channels,
313
+ num_res_blocks,
314
+ attention_resolutions,
315
+ dropout=0,
316
+ channel_mult=(1, 2, 4, 8),
317
+ conv_resample=True,
318
+ dims=3,
319
+ num_classes=None,
320
+ use_checkpoint=False,
321
+ use_fp16=False,
322
+ num_heads=-1,
323
+ num_head_channels=-1,
324
+ num_heads_upsample=-1,
325
+ use_scale_shift_norm=False,
326
+ resblock_updown=False,
327
+ transformer_depth=1,
328
+ context_dim=None,
329
+ legacy=True,
330
+ kernel_size_t=1,
331
+ padding_t=1,
332
+ use_temporal_transformer=False,
333
+ temporal_length=None,
334
+ use_relative_position=False,
335
+ nonlinearity_type="silu",
336
+ ST_transformer_module="attention_temporal",
337
+ ST_transformer_class="SpatialTemporalTransformer",
338
+ **kwargs
339
+ ):
340
+ super().__init__()
341
+ if use_temporal_transformer:
342
+ assert (
343
+ context_dim is not None
344
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
345
+ if context_dim is not None:
346
+ assert (
347
+ use_temporal_transformer
348
+ ), "Fool!! You forgot to use the temporal transformer for your cross-attention conditioning..."
349
+ from omegaconf.listconfig import ListConfig
350
+
351
+ if type(context_dim) == ListConfig:
352
+ context_dim = list(context_dim)
353
+ if num_heads_upsample == -1:
354
+ num_heads_upsample = num_heads
355
+ if num_heads == -1:
356
+ assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
357
+ if num_head_channels == -1:
358
+ assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
359
+ self.image_size = image_size
360
+ self.in_channels = in_channels
361
+ self.model_channels = model_channels
362
+ self.out_channels = out_channels
363
+ self.num_res_blocks = num_res_blocks
364
+ self.attention_resolutions = attention_resolutions
365
+ self.dropout = dropout
366
+ self.channel_mult = channel_mult
367
+ self.conv_resample = conv_resample
368
+ self.num_classes = num_classes
369
+ self.use_checkpoint = use_checkpoint
370
+ # Todo: support customted self.dtype
371
+ # self.dtype = 'float16' if use_fp16 else 'float32'
372
+ self.num_heads = num_heads
373
+ self.num_head_channels = num_head_channels
374
+ self.num_heads_upsample = num_heads_upsample
375
+ self.use_relative_position = use_relative_position
376
+ self.temporal_length = temporal_length
377
+ self.nonlinearity_type = nonlinearity_type
378
+ time_embed_dim = model_channels * 4
379
+ self.time_embed_dim = time_embed_dim
380
+ self.time_embed = paddle.nn.Sequential(
381
+ linear(model_channels, time_embed_dim),
382
+ nonlinearity(nonlinearity_type),
383
+ linear(time_embed_dim, time_embed_dim),
384
+ )
385
+ if self.num_classes is not None:
386
+ self.label_emb = paddle.nn.Embedding(num_classes, time_embed_dim)
387
+ STTransformerClass = make_spatialtemporal_transformer(
388
+ module_name=ST_transformer_module, class_name=ST_transformer_class
389
+ )
390
+ self.input_blocks = paddle.nn.LayerList(
391
+ sublayers=[
392
+ TimestepEmbedSequential(
393
+ conv_nd(dims, in_channels, model_channels, (kernel_size_t, 3, 3), padding=(padding_t, 1, 1))
394
+ )
395
+ ]
396
+ )
397
+ self._feature_size = model_channels
398
+ input_block_chans = [model_channels]
399
+ ch = model_channels
400
+ ds = 1
401
+ for level, mult in enumerate(channel_mult):
402
+ for _ in range(num_res_blocks):
403
+ layers = [
404
+ ResBlock(
405
+ ch,
406
+ time_embed_dim,
407
+ dropout,
408
+ out_channels=mult * model_channels,
409
+ dims=dims,
410
+ use_checkpoint=use_checkpoint,
411
+ use_scale_shift_norm=use_scale_shift_norm,
412
+ kernel_size_t=kernel_size_t,
413
+ padding_t=padding_t,
414
+ nonlinearity_type=nonlinearity_type,
415
+ **kwargs,
416
+ )
417
+ ]
418
+ ch = mult * model_channels
419
+ if ds in attention_resolutions:
420
+ if num_head_channels == -1:
421
+ dim_head = ch // num_heads
422
+ else:
423
+ num_heads = ch // num_head_channels
424
+ dim_head = num_head_channels
425
+ if legacy:
426
+ dim_head = ch // num_heads if use_temporal_transformer else num_head_channels
427
+ layers.append(
428
+ STAttentionBlock(
429
+ ch,
430
+ use_checkpoint=use_checkpoint,
431
+ num_heads=num_heads,
432
+ num_head_channels=dim_head,
433
+ temporal_length=temporal_length,
434
+ use_relative_position=use_relative_position,
435
+ )
436
+ if not use_temporal_transformer
437
+ else STTransformerClass(
438
+ ch,
439
+ num_heads,
440
+ dim_head,
441
+ depth=transformer_depth,
442
+ context_dim=context_dim,
443
+ temporal_length=temporal_length,
444
+ use_relative_position=use_relative_position,
445
+ **kwargs,
446
+ )
447
+ )
448
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
449
+ self._feature_size += ch
450
+ input_block_chans.append(ch)
451
+ if level != len(channel_mult) - 1:
452
+ out_ch = ch
453
+ self.input_blocks.append(
454
+ TimestepEmbedSequential(
455
+ ResBlock(
456
+ ch,
457
+ time_embed_dim,
458
+ dropout,
459
+ out_channels=out_ch,
460
+ dims=dims,
461
+ use_checkpoint=use_checkpoint,
462
+ use_scale_shift_norm=use_scale_shift_norm,
463
+ down=True,
464
+ kernel_size_t=kernel_size_t,
465
+ padding_t=padding_t,
466
+ nonlinearity_type=nonlinearity_type,
467
+ **kwargs,
468
+ )
469
+ if resblock_updown
470
+ else Downsample(
471
+ ch,
472
+ conv_resample,
473
+ dims=dims,
474
+ out_channels=out_ch,
475
+ kernel_size_t=kernel_size_t,
476
+ padding_t=padding_t,
477
+ )
478
+ )
479
+ )
480
+ ch = out_ch
481
+ input_block_chans.append(ch)
482
+ ds *= 2
483
+ self._feature_size += ch
484
+ if num_head_channels == -1:
485
+ dim_head = ch // num_heads
486
+ else:
487
+ num_heads = ch // num_head_channels
488
+ dim_head = num_head_channels
489
+ if legacy:
490
+ dim_head = ch // num_heads if use_temporal_transformer else num_head_channels
491
+ self.middle_block = TimestepEmbedSequential(
492
+ ResBlock(
493
+ ch,
494
+ time_embed_dim,
495
+ dropout,
496
+ dims=dims,
497
+ use_checkpoint=use_checkpoint,
498
+ use_scale_shift_norm=use_scale_shift_norm,
499
+ kernel_size_t=kernel_size_t,
500
+ padding_t=padding_t,
501
+ nonlinearity_type=nonlinearity_type,
502
+ **kwargs,
503
+ ),
504
+ STAttentionBlock(
505
+ ch,
506
+ use_checkpoint=use_checkpoint,
507
+ num_heads=num_heads,
508
+ num_head_channels=dim_head,
509
+ temporal_length=temporal_length,
510
+ use_relative_position=use_relative_position,
511
+ )
512
+ if not use_temporal_transformer
513
+ else STTransformerClass(
514
+ ch,
515
+ num_heads,
516
+ dim_head,
517
+ depth=transformer_depth,
518
+ context_dim=context_dim,
519
+ temporal_length=temporal_length,
520
+ use_relative_position=use_relative_position,
521
+ **kwargs,
522
+ ),
523
+ ResBlock(
524
+ ch,
525
+ time_embed_dim,
526
+ dropout,
527
+ dims=dims,
528
+ use_checkpoint=use_checkpoint,
529
+ use_scale_shift_norm=use_scale_shift_norm,
530
+ kernel_size_t=kernel_size_t,
531
+ padding_t=padding_t,
532
+ nonlinearity_type=nonlinearity_type,
533
+ **kwargs,
534
+ ),
535
+ )
536
+ self._feature_size += ch
537
+ self.output_blocks = paddle.nn.LayerList(sublayers=[])
538
+ for level, mult in list(enumerate(channel_mult))[::-1]:
539
+ for i in range(num_res_blocks + 1):
540
+ ich = input_block_chans.pop()
541
+ layers = [
542
+ ResBlock(
543
+ ch + ich,
544
+ time_embed_dim,
545
+ dropout,
546
+ out_channels=model_channels * mult,
547
+ dims=dims,
548
+ use_checkpoint=use_checkpoint,
549
+ use_scale_shift_norm=use_scale_shift_norm,
550
+ kernel_size_t=kernel_size_t,
551
+ padding_t=padding_t,
552
+ nonlinearity_type=nonlinearity_type,
553
+ **kwargs,
554
+ )
555
+ ]
556
+ ch = model_channels * mult
557
+ if ds in attention_resolutions:
558
+ if num_head_channels == -1:
559
+ dim_head = ch // num_heads
560
+ else:
561
+ num_heads = ch // num_head_channels
562
+ dim_head = num_head_channels
563
+ if legacy:
564
+ dim_head = ch // num_heads if use_temporal_transformer else num_head_channels
565
+ layers.append(
566
+ STAttentionBlock(
567
+ ch,
568
+ use_checkpoint=use_checkpoint,
569
+ num_heads=num_heads,
570
+ num_head_channels=dim_head,
571
+ temporal_length=temporal_length,
572
+ use_relative_position=use_relative_position,
573
+ )
574
+ if not use_temporal_transformer
575
+ else STTransformerClass(
576
+ ch,
577
+ num_heads,
578
+ dim_head,
579
+ depth=transformer_depth,
580
+ context_dim=context_dim,
581
+ temporal_length=temporal_length,
582
+ use_relative_position=use_relative_position,
583
+ **kwargs,
584
+ )
585
+ )
586
+ if level and i == num_res_blocks:
587
+ out_ch = ch
588
+ layers.append(
589
+ ResBlock(
590
+ ch,
591
+ time_embed_dim,
592
+ dropout,
593
+ out_channels=out_ch,
594
+ dims=dims,
595
+ use_checkpoint=use_checkpoint,
596
+ use_scale_shift_norm=use_scale_shift_norm,
597
+ up=True,
598
+ kernel_size_t=kernel_size_t,
599
+ padding_t=padding_t,
600
+ nonlinearity_type=nonlinearity_type,
601
+ **kwargs,
602
+ )
603
+ if resblock_updown
604
+ else Upsample(
605
+ ch,
606
+ conv_resample,
607
+ dims=dims,
608
+ out_channels=out_ch,
609
+ kernel_size_t=kernel_size_t,
610
+ padding_t=padding_t,
611
+ )
612
+ )
613
+ ds //= 2
614
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
615
+ self._feature_size += ch
616
+ self.out = paddle.nn.Sequential(
617
+ normalization(ch),
618
+ nonlinearity(nonlinearity_type),
619
+ zero_module(conv_nd(dims, model_channels, out_channels, (kernel_size_t, 3, 3), padding=(padding_t, 1, 1))),
620
+ )
621
+
622
+ def convert_to_fp16(self):
623
+ """
624
+ Convert the torso of the model to float16.
625
+ """
626
+ self.input_blocks.apply(fn=convert_module_to_f16)
627
+ self.middle_block.apply(fn=convert_module_to_f16)
628
+ self.output_blocks.apply(fn=convert_module_to_f16)
629
+
630
+ def convert_to_fp32(self):
631
+ """
632
+ Convert the torso of the model to float32.
633
+ """
634
+ self.input_blocks.apply(fn=convert_module_to_f32)
635
+ self.middle_block.apply(fn=convert_module_to_f32)
636
+ self.output_blocks.apply(fn=convert_module_to_f32)
637
+
638
+ def forward(self, x, timesteps=None, time_emb_replace=None, context=None, y=None, **kwargs):
639
+ """
640
+ Apply the model to an input batch.
641
+ :param x: an [N x C x ...] Tensor of inputs.
642
+ :param timesteps: a 1-D batch of timesteps.
643
+ :param context: conditioning plugged in via crossattn
644
+ :param y: an [N] Tensor of labels, if class-conditional.
645
+ :return: an [N x C x ...] Tensor of outputs.
646
+ """
647
+ # Fix 0D tensor bug
648
+ if timesteps.ndim == 0:
649
+ timesteps = timesteps.unsqueeze(0)
650
+ hs = []
651
+ if time_emb_replace is None:
652
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
653
+ emb = self.time_embed(t_emb)
654
+ else:
655
+ emb = time_emb_replace
656
+ if y is not None:
657
+ assert y.shape == (x.shape[0],)
658
+ emb = emb + self.label_emb(y)
659
+ h = x.astype(self.dtype)
660
+ for module in self.input_blocks:
661
+ h = module(h, emb, context, **kwargs)
662
+ hs.append(h)
663
+ h = self.middle_block(h, emb, context, **kwargs)
664
+ for module in self.output_blocks:
665
+ h = paddle.concat(x=[h, hs.pop()], axis=1)
666
+ h = module(h, emb, context, **kwargs)
667
+ h = h.astype(x.dtype)
668
+ h = self.out(h)
669
+
670
+ return LVDMUNet3DModelOutput(sample=h)
671
+
672
+
673
+ class FrameInterpPredUNet(LVDMUNet3DModel):
674
+ """
675
+ A Unet for unconditional generation, frame prediction and interpolation.
676
+ may need to input `mask` to indicate condition, as well as noise level `s` for condition augmentation.
677
+ """
678
+
679
+ def __init__(self, image_size, in_channels, cond_aug_mode=None, *args, **kwargs):
680
+ super().__init__(image_size, in_channels, *args, **kwargs)
681
+ if cond_aug_mode == "time_embed":
682
+ self.time_embed_cond = paddle.nn.Sequential(
683
+ linear(self.model_channels, self.time_embed_dim),
684
+ nonlinearity(self.nonlinearity_type),
685
+ linear(self.time_embed_dim, self.time_embed_dim),
686
+ )
687
+ elif cond_aug_mode == "learned_embed":
688
+ pass
689
+
690
+ def forward(self, x, timesteps, context=None, y=None, s=None, mask=None, **kwargs):
691
+ # Fix 0D tensor bug
692
+ if timesteps.ndim == 0:
693
+ timesteps = timesteps.unsqueeze(0)
694
+ if s is not None:
695
+ s_emb = timestep_embedding(s, self.model_channels, repeat_only=False)
696
+ s_emb = self.time_embed_cond(s_emb)
697
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
698
+ emb = self.time_embed(t_emb)
699
+ assert emb.dim() == 2
700
+ mask_ = mask[:, :, :, (0), (0)]
701
+ t = mask.shape[2]
702
+ emb_mix = (
703
+ emb.unsqueeze(axis=2).tile(repeat_times=[1, 1, t]) * (1 - mask_)
704
+ + s_emb.unsqueeze(axis=2).tile(repeat_times=[1, 1, t]) * mask_
705
+ )
706
+ assert emb_mix.dim() == 3
707
+ emb_mix = rearrange(emb_mix, "b c t -> b t c")
708
+ time_emb_replace = emb_mix
709
+ timesteps = None
710
+ else:
711
+ time_emb_replace = None
712
+ timesteps = timesteps
713
+ return super().forward(x, timesteps, time_emb_replace=time_emb_replace, context=context, y=y, **kwargs)
PaddleMIX/ppdiffusers/ppdiffusers/models/lvdm_util.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import random
17
+ from inspect import isfunction
18
+
19
+ import numpy as np
20
+ import paddle
21
+ from einops import repeat
22
+
23
+
24
+ def make_interp_mask_with_bothsidescond(t, device, n_interp1, n_interp2):
25
+ """1: cond frames
26
+ 0: generated frames
27
+ """
28
+ mask = paddle.zeros(shape=[t])
29
+ mask[:n_interp1] = 1
30
+ mask[t - n_interp2 :] = 1
31
+ return mask
32
+
33
+
34
+ def make_interp_mask_with_framestride(t, device, frame_stride):
35
+ """1: cond frames
36
+ 0: generated frames
37
+ """
38
+ mask = paddle.zeros(shape=[t])
39
+ for i in range(0, t, frame_stride):
40
+ mask[i] = 1
41
+ return mask
42
+
43
+
44
+ def random_temporal_masking(
45
+ input_shape, p_interp, p_pred, device, n_interp1=1, n_interp2=1, n_prevs=[1], interp_frame_stride=None
46
+ ):
47
+ """return mask for masking input, where 1 indicates given real image as condition,
48
+ 0 indicates noisy samples.
49
+ """
50
+ if p_pred == 0.0:
51
+ n_prevs = None
52
+ b, c, t, h, w = input_shape
53
+ mask = paddle.zeros(shape=[b, t])
54
+ for i in range(b):
55
+ r = random.random()
56
+ if r < p_interp:
57
+ if interp_frame_stride is not None:
58
+ mask[i] = make_interp_mask_with_framestride(t, device, interp_frame_stride)
59
+ else:
60
+ mask[i] = make_interp_mask_with_bothsidescond(t, device, n_interp1, n_interp2)
61
+ elif p_interp <= r < p_interp + p_pred:
62
+ n_pred = random.choice(n_prevs)
63
+ mask[(i), :n_pred] = 1
64
+ else:
65
+ pass
66
+ mask = mask.unsqueeze(axis=1).unsqueeze(axis=3).unsqueeze(axis=4)
67
+ mask = mask.tile(repeat_times=[1, 1, 1, h, w])
68
+ return mask
69
+
70
+
71
+ def make_beta_schedule(schedule, n_timestep, linear_start=0.0001, linear_end=0.02, cosine_s=0.008):
72
+ if schedule == "linear":
73
+ betas = (
74
+ paddle.linspace(start=linear_start**0.5, stop=linear_end**0.5, num=n_timestep).astype("float64") ** 2
75
+ )
76
+ elif schedule == "cosine":
77
+ timesteps = paddle.arange(end=n_timestep + 1).astype("float64") / n_timestep + cosine_s
78
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
79
+ alphas = paddle.cos(x=alphas).pow(y=2)
80
+ alphas = alphas / alphas[0]
81
+ betas = 1 - alphas[1:] / alphas[:-1]
82
+ betas = np.clip(betas, a_min=0, a_max=0.999)
83
+ elif schedule == "sqrt_linear":
84
+ betas = paddle.linspace(start=linear_start, stop=linear_end, num=n_timestep).astype("float64")
85
+ elif schedule == "sqrt":
86
+ betas = paddle.linspace(start=linear_start, stop=linear_end, num=n_timestep).astype("float64") ** 0.5
87
+ else:
88
+ raise ValueError(f"schedule '{schedule}' unknown.")
89
+ return betas.numpy()
90
+
91
+
92
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
93
+ if ddim_discr_method == "uniform":
94
+ c = num_ddpm_timesteps // num_ddim_timesteps
95
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
96
+ elif ddim_discr_method == "quad":
97
+ ddim_timesteps = (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps) ** 2).astype(int)
98
+ else:
99
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
100
+ steps_out = ddim_timesteps + 1
101
+ if verbose:
102
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
103
+ return steps_out
104
+
105
+
106
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
107
+ alphas = alphacums[ddim_timesteps]
108
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
109
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
110
+ if verbose:
111
+ print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}")
112
+ print(
113
+ f"For the chosen value of eta, which is {eta}, this results in the following sigma_t schedule for ddim sampler {sigmas}"
114
+ )
115
+ return sigmas, alphas, alphas_prev
116
+
117
+
118
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
119
+ """
120
+ Create a beta schedule that discretizes the given alpha_t_bar function,
121
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
122
+ :param num_diffusion_timesteps: the number of betas to produce.
123
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
124
+ produces the cumulative product of (1-beta) up to that
125
+ part of the diffusion process.
126
+ :param max_beta: the maximum beta to use; use values lower than 1 to
127
+ prevent singularities.
128
+ """
129
+ betas = []
130
+ for i in range(num_diffusion_timesteps):
131
+ t1 = i / num_diffusion_timesteps
132
+ t2 = (i + 1) / num_diffusion_timesteps
133
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
134
+ return np.array(betas)
135
+
136
+
137
+ def extract_into_tensor(a, t, x_shape):
138
+ b, *_ = t.shape
139
+ out = a.take_along_axis(axis=-1, indices=t)
140
+ return out.reshape([b, *((1,) * (len(x_shape) - 1))])
141
+
142
+
143
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
144
+ """
145
+ Create sinusoidal timestep embeddings.
146
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
147
+ These may be fractional.
148
+ :param dim: the dimension of the output.
149
+ :param max_period: controls the minimum frequency of the embeddings.
150
+ :return: an [N x dim] Tensor of positional embeddings.
151
+ """
152
+ if not repeat_only:
153
+ half = dim // 2
154
+ freqs = paddle.exp(
155
+ x=(-math.log(max_period) * paddle.arange(start=0, end=half).astype("float32") / half).astype("float32")
156
+ )
157
+ args = timesteps[:, (None)].astype(dtype="float32") * freqs[None]
158
+ embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1)
159
+ if dim % 2:
160
+ embedding = paddle.concat(x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1)
161
+ else:
162
+ embedding = repeat(timesteps, "b -> b d", d=dim)
163
+ return embedding
164
+
165
+
166
+ def zero_module(module):
167
+ """
168
+ Zero out the parameters of a module and return it.
169
+ """
170
+ for p in module.parameters():
171
+ p.detach().zero_()
172
+ return module
173
+
174
+
175
+ def scale_module(module, scale):
176
+ """
177
+ Scale the parameters of a module and return it.
178
+ """
179
+ for p in module.parameters():
180
+ p.detach().scale_(scale=scale)
181
+ return module
182
+
183
+
184
+ def mean_flat(tensor):
185
+ """
186
+ Take the mean over all non-batch dimensions.
187
+ """
188
+ return tensor.mean(axis=list(range(1, len(tensor.shape))))
189
+
190
+
191
+ def normalization(channels):
192
+ """
193
+ Make a standard normalization layer.
194
+ :param channels: number of input channels.
195
+ :return: an nn.Module for normalization.
196
+ """
197
+ return GroupNorm32(32, channels)
198
+
199
+
200
+ def Normalize(in_channels):
201
+ return paddle.nn.GroupNorm(
202
+ num_groups=32, num_channels=in_channels, epsilon=1e-06, weight_attr=None, bias_attr=None
203
+ )
204
+
205
+
206
+ def identity(*args, **kwargs):
207
+ return paddle.nn.Identity()
208
+
209
+
210
+ def nonlinearity(type="silu"):
211
+ if type == "silu":
212
+ return paddle.nn.Silu()
213
+ elif type == "leaky_relu":
214
+ return paddle.nn.LeakyReLU()
215
+
216
+
217
+ class GEGLU(paddle.nn.Layer):
218
+ def __init__(self, dim_in, dim_out):
219
+ super().__init__()
220
+ self.proj = paddle.nn.Linear(in_features=dim_in, out_features=dim_out * 2)
221
+
222
+ def forward(self, x):
223
+ x, gate = self.proj(x).chunk(chunks=2, axis=-1)
224
+ return x * paddle.nn.functional.gelu(x=gate)
225
+
226
+
227
+ class SiLU(paddle.nn.Layer):
228
+ def forward(self, x):
229
+ return x * paddle.nn.functional.sigmoid(x=x)
230
+
231
+
232
+ class GroupNorm32(paddle.nn.GroupNorm):
233
+ def forward(self, x):
234
+ return super().forward(x.astype(dtype="float32")).astype(x.dtype)
235
+
236
+
237
+ def conv_nd(dims, *args, **kwargs):
238
+ """
239
+ Create a 1D, 2D, or 3D convolution module.
240
+ """
241
+ if dims == 1:
242
+ return paddle.nn.Conv1D(*args, **kwargs)
243
+ elif dims == 2:
244
+ return paddle.nn.Conv2D(*args, **kwargs)
245
+ elif dims == 3:
246
+ return paddle.nn.Conv3D(*args, **kwargs)
247
+ raise ValueError(f"unsupported dimensions: {dims}")
248
+
249
+
250
+ def linear(*args, **kwargs):
251
+ """
252
+ Create a linear module.
253
+ """
254
+ return paddle.nn.Linear(*args, **kwargs)
255
+
256
+
257
+ def avg_pool_nd(dims, *args, **kwargs):
258
+ """
259
+ Create a 1D, 2D, or 3D average pooling module.
260
+ """
261
+ if dims == 1:
262
+ return paddle.nn.AvgPool1D(*args, **kwargs, exclusive=False)
263
+ elif dims == 2:
264
+ return paddle.nn.AvgPool1D(*args, **kwargs, exclusive=False)
265
+ elif dims == 3:
266
+ return paddle.nn.AvgPool1D(*args, **kwargs, exclusive=False)
267
+ raise ValueError(f"unsupported dimensions: {dims}")
268
+
269
+
270
+ def noise_like(shape, device, repeat=False):
271
+ repeat_noise = lambda: paddle.randn(shape=(1, *shape[1:])).tile(
272
+ repeat_times=[shape[0], *((1,) * (len(shape) - 1))]
273
+ )
274
+ noise = lambda: paddle.randn(shape=shape)
275
+ return repeat_noise() if repeat else noise()
276
+
277
+
278
+ def init_(tensor):
279
+ dim = tensor.shape[-1]
280
+ std = 1 / math.sqrt(dim)
281
+ tensor.uniform_(min=-std, max=std)
282
+ return tensor
283
+
284
+
285
+ def exists(val):
286
+ return val is not None
287
+
288
+
289
+ def uniq(arr):
290
+ return {el: (True) for el in arr}.keys()
291
+
292
+
293
+ def default(val, d):
294
+ if exists(val):
295
+ return val
296
+ return d() if isfunction(d) else d
PaddleMIX/ppdiffusers/ppdiffusers/models/modeling_pytorch_paddle_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import paddle.nn as nn
17
+
18
+ from ..utils import logging
19
+ from ..utils.import_utils import is_torch_available
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ if is_torch_available():
24
+ import torch
25
+
26
+
27
+ def convert_pytorch_state_dict_to_paddle(self: nn.Layer, pt_state_dict, sub_layer=None):
28
+ # Step 1: Find Linear layer which need transpose weight
29
+ linear_need_transpose = []
30
+ for k, v in self.named_sublayers(include_self=True):
31
+ if isinstance(v, nn.Linear):
32
+ if sub_layer is not None and sub_layer not in k:
33
+ continue
34
+ linear_need_transpose.append(k + ".weight")
35
+
36
+ ignore_keys = ["position_ids", ".num_batches_tracked"]
37
+ ptname2pdname = {
38
+ # torch.nn.BatchNorm2d -> paddle.nn.BatchNorm2D
39
+ ".running_var": "._variance",
40
+ ".running_mean": "._mean",
41
+ }
42
+ # Need to change some parameters name to match paddle names
43
+ keys = list(pt_state_dict.keys())
44
+ for pt_key in keys:
45
+ pt_tensor = pt_state_dict.pop(pt_key)
46
+ # only convert sub_layer state dict
47
+ if sub_layer is not None and sub_layer not in pt_key:
48
+ continue
49
+ # (0) ignore_keys
50
+ if any(i in pt_key for i in ignore_keys):
51
+ continue
52
+ # (1) transpose linear
53
+ if pt_key in linear_need_transpose and pt_tensor.ndim == 2:
54
+ pt_tensor = pt_tensor.T
55
+ # (2) 0d tensor -> 1d tensor
56
+ # if pt_tensor.ndim == 0:
57
+ # pt_tensor = pt_tensor.reshape((1,))
58
+ # (3) name mapping
59
+ for old_key, new_key in ptname2pdname.items():
60
+ pt_key = pt_key.replace(old_key, new_key)
61
+
62
+ pt_state_dict[pt_key] = pt_tensor
63
+ return pt_state_dict
64
+
65
+
66
+ def convert_paddle_state_dict_to_pytorch(self: nn.Layer, pd_state_dict):
67
+ # Step 2: Find Linear layer which need transpose weight
68
+ linear_need_transpose = []
69
+ for k, v in self.named_sublayers(include_self=True):
70
+ if isinstance(v, nn.Linear):
71
+ linear_need_transpose.append(k + ".weight")
72
+
73
+ ignore_keys = ["position_ids"]
74
+ ptname2pdname = {
75
+ # torch.nn.BatchNorm2d -> paddle.nn.BatchNorm2D
76
+ ".running_var": "._variance",
77
+ ".running_mean": "._mean",
78
+ }
79
+ keys = list(pd_state_dict.keys())
80
+ detect_bfloat16 = False
81
+ for pd_key in keys:
82
+ pd_tensor = pd_state_dict.pop(pd_key)
83
+ # (0) ignore_keys
84
+ if any(i in pd_key for i in ignore_keys):
85
+ continue
86
+ # (1) transpose linear
87
+ if pd_key in linear_need_transpose and pd_tensor.ndim == 2:
88
+ pd_tensor = pd_tensor.T
89
+ # TODO maybe not true
90
+ # (2) 1d tensor -> 0d tensor
91
+ if pd_tensor.ndim == 1:
92
+ pd_tensor = pd_tensor.squeeze()
93
+ # (3) name mapping
94
+ for old_key, new_key in ptname2pdname.items():
95
+ pd_key = pd_key.replace(new_key, old_key)
96
+
97
+ pd_tensor = np.ascontiguousarray(pd_tensor)
98
+
99
+ if is_torch_available():
100
+ if pd_tensor.dtype in ["uint16", np.uint16]:
101
+ pd_tensor = pd_tensor.astype(np.float32)
102
+ pd_state_dict[pd_key] = torch.from_numpy(pd_tensor).to(torch.bfloat16)
103
+ else:
104
+ pd_state_dict[pd_key] = torch.from_numpy(pd_tensor)
105
+ else:
106
+ if pd_tensor.dtype in ["uint16", np.uint16]:
107
+ pd_tensor = pd_tensor.astype(np.float16)
108
+ detect_bfloat16 = True
109
+ pd_state_dict[pd_key] = pd_tensor
110
+
111
+ if detect_bfloat16:
112
+ logger.warning(
113
+ "PyTorch is not installed, so we cannot save as `bfloat16` tensor. "
114
+ "To ensure the model can still be loaded, we will save it as `float16` tensor instead. "
115
+ "Please note that this may affect the precision of the saved model."
116
+ )
117
+ return pd_state_dict
PaddleMIX/ppdiffusers/ppdiffusers/models/modeling_utils.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import gc
18
+ import json
19
+ import os
20
+ from collections import OrderedDict
21
+ from contextlib import ExitStack
22
+ from functools import partial
23
+ from typing import Any, Callable, List, Optional, Union
24
+
25
+ import numpy as np
26
+ from aistudio_sdk.hub import create_repo as aistudio_create_repo
27
+ from huggingface_hub import create_repo
28
+ from paddle import nn
29
+ from tqdm import tqdm
30
+
31
+ from ..utils import (
32
+ CONFIG_NAME,
33
+ DIFFUSERS_CACHE,
34
+ FROM_AISTUDIO,
35
+ FROM_DIFFUSERS,
36
+ FROM_HF_HUB,
37
+ HF_HUB_OFFLINE,
38
+ LOW_CPU_MEM_USAGE_DEFAULT,
39
+ MIN_PEFT_VERSION,
40
+ PADDLE_SAFETENSORS_WEIGHTS_NAME,
41
+ PADDLE_SAFETENSORS_WEIGHTS_NAME_INDEX_NAME,
42
+ PADDLE_WEIGHTS_NAME,
43
+ PADDLE_WEIGHTS_NAME_INDEX_NAME,
44
+ PPDIFFUSERS_CACHE,
45
+ TO_DIFFUSERS,
46
+ TORCH_SAFETENSORS_WEIGHTS_NAME,
47
+ TORCH_SAFETENSORS_WEIGHTS_NAME_INDEX_NAME,
48
+ TORCH_WEIGHTS_NAME,
49
+ TORCH_WEIGHTS_NAME_INDEX_NAME,
50
+ _add_variant,
51
+ _get_model_file,
52
+ check_peft_version,
53
+ deprecate,
54
+ get_checkpoint_shard_files,
55
+ is_paddle_available,
56
+ is_paddle_version,
57
+ is_paddlenlp_available,
58
+ is_safetensors_available,
59
+ is_torch_available,
60
+ logging,
61
+ smart_load,
62
+ )
63
+ from ..version import VERSION as __version__
64
+ from .modeling_pytorch_paddle_utils import (
65
+ convert_paddle_state_dict_to_pytorch,
66
+ convert_pytorch_state_dict_to_paddle,
67
+ )
68
+
69
+ logger = logging.get_logger(__name__)
70
+
71
+ if is_torch_available():
72
+ import torch
73
+
74
+ if is_safetensors_available():
75
+ from safetensors import safe_open
76
+ from safetensors.numpy import save_file as np_safe_save_file
77
+
78
+ if is_torch_available():
79
+ from safetensors.torch import save_file as torch_safe_save_file
80
+
81
+ if is_paddle_available():
82
+ import paddle
83
+
84
+ if is_paddlenlp_available():
85
+ try:
86
+ from paddlenlp.transformers.model_utils import no_init_weights
87
+ except ImportError:
88
+ from ..utils.paddle_utils import no_init_weights
89
+ from paddlenlp.transformers.model_utils import shard_checkpoint
90
+
91
+
92
+ def faster_set_state_dict(model, state_dict):
93
+ # the state_dict will be destroied.
94
+ with paddle.no_grad():
95
+ for k, v in model.state_dict(use_hook=False).items():
96
+ if k in state_dict:
97
+ v_new = state_dict.pop(k)
98
+ # with device_guard(): donot do device guard
99
+ if isinstance(v_new, np.ndarray):
100
+ v_new = paddle.Tensor(v_new, zero_copy=True)
101
+ if v.dtype != v_new.dtype:
102
+ v_new = v_new.cast(v.dtype)
103
+ v.copy_(v_new, False)
104
+ else:
105
+ if (hasattr(v, "_is_initialized") and not v._is_initialized()) or "undefined" in str(v.place):
106
+ v.initialize()
107
+ # logger.warning(f"key {k} is not in state_dict. And it is lazy tensor. We will initialize it.")
108
+
109
+
110
+ class ContextManagers:
111
+ """
112
+ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
113
+ in the `fastcore` library.
114
+ """
115
+
116
+ def __init__(self, context_managers):
117
+ self.context_managers = context_managers
118
+ self.stack = ExitStack()
119
+
120
+ def __enter__(self):
121
+ for context_manager in self.context_managers:
122
+ self.stack.enter_context(context_manager)
123
+
124
+ def __exit__(self, *args, **kwargs):
125
+ self.stack.__exit__(*args, **kwargs)
126
+
127
+
128
+ def get_parameter_device(parameter: nn.Layer):
129
+ try:
130
+ # TODO https://github.com/huggingface/diffusers/compare/v0.15.0...v0.16.0#diff-6a3b9a08c1d37dbc341131632415fea800af242a84fb31f1bcd40d725e2eeeebR64
131
+ return next(parameter.named_parameters())[1].place
132
+ except StopIteration:
133
+ try:
134
+ return next(parameter.named_buffers())[1].place
135
+ except StopIteration:
136
+ return paddle.get_device()
137
+
138
+
139
+ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
140
+ try:
141
+ # TODO https://github.com/huggingface/diffusers/compare/v0.15.0...v0.16.0#diff-6a3b9a08c1d37dbc341131632415fea800af242a84fb31f1bcd40d725e2eeeebR80
142
+ return next(parameter.named_parameters())[1].dtype
143
+ except StopIteration:
144
+ try:
145
+ return next(parameter.named_buffers())[1].dtype
146
+ except StopIteration:
147
+ return parameter._dtype
148
+
149
+
150
+ def load_state_dict(
151
+ checkpoint_file: Union[str, os.PathLike], state_dict, tensor_parallel_split_mapping=None, ignore_keys=None
152
+ ):
153
+ """
154
+ Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
155
+ """
156
+ if tensor_parallel_split_mapping is None:
157
+ tensor_parallel_split_mapping = {}
158
+ data_format = "pd"
159
+ if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
160
+ # Check format of the archive
161
+ with safe_open(checkpoint_file, framework="np") as f:
162
+ metadata = f.metadata()
163
+ if metadata is None:
164
+ metadata = {}
165
+ if metadata.get("format", "pt") not in ["pt", "pd", "np"]:
166
+ raise OSError(
167
+ f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
168
+ "you save your model with the `save_pretrained` method."
169
+ )
170
+ data_format = metadata.get("format", "pt")
171
+ with safe_open(checkpoint_file, framework="np") as f:
172
+ for key in f.keys():
173
+ need_continue = False
174
+ if ignore_keys is not None:
175
+ for ik in ignore_keys:
176
+ if key.startswith(ik):
177
+ logger.info("Deleting key {} from state_dict.".format(key))
178
+ need_continue = True
179
+ break
180
+ if need_continue:
181
+ continue
182
+ if key in tensor_parallel_split_mapping:
183
+ py_safe_slice_ = f.get_slice(key)
184
+ weight = tensor_parallel_split_mapping[key](py_safe_slice_)
185
+ else:
186
+ weight = f.get_tensor(key)
187
+ state_dict[key] = paddle.Tensor(weight, zero_copy=True)
188
+
189
+ else:
190
+ if any(checkpoint_file.endswith(suffix) for suffix in [".pt", ".pth", ".bin", ".ckpt"]):
191
+ data_format = "pt"
192
+
193
+ tmp_state_dict = smart_load(checkpoint_file, return_numpy=True)
194
+ for key in list(tmp_state_dict.keys()):
195
+ need_continue = False
196
+ if ignore_keys is not None:
197
+ for ik in ignore_keys:
198
+ if key.startswith(ik):
199
+ logger.info("Deleting key {} from state_dict.".format(key))
200
+ need_continue = True
201
+ break
202
+ if need_continue:
203
+ continue
204
+ # with device_guard():
205
+ t = tmp_state_dict.pop(key)
206
+ if key in tensor_parallel_split_mapping:
207
+ t = tensor_parallel_split_mapping[key](t)
208
+ if isinstance(t, dict):
209
+ if len(t) == 0:
210
+ state_dict[key] = {}
211
+ else:
212
+ state_dict[key] = paddle.Tensor(t, zero_copy=True)
213
+
214
+ return data_format
215
+
216
+
217
+ class ModelMixin(nn.Layer):
218
+ r"""
219
+ Base class for all models.
220
+
221
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
222
+ saving models.
223
+
224
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
225
+ """
226
+
227
+ config_name = CONFIG_NAME
228
+ _automatically_saved_args = ["_ppdiffusers_version", "_class_name", "_name_or_path"]
229
+ _supports_gradient_checkpointing = False
230
+ _keys_to_ignore_on_load_unexpected = None
231
+ _pp_peft_config_loaded = False
232
+
233
+ def __init__(self):
234
+ super().__init__()
235
+
236
+ def __getattr__(self, name: str) -> Any:
237
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
238
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
239
+ __getattr__ here in addition so that we don't trigger `nn.Layer`'s __getattr__':
240
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
241
+ """
242
+
243
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
244
+ is_attribute = name in self.__dict__
245
+
246
+ if is_in_config and not is_attribute:
247
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
248
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
249
+ return self._internal_dict[name]
250
+
251
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
252
+ return super().__getattr__(name)
253
+
254
+ @property
255
+ def is_gradient_checkpointing(self) -> bool:
256
+ """
257
+ Whether gradient checkpointing is activated for this model or not.
258
+ """
259
+ return any(
260
+ hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing
261
+ for m in self.sublayers(include_self=True)
262
+ )
263
+
264
+ def enable_gradient_checkpointing(self) -> None:
265
+ """
266
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
267
+ *checkpoint activations* in other frameworks).
268
+ """
269
+ if not self._supports_gradient_checkpointing:
270
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
271
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
272
+
273
+ def disable_gradient_checkpointing(self) -> None:
274
+ """
275
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
276
+ *checkpoint activations* in other frameworks).
277
+ """
278
+ if self._supports_gradient_checkpointing:
279
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
280
+
281
+ def set_use_memory_efficient_attention_xformers(self, valid: bool, attention_op: Optional[str] = None) -> None:
282
+ # Recursively walk through all the children.
283
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
284
+ # gets the message
285
+ def fn_recursive_set_mem_eff(module: nn.Layer):
286
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
287
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
288
+
289
+ for child in module.children():
290
+ fn_recursive_set_mem_eff(child)
291
+
292
+ for module in self.children():
293
+ if isinstance(module, nn.Layer):
294
+ fn_recursive_set_mem_eff(module)
295
+
296
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[str] = None) -> None:
297
+ r"""
298
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
299
+
300
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
301
+ inference. Speed up during training is not guaranteed.
302
+
303
+ <Tip warning={true}>
304
+
305
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
306
+ precedent.
307
+
308
+ </Tip>
309
+
310
+ Parameters:
311
+ attention_op (`str`, *optional*):
312
+ Override the default `None`
313
+
314
+ Examples:
315
+
316
+ ```py
317
+ >>> import paddle
318
+ >>> from ppdiffusers import UNet2DConditionModel
319
+
320
+ >>> model = UNet2DConditionModel.from_pretrained(
321
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", paddle_dtype=paddle.float16
322
+ ... )
323
+ >>> model.enable_xformers_memory_efficient_attention(attention_op="auto")
324
+ ```
325
+ """
326
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
327
+
328
+ def disable_xformers_memory_efficient_attention(self) -> None:
329
+ r"""
330
+ Disable memory efficient attention as implemented in xformers.
331
+ """
332
+ self.set_use_memory_efficient_attention_xformers(False)
333
+
334
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
335
+ r"""
336
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
337
+ to the adapter to follow the convention of the PEFT library.
338
+
339
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
340
+ [documentation](https://huggingface.co/docs/peft).
341
+
342
+ Args:
343
+ adapter_config (`[~peft.PeftConfig]`):
344
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
345
+ methods.
346
+ adapter_name (`str`, *optional*, defaults to `"default"`):
347
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
348
+ """
349
+ check_peft_version(min_version=MIN_PEFT_VERSION)
350
+
351
+ from ppdiffusers.peft import PeftConfig, inject_adapter_in_model
352
+
353
+ if not self._pp_peft_config_loaded:
354
+ self._pp_peft_config_loaded = True
355
+ elif adapter_name in self.peft_config:
356
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
357
+
358
+ if not isinstance(adapter_config, PeftConfig):
359
+ raise ValueError(
360
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
361
+ )
362
+
363
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
364
+ # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
365
+ adapter_config.base_model_name_or_path = None
366
+ inject_adapter_in_model(adapter_config, self, adapter_name)
367
+ self.set_adapter(adapter_name)
368
+
369
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
370
+ """
371
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
372
+
373
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
374
+ official documentation: https://huggingface.co/docs/peft
375
+
376
+ Args:
377
+ adapter_name (Union[str, List[str]])):
378
+ The list of adapters to set or the adapter name in case of single adapter.
379
+ """
380
+ check_peft_version(min_version=MIN_PEFT_VERSION)
381
+
382
+ if not self._pp_peft_config_loaded:
383
+ raise ValueError("No adapter loaded. Please load an adapter first.")
384
+
385
+ if isinstance(adapter_name, str):
386
+ adapter_name = [adapter_name]
387
+
388
+ missing = set(adapter_name) - set(self.peft_config)
389
+ if len(missing) > 0:
390
+ raise ValueError(
391
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
392
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
393
+ )
394
+
395
+ from ppdiffusers.peft.tuners.tuners_utils import BaseTunerLayer
396
+
397
+ _adapters_has_been_set = False
398
+
399
+ for _, module in self.named_sublayers(include_self=True):
400
+ if isinstance(module, BaseTunerLayer):
401
+ if hasattr(module, "set_adapter"):
402
+ module.set_adapter(adapter_name)
403
+ # Previous versions of PEFT does not support multi-adapter inference
404
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
405
+ raise ValueError(
406
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
407
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
408
+ )
409
+ else:
410
+ module.active_adapter = adapter_name
411
+ _adapters_has_been_set = True
412
+
413
+ if not _adapters_has_been_set:
414
+ raise ValueError(
415
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
416
+ )
417
+
418
+ def disable_adapters(self) -> None:
419
+ r"""
420
+ Disable all adapters attached to the model and fallback to inference with the base model only.
421
+
422
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
423
+ official documentation: https://huggingface.co/docs/peft
424
+ """
425
+ check_peft_version(min_version=MIN_PEFT_VERSION)
426
+
427
+ if not self._pp_peft_config_loaded:
428
+ raise ValueError("No adapter loaded. Please load an adapter first.")
429
+
430
+ from ppdiffusers.peft.tuners.tuners_utils import BaseTunerLayer
431
+
432
+ for _, module in self.named_sublayers(include_self=True):
433
+ if isinstance(module, BaseTunerLayer):
434
+ if hasattr(module, "enable_adapters"):
435
+ module.enable_adapters(enabled=False)
436
+ else:
437
+ # support for older PEFT versions
438
+ module.disable_adapters = True
439
+
440
+ def enable_adapters(self) -> None:
441
+ """
442
+ Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
443
+ list of adapters to enable.
444
+
445
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
446
+ official documentation: https://huggingface.co/docs/peft
447
+ """
448
+ check_peft_version(min_version=MIN_PEFT_VERSION)
449
+
450
+ if not self._pp_peft_config_loaded:
451
+ raise ValueError("No adapter loaded. Please load an adapter first.")
452
+
453
+ from ppdiffusers.peft.tuners.tuners_utils import BaseTunerLayer
454
+
455
+ for _, module in self.named_sublayers(include_self=True):
456
+ if isinstance(module, BaseTunerLayer):
457
+ if hasattr(module, "enable_adapters"):
458
+ module.enable_adapters(enabled=True)
459
+ else:
460
+ # support for older PEFT versions
461
+ module.disable_adapters = False
462
+
463
+ def active_adapters(self) -> List[str]:
464
+ """
465
+ Gets the current list of active adapters of the model.
466
+
467
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
468
+ official documentation: https://huggingface.co/docs/peft
469
+ """
470
+ check_peft_version(min_version=MIN_PEFT_VERSION)
471
+
472
+ if not self._pp_peft_config_loaded:
473
+ raise ValueError("No adapter loaded. Please load an adapter first.")
474
+
475
+ from ppdiffusers.peft.tuners.tuners_utils import BaseTunerLayer
476
+
477
+ for _, module in self.named_sublayers(include_self=True):
478
+ if isinstance(module, BaseTunerLayer):
479
+ return module.active_adapter
480
+
481
+ def save_pretrained(
482
+ self,
483
+ save_directory: Union[str, os.PathLike],
484
+ is_main_process: bool = True,
485
+ save_function: Optional[Callable] = None,
486
+ max_shard_size: Union[int, str] = "10GB",
487
+ safe_serialization: bool = True,
488
+ variant: Optional[str] = None,
489
+ push_to_hub: bool = False,
490
+ save_to_aistudio: bool = False,
491
+ to_diffusers: Optional[bool] = None,
492
+ **kwargs,
493
+ ):
494
+ """
495
+ Save a model and its configuration file to a directory so that it can be reloaded using the
496
+ [`~models.ModelMixin.from_pretrained`] class method.
497
+
498
+ Arguments:
499
+ save_directory (`str` or `os.PathLike`):
500
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
501
+ is_main_process (`bool`, *optional*, defaults to `True`):
502
+ Whether the process calling this is the main process or not. Useful during distributed training and you
503
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
504
+ process to avoid race conditions.
505
+ save_function (`Callable`):
506
+ The function to use to save the state dictionary. Useful during distributed training when you need to
507
+ replace `torch.save` with another method. Can be configured with the environment variable
508
+ `DIFFUSERS_SAVE_MODE`.
509
+ safe_serialization (`bool`, *optional*, defaults to `True`):
510
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
511
+ variant (`str`, *optional*):
512
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
513
+ push_to_hub (`bool`, *optional*, defaults to `False`):
514
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
515
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
516
+ namespace).
517
+ kwargs (`Dict[str, Any]`, *optional*):
518
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
519
+ """
520
+ # distributed kwargs
521
+ merge_tensor_parallel = kwargs.get("merge_tensor_parallel", False)
522
+ tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
523
+
524
+ if to_diffusers is None:
525
+ to_diffusers = TO_DIFFUSERS
526
+
527
+ if os.path.isfile(save_directory):
528
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
529
+ return
530
+
531
+ os.makedirs(save_directory, exist_ok=True)
532
+
533
+ # create repo
534
+ commit_message = kwargs.pop("commit_message", None)
535
+ private = kwargs.pop("private", False)
536
+ create_pr = kwargs.pop("create_pr", False)
537
+ token = kwargs.pop("token", None)
538
+ token_kwargs = {}
539
+ if token is not None:
540
+ token_kwargs["token"] = token
541
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
542
+ license = kwargs.pop("license", "creativeml-openrail-m")
543
+ exist_ok = kwargs.pop("exist_ok", True)
544
+
545
+ if push_to_hub:
546
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, **token_kwargs).repo_id
547
+
548
+ if save_to_aistudio:
549
+ assert "/" in repo_id, "Please specify the repo id in format of `user_id/repo_name`"
550
+ res = aistudio_create_repo(repo_id=repo_id, private=private, license=license, **token_kwargs)
551
+ if "error_code" in res:
552
+ if res["error_code"] == 10003 and exist_ok:
553
+ logger.info(
554
+ f"Repo {repo_id} already exists, it will override files with the same name. To avoid this, please set exist_ok=False"
555
+ )
556
+ else:
557
+ logger.error(
558
+ f"Failed to create repo {repo_id}, error_code: {res['error_code']}, error_msg: {res['error_msg']}"
559
+ )
560
+ else:
561
+ logger.info(f"Successfully created repo {repo_id}")
562
+
563
+ # Only save the model itself if we are using distributed training
564
+ model_to_save = self
565
+
566
+ # Attach architecture to the config
567
+ # Save the config
568
+ if is_main_process:
569
+ model_to_save.save_config(save_directory, to_diffusers=to_diffusers)
570
+
571
+ # Save the model
572
+ state_dict = model_to_save.state_dict()
573
+ if tensor_parallel_degree > 1:
574
+ if merge_tensor_parallel:
575
+ config_to_save = model_to_save._internal_dict
576
+ state_dict = model_to_save.merge_tensor_parallel(state_dict, config_to_save)
577
+ tensor_parallel_degree = 1
578
+ if paddle.distributed.fleet.get_hybrid_communicate_group().get_model_parallel_rank() != 0:
579
+ logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save")
580
+ return
581
+
582
+ if to_diffusers:
583
+ if not is_torch_available() and not safe_serialization:
584
+ safe_serialization = True
585
+ logger.warning(
586
+ "PyTorch is not installed, and `safe_serialization` is currently set to `False`. "
587
+ "To ensure proper model saving, we will automatically set `safe_serialization=True`. "
588
+ "If you want to keep `safe_serialization=False`, please make sure PyTorch is installed."
589
+ )
590
+ if safe_serialization:
591
+ save_index_file = TORCH_SAFETENSORS_WEIGHTS_NAME_INDEX_NAME
592
+ weights_name = TORCH_SAFETENSORS_WEIGHTS_NAME
593
+ if is_torch_available():
594
+ save_function = partial(torch_safe_save_file, metadata={"format": "pt"})
595
+ else:
596
+ save_function = partial(np_safe_save_file, metadata={"format": "pt"})
597
+ else:
598
+ save_index_file = TORCH_WEIGHTS_NAME_INDEX_NAME
599
+ weights_name = TORCH_WEIGHTS_NAME
600
+ save_function = torch.save
601
+ else:
602
+ if safe_serialization:
603
+ save_index_file = PADDLE_SAFETENSORS_WEIGHTS_NAME_INDEX_NAME
604
+ weights_name = PADDLE_SAFETENSORS_WEIGHTS_NAME
605
+ save_function = partial(np_safe_save_file, metadata={"format": "pd"})
606
+ else:
607
+ save_index_file = PADDLE_WEIGHTS_NAME_INDEX_NAME
608
+ weights_name = PADDLE_WEIGHTS_NAME
609
+ save_function = paddle.save
610
+
611
+ weights_name = _add_variant(weights_name, variant)
612
+
613
+ # Save model
614
+ shards, index = shard_checkpoint(
615
+ state_dict,
616
+ max_shard_size=max_shard_size,
617
+ weights_name=weights_name,
618
+ )
619
+ # Save the model
620
+ for shard_file, shard in shards.items():
621
+ for k in list(shard.keys()):
622
+ if isinstance(shard[k], paddle.Tensor):
623
+ shard[k] = np.ascontiguousarray(shard.pop(k).cpu().numpy())
624
+ if to_diffusers:
625
+ convert_paddle_state_dict_to_pytorch(self, shard)
626
+ save_function(shard, os.path.join(save_directory, shard_file))
627
+
628
+ # Save the model
629
+ if index is None:
630
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
631
+
632
+ else:
633
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
634
+ # Save the index as well
635
+ with open(save_index_file, "w", encoding="utf-8") as f:
636
+ content = json.dumps(index, indent=2) + "\n"
637
+ f.write(content)
638
+ logger.info(
639
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
640
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
641
+ f"index located at {save_index_file}."
642
+ )
643
+ # upload to aistudio or huggingface hub
644
+ if save_to_aistudio:
645
+ self._upload_folder_aistudio(
646
+ save_directory,
647
+ repo_id,
648
+ commit_message=commit_message,
649
+ **token_kwargs,
650
+ )
651
+ if push_to_hub:
652
+ self._upload_folder(
653
+ save_directory,
654
+ repo_id,
655
+ commit_message=commit_message,
656
+ create_pr=create_pr,
657
+ **token_kwargs,
658
+ )
659
+
660
+ @classmethod
661
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
662
+ r"""
663
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
664
+
665
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
666
+ train the model, set it back in training mode with `model.train()`.
667
+
668
+ Parameters:
669
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
670
+ Can be either:
671
+
672
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
673
+ the Hub.
674
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
675
+ with [`~ModelMixin.save_pretrained`].
676
+
677
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
678
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
679
+ is not used.
680
+ paddle_dtype (`str` or `paddle.dtype`, *optional*):
681
+ Override the default `paddle.dtype` and load the model with another dtype. If `"auto"` is passed, the
682
+ dtype is automatically derived from the model's weights.
683
+ force_download (`bool`, *optional*, defaults to `False`):
684
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
685
+ cached versions if they exist.
686
+ resume_download (`bool`, *optional*, defaults to `False`):
687
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
688
+ incompletely downloaded files are deleted.
689
+ proxies (`Dict[str, str]`, *optional*):
690
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
691
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
692
+ output_loading_info (`bool`, *optional*, defaults to `False`):
693
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
694
+ local_files_only(`bool`, *optional*, defaults to `False`):
695
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
696
+ won't be downloaded from the Hub.
697
+ use_auth_token (`str` or *bool*, *optional*):
698
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
699
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
700
+ revision (`str`, *optional*, defaults to `"main"`):
701
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
702
+ allowed by Git.
703
+ from_flax (`bool`, *optional*, defaults to `False`):
704
+ Load the model weights from a Flax checkpoint save file.
705
+ subfolder (`str`, *optional*, defaults to `""`):
706
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
707
+ mirror (`str`, *optional*):
708
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
709
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
710
+ information.
711
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
712
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
713
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
714
+ same device.
715
+
716
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
717
+ more information about each option see [designing a device
718
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
719
+ max_memory (`Dict`, *optional*):
720
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
721
+ each GPU and the available CPU RAM if unset.
722
+ offload_folder (`str` or `os.PathLike`, *optional*):
723
+ The path to offload weights if `device_map` contains the value `"disk"`.
724
+ offload_state_dict (`bool`, *optional*):
725
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
726
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
727
+ when there is some disk offload.
728
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
729
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
730
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
731
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
732
+ argument to `True` will raise an error.
733
+ variant (`str`, *optional*):
734
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
735
+ loading `from_flax`.
736
+ use_safetensors (`bool`, *optional*, defaults to `None`):
737
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
738
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
739
+ weights. If set to `False`, `safetensors` weights are not loaded.
740
+
741
+ <Tip>
742
+
743
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
744
+ `huggingface-cli login`. You can also activate the special
745
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
746
+ firewalled environment.
747
+
748
+ </Tip>
749
+
750
+ Example:
751
+
752
+ ```py
753
+ from ppdiffusers import UNet2DConditionModel
754
+
755
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
756
+ ```
757
+
758
+ If you get the error message below, you need to finetune the weights for your downstream task:
759
+
760
+ ```bash
761
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
762
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
763
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
764
+ ```
765
+ """
766
+ from_hf_hub = kwargs.pop("from_hf_hub", FROM_HF_HUB)
767
+ from_aistudio = kwargs.pop("from_aistudio", FROM_AISTUDIO)
768
+ cache_dir = kwargs.pop("cache_dir", None)
769
+ if cache_dir is None:
770
+ if from_aistudio:
771
+ cache_dir = None # TODO, check aistudio cache
772
+ elif from_hf_hub:
773
+ cache_dir = DIFFUSERS_CACHE
774
+ else:
775
+ cache_dir = PPDIFFUSERS_CACHE
776
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
777
+ force_download = kwargs.pop("force_download", False)
778
+ from_diffusers = kwargs.pop("from_diffusers", FROM_DIFFUSERS)
779
+ resume_download = kwargs.pop("resume_download", False)
780
+ proxies = kwargs.pop("proxies", None)
781
+ output_loading_info = kwargs.pop("output_loading_info", False)
782
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
783
+ use_auth_token = kwargs.pop("use_auth_token", None)
784
+ revision = kwargs.pop("revision", None)
785
+ paddle_dtype = kwargs.pop("paddle_dtype", None)
786
+ subfolder = kwargs.pop("subfolder", "")
787
+ if subfolder is None:
788
+ subfolder = ""
789
+ device_map = kwargs.pop("device_map", None)
790
+ max_memory = kwargs.pop("max_memory", None)
791
+ offload_folder = kwargs.pop("offload_folder", None)
792
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
793
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", LOW_CPU_MEM_USAGE_DEFAULT)
794
+ variant = kwargs.pop("variant", None)
795
+ use_safetensors = kwargs.pop("use_safetensors", None)
796
+ ignore_keys = kwargs.pop("ignore_keys", [])
797
+
798
+ # distributed kwargs
799
+ tensor_parallel_degree = kwargs.pop("tensor_parallel_degree", 1)
800
+
801
+ if use_safetensors is None:
802
+ use_safetensors = True
803
+
804
+ if low_cpu_mem_usage and (not is_paddle_version(">=", "2.5.0") and not is_paddle_version("==", "0.0.0")):
805
+ raise NotImplementedError(
806
+ "Low memory initialization requires paddlepaddle-gpu >= 2.5.0. Please either update your PaddlePaddle version or set"
807
+ " `low_cpu_mem_usage=False`."
808
+ )
809
+
810
+ # Load config if we don't provide a configuration
811
+ config_path = pretrained_model_name_or_path
812
+
813
+ user_agent = {
814
+ "ppdiffusers": __version__,
815
+ "file_type": "model",
816
+ "framework": "pytorch" if from_diffusers else "paddle",
817
+ }
818
+
819
+ # load config
820
+ config, unused_kwargs, commit_hash, config_file = cls.load_config(
821
+ config_path,
822
+ cache_dir=cache_dir,
823
+ return_unused_kwargs=True,
824
+ return_commit_hash=True,
825
+ return_config_file=True,
826
+ force_download=force_download,
827
+ resume_download=resume_download,
828
+ proxies=proxies,
829
+ local_files_only=local_files_only,
830
+ use_auth_token=use_auth_token,
831
+ revision=revision,
832
+ subfolder=subfolder,
833
+ device_map=device_map,
834
+ max_memory=max_memory,
835
+ offload_folder=offload_folder,
836
+ offload_state_dict=offload_state_dict,
837
+ user_agent=user_agent,
838
+ from_hf_hub=from_hf_hub,
839
+ from_aistudio=from_aistudio,
840
+ **kwargs,
841
+ )
842
+ index_file = None
843
+
844
+ variant_list = [variant]
845
+ if None not in variant_list:
846
+ variant_list.append(None)
847
+ if "fp16" not in variant_list:
848
+ variant_list.append("fp16")
849
+ if "fp32" not in variant_list:
850
+ variant_list.append("fp32")
851
+ for v_index, variant in enumerate(variant_list):
852
+ try:
853
+ if use_safetensors:
854
+ try:
855
+ # is sharded model
856
+ index_file = _get_model_file(
857
+ pretrained_model_name_or_path,
858
+ weights_name=_add_variant(TORCH_SAFETENSORS_WEIGHTS_NAME_INDEX_NAME, variant)
859
+ if from_diffusers
860
+ else _add_variant(PADDLE_SAFETENSORS_WEIGHTS_NAME_INDEX_NAME, variant),
861
+ cache_dir=cache_dir,
862
+ force_download=force_download,
863
+ resume_download=resume_download,
864
+ proxies=proxies,
865
+ local_files_only=local_files_only,
866
+ use_auth_token=use_auth_token,
867
+ revision=revision,
868
+ subfolder=subfolder,
869
+ user_agent=user_agent,
870
+ commit_hash=commit_hash,
871
+ from_hf_hub=from_hf_hub,
872
+ from_aistudio=from_aistudio,
873
+ )
874
+ except Exception:
875
+ index_file = None
876
+ if index_file is None:
877
+ # is sharded model
878
+ try:
879
+ index_file = _get_model_file(
880
+ pretrained_model_name_or_path,
881
+ weights_name=_add_variant(TORCH_WEIGHTS_NAME_INDEX_NAME, variant)
882
+ if from_diffusers
883
+ else _add_variant(PADDLE_WEIGHTS_NAME_INDEX_NAME, variant),
884
+ cache_dir=cache_dir,
885
+ force_download=force_download,
886
+ resume_download=resume_download,
887
+ proxies=proxies,
888
+ local_files_only=local_files_only,
889
+ use_auth_token=use_auth_token,
890
+ revision=revision,
891
+ subfolder=subfolder,
892
+ user_agent=user_agent,
893
+ commit_hash=commit_hash,
894
+ from_hf_hub=from_hf_hub,
895
+ from_aistudio=from_aistudio,
896
+ )
897
+ except Exception:
898
+ index_file = None
899
+ is_sharded = index_file is not None
900
+
901
+ if is_sharded:
902
+ resolved_model_files, sharded_metadata = get_checkpoint_shard_files(
903
+ pretrained_model_name_or_path,
904
+ index_filename=index_file,
905
+ cache_dir=cache_dir,
906
+ force_download=force_download,
907
+ resume_download=resume_download,
908
+ proxies=proxies,
909
+ local_files_only=local_files_only,
910
+ use_auth_token=use_auth_token,
911
+ revision=revision,
912
+ subfolder=subfolder,
913
+ user_agent=user_agent,
914
+ commit_hash=commit_hash,
915
+ from_hf_hub=from_hf_hub,
916
+ from_aistudio=from_aistudio,
917
+ )
918
+ if not isinstance(resolved_model_files, list):
919
+ resolved_model_files = [resolved_model_files]
920
+ else:
921
+ # load model
922
+ model_file = None
923
+ if use_safetensors:
924
+ try:
925
+ model_file = _get_model_file(
926
+ pretrained_model_name_or_path,
927
+ weights_name=_add_variant(TORCH_SAFETENSORS_WEIGHTS_NAME, variant)
928
+ if from_diffusers
929
+ else _add_variant(PADDLE_SAFETENSORS_WEIGHTS_NAME, variant),
930
+ cache_dir=cache_dir,
931
+ force_download=force_download,
932
+ resume_download=resume_download,
933
+ proxies=proxies,
934
+ local_files_only=local_files_only,
935
+ use_auth_token=use_auth_token,
936
+ revision=revision,
937
+ subfolder=subfolder,
938
+ user_agent=user_agent,
939
+ commit_hash=commit_hash,
940
+ from_hf_hub=from_hf_hub,
941
+ from_aistudio=from_aistudio,
942
+ )
943
+ except Exception:
944
+ model_file = None
945
+ pass
946
+ if model_file is None:
947
+ model_file = _get_model_file(
948
+ pretrained_model_name_or_path,
949
+ weights_name=_add_variant(TORCH_WEIGHTS_NAME, variant)
950
+ if from_diffusers
951
+ else _add_variant(PADDLE_WEIGHTS_NAME, variant),
952
+ cache_dir=cache_dir,
953
+ force_download=force_download,
954
+ resume_download=resume_download,
955
+ proxies=proxies,
956
+ local_files_only=local_files_only,
957
+ use_auth_token=use_auth_token,
958
+ revision=revision,
959
+ subfolder=subfolder,
960
+ user_agent=user_agent,
961
+ commit_hash=commit_hash,
962
+ from_hf_hub=from_hf_hub,
963
+ from_aistudio=from_aistudio,
964
+ )
965
+ resolved_model_files = [model_file]
966
+ except Exception as e: # NOQA
967
+ logger.warning(
968
+ f"Unable to load the `variant={variant}` of the model from `{pretrained_model_name_or_path}`! "
969
+ "Please make sure the specified variant exists and is correct."
970
+ )
971
+ resolved_model_files = []
972
+ if len(resolved_model_files) > 0:
973
+ if v_index > 0:
974
+ name = (
975
+ ", ".join([config_file, index_file] + resolved_model_files)
976
+ if index_file is not None
977
+ else ", ".join(resolved_model_files)
978
+ )
979
+ logger.warning(
980
+ f"Proceeding to load the `variant={variant}` of the model with the resolved model files: {name}. "
981
+ "Please note that this might not be the desired variant."
982
+ )
983
+ break
984
+ variant_str = ", ".join(map(lambda x: "`" + str(x) + "`", variant_list))
985
+ assert len(resolved_model_files) > 0, (
986
+ f"We are attempting to load the variant in [{variant_str}]. "
987
+ f"But unfortunately, no model files were found in the path {pretrained_model_name_or_path}. "
988
+ "Please check if the provided path is correct and ensure that it contains the necessary model files. "
989
+ "If the issue persists, consider redownloading the model files or contacting the model provider for assistance."
990
+ )
991
+ init_contexts = []
992
+
993
+ dtype = paddle.float32 if paddle_dtype is None else paddle_dtype
994
+ init_contexts.append(paddle.dtype_guard(dtype))
995
+
996
+ if low_cpu_mem_usage:
997
+ # Instantiate model.
998
+ init_contexts.append(no_init_weights(_enable=True))
999
+ if hasattr(paddle, "LazyGuard"):
1000
+ init_contexts.append(paddle.LazyGuard())
1001
+
1002
+ with ContextManagers(init_contexts):
1003
+ model = cls.from_config(config, **unused_kwargs)
1004
+
1005
+ # (westfish) 2024/04/01:
1006
+ # Tensor parallel is only supported for models that inherit from `ConversionMixin`
1007
+ if tensor_parallel_degree > 1:
1008
+ from paddlenlp.transformers.conversion_utils import ConversionMixin
1009
+
1010
+ if not issubclass(cls, ConversionMixin):
1011
+ raise NotImplementedError(
1012
+ "Tensor parallel is only supported for models that inherit from `ConversionMixin`."
1013
+ )
1014
+ if len(resolved_model_files) > 1:
1015
+ raise NotImplementedError("Tensor parallel is not supported for multiple shards yet.")
1016
+ tmp_state_dict = smart_load(resolved_model_files[0], return_numpy=True)
1017
+ tensor_parallel_split_mapping = cls.get_tensor_parallel_convert_actions(config, tmp_state_dict.keys())
1018
+ else:
1019
+ tensor_parallel_split_mapping = None
1020
+
1021
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
1022
+ model,
1023
+ resolved_model_files,
1024
+ pretrained_model_name_or_path,
1025
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1026
+ ignore_keys=ignore_keys,
1027
+ from_diffusers=from_diffusers,
1028
+ tensor_parallel_split_mapping=tensor_parallel_split_mapping,
1029
+ tensor_parallel_degree=tensor_parallel_degree,
1030
+ )
1031
+
1032
+ loading_info = {
1033
+ "missing_keys": missing_keys,
1034
+ "unexpected_keys": unexpected_keys,
1035
+ "mismatched_keys": mismatched_keys,
1036
+ "error_msgs": error_msgs,
1037
+ }
1038
+
1039
+ if paddle_dtype is not None:
1040
+ model = model.to(dtype=paddle_dtype)
1041
+
1042
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1043
+
1044
+ # Set model in evaluation mode to deactivate DropOut modules by default
1045
+ model.eval()
1046
+ if output_loading_info:
1047
+ return model, loading_info
1048
+
1049
+ return model
1050
+
1051
+ @classmethod
1052
+ def custom_modify_weight(cls, model_to_load, state_dict):
1053
+ pass
1054
+
1055
+ @classmethod
1056
+ def _load_pretrained_model(
1057
+ cls,
1058
+ model: "ModelMixin",
1059
+ resolved_model_files,
1060
+ pretrained_model_name_or_path: Union[str, os.PathLike],
1061
+ ignore_mismatched_sizes: bool = False,
1062
+ ignore_keys=None,
1063
+ from_diffusers=False,
1064
+ tensor_parallel_split_mapping=None,
1065
+ tensor_parallel_degree=1,
1066
+ ):
1067
+ state_dict = OrderedDict()
1068
+ model_state_dict = model.state_dict()
1069
+ loaded_keys = []
1070
+ expected_keys = list(model_state_dict.keys())
1071
+ error_msgs = []
1072
+ mismatched_keys = []
1073
+
1074
+ if len(resolved_model_files) > 1:
1075
+ resolved_model_files = tqdm(resolved_model_files, desc="Loading checkpoint shards")
1076
+ if tensor_parallel_degree > 1:
1077
+ raise NotImplementedError("Tensor parallel is not supported for multiple shards yet.")
1078
+
1079
+ # load shard state dict
1080
+ for shard_file in resolved_model_files:
1081
+ data_format = load_state_dict(
1082
+ shard_file,
1083
+ state_dict, # inplace update state_dict
1084
+ tensor_parallel_split_mapping=tensor_parallel_split_mapping,
1085
+ ignore_keys=ignore_keys,
1086
+ )
1087
+ # NOTE: new add support old state_dict
1088
+ model._update_deprecated_state_dict(state_dict)
1089
+ # NOTE: convert old model state dict!
1090
+ model._convert_deprecated_attention_blocks(state_dict)
1091
+
1092
+ # NOTE: convert torch model state dict!
1093
+ if from_diffusers or data_format in ["pt"]:
1094
+ convert_pytorch_state_dict_to_paddle(model, state_dict)
1095
+
1096
+ original_loaded_keys = list(state_dict.keys())
1097
+ loaded_keys.extend(original_loaded_keys)
1098
+
1099
+ # Make sure we are able to load base models as well as derived models (with heads)
1100
+ model_to_load = model
1101
+
1102
+ def _find_mismatched_keys(
1103
+ state_dict,
1104
+ model_state_dict,
1105
+ loaded_keys,
1106
+ ignore_mismatched_sizes,
1107
+ ):
1108
+ mismatched_keys = []
1109
+ for checkpoint_key in loaded_keys:
1110
+ model_key = checkpoint_key
1111
+
1112
+ if model_key in model_state_dict and list(state_dict[checkpoint_key].shape) != list(
1113
+ model_state_dict[model_key].shape
1114
+ ):
1115
+ mismatched_keys.append(
1116
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1117
+ )
1118
+ del state_dict[checkpoint_key]
1119
+ if ignore_mismatched_sizes:
1120
+ mismatched_keys = []
1121
+ return mismatched_keys
1122
+
1123
+ if state_dict is not None and len(state_dict) > 0:
1124
+ _mismatched_keys = _find_mismatched_keys(
1125
+ state_dict,
1126
+ model_state_dict,
1127
+ original_loaded_keys,
1128
+ ignore_mismatched_sizes,
1129
+ )
1130
+ mismatched_keys.extend(_mismatched_keys)
1131
+ for key_name, loaded_shape, model_shape in _mismatched_keys:
1132
+ error_msgs.append(
1133
+ f"Error size mismatch, {key_name} receives a shape {loaded_shape}, but the expected shape is {model_shape}."
1134
+ )
1135
+ cls.custom_modify_weight(model_to_load, state_dict)
1136
+ faster_set_state_dict(model_to_load, state_dict)
1137
+
1138
+ missing_keys = sorted(list(set(expected_keys) - set(loaded_keys)))
1139
+ unexpected_keys = sorted(list(set(loaded_keys) - set(expected_keys)))
1140
+
1141
+ if len(error_msgs) > 0:
1142
+ error_msg = "\n\t".join(error_msgs)
1143
+ if "size mismatch" in error_msg:
1144
+ error_msg += (
1145
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1146
+ )
1147
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1148
+
1149
+ if len(unexpected_keys) > 0:
1150
+ logger.warning(
1151
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1152
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1153
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1154
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1155
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1156
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1157
+ " identical (initializing a BertForSequenceClassification model from a"
1158
+ " BertForSequenceClassification model)."
1159
+ )
1160
+ else:
1161
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1162
+ if len(missing_keys) > 0:
1163
+ logger.warning(
1164
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1165
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1166
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1167
+ )
1168
+ elif len(mismatched_keys) == 0:
1169
+ logger.info(
1170
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1171
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1172
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1173
+ " without further training."
1174
+ )
1175
+ if len(mismatched_keys) > 0:
1176
+ mismatched_warning = "\n".join(
1177
+ [
1178
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1179
+ for key, shape1, shape2 in mismatched_keys
1180
+ ]
1181
+ )
1182
+ logger.warning(
1183
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1184
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1185
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1186
+ " able to use it for predictions and inference."
1187
+ )
1188
+ del state_dict
1189
+ gc.collect()
1190
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1191
+
1192
+ @property
1193
+ def device(self):
1194
+ """
1195
+ `paddle.place`: The device on which the module is (assuming that all the module parameters are on the same
1196
+ device).
1197
+ """
1198
+ return get_parameter_device(self)
1199
+
1200
+ @property
1201
+ def dtype(self) -> paddle.dtype:
1202
+ """
1203
+ `paddle.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1204
+ """
1205
+ return get_parameter_dtype(self)
1206
+
1207
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
1208
+ """
1209
+ Get number of (trainable or non-embedding) parameters in the module.
1210
+
1211
+ Args:
1212
+ only_trainable (`bool`, *optional*, defaults to `False`):
1213
+ Whether or not to return only the number of trainable parameters.
1214
+
1215
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
1216
+ Whether or not to return only the number of non-embedding parameters.
1217
+
1218
+ Returns:
1219
+ `int`: The number of parameters.
1220
+
1221
+ Example:
1222
+ ```py
1223
+ from ppdiffusers import UNet2DConditionModel
1224
+ model_id = "runwayml/stable-diffusion-v1-5"
1225
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
1226
+ unet.num_parameters(only_trainable=True)
1227
+ 859520964
1228
+ ```
1229
+ """
1230
+
1231
+ if exclude_embeddings:
1232
+ embedding_param_names = [
1233
+ f"{name}.weight"
1234
+ for name, module_type in self.named_sublayers(include_self=True)
1235
+ if isinstance(module_type, nn.Embedding)
1236
+ ]
1237
+ non_embedding_parameters = [
1238
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1239
+ ]
1240
+ return sum(p.numel() for p in non_embedding_parameters if not p.stop_gradient or not only_trainable)
1241
+ else:
1242
+ return sum(p.numel() for p in self.parameters() if not p.stop_gradient or not only_trainable)
1243
+
1244
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1245
+ deprecated_attention_block_paths = []
1246
+
1247
+ def recursive_find_attn_block(name, module):
1248
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1249
+ deprecated_attention_block_paths.append(name)
1250
+
1251
+ for sub_name, sub_module in module.named_children():
1252
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
1253
+ recursive_find_attn_block(sub_name, sub_module)
1254
+
1255
+ recursive_find_attn_block("", self)
1256
+
1257
+ # NOTE: we have to check if the deprecated parameters are in the state dict
1258
+ # because it is possible we are loading from a state dict that was already
1259
+ # converted
1260
+
1261
+ for path in deprecated_attention_block_paths:
1262
+ # group_norm path stays the same
1263
+
1264
+ # query -> to_q
1265
+ if f"{path}.query.weight" in state_dict:
1266
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
1267
+ if f"{path}.query.bias" in state_dict:
1268
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
1269
+
1270
+ # key -> to_k
1271
+ if f"{path}.key.weight" in state_dict:
1272
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
1273
+ if f"{path}.key.bias" in state_dict:
1274
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
1275
+
1276
+ # value -> to_v
1277
+ if f"{path}.value.weight" in state_dict:
1278
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
1279
+ if f"{path}.value.bias" in state_dict:
1280
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
1281
+
1282
+ # proj_attn -> to_out.0
1283
+ if f"{path}.proj_attn.weight" in state_dict:
1284
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1285
+ if f"{path}.proj_attn.bias" in state_dict:
1286
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1287
+
1288
+ def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1289
+ deprecated_attention_block_modules = []
1290
+
1291
+ def recursive_find_attn_block(module):
1292
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1293
+ deprecated_attention_block_modules.append(module)
1294
+
1295
+ for sub_module in module.children():
1296
+ recursive_find_attn_block(sub_module)
1297
+
1298
+ recursive_find_attn_block(self)
1299
+
1300
+ for module in deprecated_attention_block_modules:
1301
+ module.query = module.to_q
1302
+ module.key = module.to_k
1303
+ module.value = module.to_v
1304
+ module.proj_attn = module.to_out[0]
1305
+
1306
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
1307
+ # that _all_ the weights are loaded into the new attributes and we're not
1308
+ # making an incorrect assumption that this model should be converted when
1309
+ # it really shouldn't be.
1310
+ del module.to_q
1311
+ del module.to_k
1312
+ del module.to_v
1313
+ del module.to_out
1314
+
1315
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1316
+ deprecated_attention_block_modules = []
1317
+
1318
+ def recursive_find_attn_block(module) -> None:
1319
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1320
+ deprecated_attention_block_modules.append(module)
1321
+
1322
+ for sub_module in module.children():
1323
+ recursive_find_attn_block(sub_module)
1324
+
1325
+ recursive_find_attn_block(self)
1326
+
1327
+ for module in deprecated_attention_block_modules:
1328
+ module.to_q = module.query
1329
+ module.to_k = module.key
1330
+ module.to_v = module.value
1331
+ module.to_out = nn.LayerList([module.proj_attn, nn.Dropout(module.dropout)])
1332
+
1333
+ del module.query
1334
+ del module.key
1335
+ del module.value
1336
+ del module.proj_attn
1337
+
1338
+ @classmethod
1339
+ def _update_deprecated_state_dict(cls, state_dict=None, loaded_keys=None, model=None):
1340
+ if state_dict is None:
1341
+ return loaded_keys
1342
+ _deprecated_dict = getattr(cls, "_deprecated_dict", None)
1343
+ from_deprecated_state_dict = _deprecated_dict is not None and any(
1344
+ cls._deprecated_dict.get("key", "NONE") in all_key for all_key in state_dict.keys()
1345
+ )
1346
+ if from_deprecated_state_dict:
1347
+ logger.warning(
1348
+ "Loading from deprecated state_dict, please load new state_dict via setting `use_safetensors=True`."
1349
+ )
1350
+ for name in list(state_dict.keys()):
1351
+ deprecated_name = name
1352
+ for old_name, new_name in cls._deprecated_dict.get("name_mapping", {}).items():
1353
+ name = name.replace(old_name, new_name)
1354
+ state_dict[name] = state_dict.pop(deprecated_name)
1355
+ loaded_keys = list(state_dict.keys())
1356
+ return loaded_keys
PaddleMIX/ppdiffusers/ppdiffusers/models/modelscope_gaussion_sdedit.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import random
17
+
18
+ import paddle
19
+ from tqdm.auto import trange
20
+
21
+
22
+ def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15):
23
+ t_min = math.atan(math.exp(-0.5 * logsnr_min))
24
+ t_max = math.atan(math.exp(-0.5 * logsnr_max))
25
+ t = paddle.linspace(1, 0, n)
26
+ logsnrs = -2 * paddle.log(paddle.tan(t_min + t * (t_max - t_min)))
27
+ return logsnrs
28
+
29
+
30
+ def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2):
31
+ logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max)
32
+ logsnrs += 2 * math.log(1 / scale)
33
+ return logsnrs
34
+
35
+
36
+ def logsnrs_to_sigmas(logsnrs):
37
+ return paddle.sqrt(paddle.nn.functional.sigmoid(-logsnrs))
38
+
39
+
40
+ def _logsnr_cosine_interp(n, logsnr_min=-15, logsnr_max=15, scale_min=2, scale_max=4):
41
+ t = paddle.linspace(1, 0, n)
42
+ logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min)
43
+ logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max)
44
+ logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max
45
+ return logsnrs
46
+
47
+
48
+ def logsnr_cosine_interp_schedule(n, logsnr_min=-15, logsnr_max=15, scale_min=2, scale_max=4):
49
+ return logsnrs_to_sigmas(_logsnr_cosine_interp(n, logsnr_min, logsnr_max, scale_min, scale_max))
50
+
51
+
52
+ def noise_schedule(schedule="logsnr_cosine_interp", n=1000, zero_terminal_snr=False, **kwargs):
53
+ # compute sigmas
54
+ sigmas = {"logsnr_cosine_interp": logsnr_cosine_interp_schedule}[schedule](n, **kwargs)
55
+
56
+ # post-processing
57
+ if zero_terminal_snr and sigmas.max() != 1.0:
58
+ scale = (1.0 - sigmas.min()) / (sigmas.max() - sigmas.min())
59
+ sigmas = sigmas.min() + scale * (sigmas - sigmas.min())
60
+ return sigmas
61
+
62
+
63
+ def _i(tensor, t, x):
64
+ r"""Index tensor using t and format the output according to x."""
65
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
66
+ if tensor.place != x.place:
67
+ tensor = paddle.to_tensor(tensor, place=x.place)
68
+ return tensor[t].reshape(shape).astype(x.dtype)
69
+
70
+
71
+ def get_scalings(sigma):
72
+ c_out = -sigma
73
+ c_in = 1 / (sigma**2 + 1.0**2) ** 0.5
74
+ return c_out, c_in
75
+
76
+
77
+ def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0):
78
+ ramp = paddle.linspace(1, 0, n)
79
+ min_inv_rho = sigma_min ** (1 / rho)
80
+ max_inv_rho = sigma_max ** (1 / rho)
81
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
82
+ sigmas = paddle.sqrt(sigmas**2 / (1 + sigmas**2))
83
+ return sigmas
84
+
85
+
86
+ @paddle.no_grad()
87
+ def sample_heun(noise, model, sigmas, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, show_progress=True):
88
+ """
89
+ Implements Algorithm 2 (Heun steps) from Karras et al. (2022).
90
+ """
91
+ x = noise * sigmas[0]
92
+ for i in trange(len(sigmas) - 1, disable=not show_progress):
93
+ gamma = 0.0
94
+ if s_tmin <= sigmas[i] <= s_tmax and sigmas[i] < float("inf"):
95
+ gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
96
+ eps = paddle.randn(shape=x.shape, dtype=x.dtype) * s_noise
97
+ sigma_hat = sigmas[i] * (gamma + 1)
98
+ if gamma > 0:
99
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
100
+ if sigmas[i] == float("inf"):
101
+ # Euler method
102
+ denoised = model(noise, sigma_hat)
103
+ x = denoised + sigmas[i + 1] * (gamma + 1) * noise
104
+ else:
105
+ _, c_in = get_scalings(sigma_hat)
106
+ denoised = model(x * c_in, sigma_hat)
107
+ d = (x - denoised) / sigma_hat
108
+ dt = sigmas[i + 1] - sigma_hat
109
+ if sigmas[i + 1] == 0:
110
+ # Euler method
111
+ x = x + d * dt
112
+ else:
113
+ # Heun's method
114
+ x_2 = x + d * dt
115
+ _, c_in = get_scalings(sigmas[i + 1])
116
+ denoised_2 = model(x_2 * c_in, sigmas[i + 1])
117
+ d_2 = (x_2 - denoised_2) / sigmas[i + 1]
118
+ d_prime = (d + d_2) / 2
119
+ x = x + d_prime * dt
120
+ return x
121
+
122
+
123
+ class BatchedBrownianTree:
124
+ """
125
+ A wrapper around torchsde.BrownianTree that enables batches of entropy.
126
+ """
127
+
128
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
129
+ import paddlesde
130
+
131
+ t0, t1, self.sign = self.sort(t0, t1)
132
+ w0 = kwargs.get("w0", paddle.zeros_like(x))
133
+ if seed is None:
134
+ seed = paddle.randint(0, 2**31 - 1, []).item()
135
+ self.batched = True
136
+ try:
137
+ assert len(seed) == x.shape[0]
138
+ w0 = w0[0]
139
+ except TypeError:
140
+ seed = [seed]
141
+ self.batched = False
142
+ self.trees = [paddlesde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
143
+
144
+ @staticmethod
145
+ def sort(a, b):
146
+ return (a, b, 1) if a < b else (b, a, -1)
147
+
148
+ def __call__(self, t0, t1):
149
+ t0, t1, sign = self.sort(t0, t1)
150
+ w = paddle.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
151
+ return w if self.batched else w[0]
152
+
153
+
154
+ class BrownianTreeNoiseSampler:
155
+ """
156
+ A noise sampler backed by a torchsde.BrownianTree.
157
+
158
+ Args:
159
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
160
+ random samples.
161
+ sigma_min (float): The low end of the valid interval.
162
+ sigma_max (float): The high end of the valid interval.
163
+ seed (int or List[int]): The random seed. If a list of seeds is
164
+ supplied instead of a single integer, then the noise sampler will
165
+ use one BrownianTree per batch item, each with its own seed.
166
+ transform (callable): A function that maps sigma to the sampler's
167
+ internal timestep.
168
+ """
169
+
170
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
171
+ self.transform = transform
172
+ t0 = self.transform(paddle.to_tensor(sigma_min))
173
+ t1 = self.transform(paddle.to_tensor(sigma_max))
174
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
175
+
176
+ def __call__(self, sigma, sigma_next):
177
+ t0 = self.transform(paddle.to_tensor(sigma))
178
+ t1 = self.transform(paddle.to_tensor(sigma_next))
179
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
180
+
181
+
182
+ @paddle.no_grad()
183
+ def sample_dpmpp_2m_sde(noise, model, sigmas, eta=1.0, s_noise=1.0, solver_type="midpoint", show_progress=True):
184
+ """
185
+ DPM-Solver++ (2M) SDE.
186
+ """
187
+ assert solver_type in {"heun", "midpoint"}
188
+
189
+ x = noise * sigmas[0]
190
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float("inf")].max()
191
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
192
+ old_denoised = None
193
+ h_last = None
194
+
195
+ for i in trange(len(sigmas) - 1, disable=not show_progress):
196
+ if sigmas[i] == float("inf"):
197
+ # Euler method
198
+ denoised = model(noise, sigmas[i])
199
+ x = denoised + sigmas[i + 1] * noise
200
+ else:
201
+ _, c_in = get_scalings(sigmas[i])
202
+ denoised = model(x * c_in, sigmas[i])
203
+ if sigmas[i + 1] == 0:
204
+ # Denoising step
205
+ x = denoised
206
+ else:
207
+ # DPM-Solver++(2M) SDE
208
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
209
+ h = s - t
210
+ eta_h = eta * h
211
+
212
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
213
+ if old_denoised is not None:
214
+ r = h_last / h
215
+ if solver_type == "heun":
216
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
217
+ elif solver_type == "midpoint":
218
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
219
+
220
+ x = (
221
+ x
222
+ + noise_sampler(sigmas[i], sigmas[i + 1])
223
+ * sigmas[i + 1]
224
+ * (-2 * eta_h).expm1().neg().sqrt()
225
+ * s_noise
226
+ )
227
+
228
+ old_denoised = denoised
229
+ h_last = h
230
+ return x
231
+
232
+
233
+ class GaussianDiffusion_SDEdit(object):
234
+ def __init__(self, sigmas, prediction_type="eps"):
235
+ assert prediction_type in {"x0", "eps", "v"}
236
+ self.sigmas = sigmas
237
+ self.alphas = paddle.sqrt(1 - sigmas**2)
238
+ self.num_timesteps = len(sigmas)
239
+ self.prediction_type = prediction_type
240
+
241
+ def diffuse(self, x0, t, noise=None):
242
+ noise = paddle.randn(shape=x0.shape, dtype=x0.dtype) if noise is None else noise
243
+ xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
244
+ return xt
245
+
246
+ def denoise(
247
+ self, xt, t, s, model, model_kwargs={}, guide_scale=None, guide_rescale=None, clamp=None, percentile=None
248
+ ):
249
+ s = t - 1 if s is None else s
250
+
251
+ # hyperparams
252
+ sigmas = _i(self.sigmas, t, xt)
253
+ alphas = _i(self.alphas, t, xt)
254
+ alphas_s = _i(self.alphas, s.clip(0), xt)
255
+ alphas_s[s < 0] = 1.0
256
+ sigmas_s = paddle.sqrt(1 - alphas_s**2)
257
+
258
+ # precompute variables
259
+ betas = 1 - (alphas / alphas_s) ** 2
260
+ coef1 = betas * alphas_s / sigmas**2
261
+ coef2 = (alphas * sigmas_s**2) / (alphas_s * sigmas**2)
262
+ var = betas * (sigmas_s / sigmas) ** 2
263
+ log_var = paddle.log(var).clip_(-20, 20)
264
+
265
+ # prediction
266
+ if guide_scale is None:
267
+ assert isinstance(model_kwargs, dict)
268
+ out = model(xt, t=t, **model_kwargs).sample
269
+ else:
270
+ # classifier-free guidance
271
+ assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
272
+ y_out = model(xt, t=t, **model_kwargs[0]).sample
273
+ if guide_scale == 1.0:
274
+ out = y_out
275
+ else:
276
+ u_out = model(xt, t=t, **model_kwargs[1]).sample
277
+ out = u_out + guide_scale * (y_out - u_out)
278
+
279
+ if guide_rescale is not None:
280
+ assert 0 <= guide_rescale <= 1
281
+ ratio = (
282
+ paddle.std(y_out.flatten(1), axis=1) / (paddle.std(out.flatten(1), axis=1) + 1e-12) # noqa
283
+ ).reshape(list((-1,) + (1,) * (y_out.ndim - 1)))
284
+ out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
285
+
286
+ # compute x0
287
+ if self.prediction_type == "x0":
288
+ x0 = out
289
+ elif self.prediction_type == "eps":
290
+ x0 = (xt - sigmas * out) / alphas
291
+ elif self.prediction_type == "v":
292
+ x0 = alphas * xt - sigmas * out
293
+ else:
294
+ raise NotImplementedError(f"prediction_type {self.prediction_type} not implemented")
295
+
296
+ # restrict the range of x0
297
+ if percentile is not None:
298
+ assert 0 < percentile <= 1
299
+ s = paddle.quantile(x0.flatten(1).abs(), percentile, axis=1).clip_(1.0).reshape([-1, 1, 1, 1])
300
+ x0 = paddle.min(s, paddle.max(-s, x0)) / s
301
+ elif clamp is not None:
302
+ x0 = x0.clip_(-clamp, clamp)
303
+
304
+ # recompute eps using the restricted x0
305
+ eps = (xt - alphas * x0) / sigmas
306
+
307
+ # compute mu (mean of posterior distribution) using the restricted x0
308
+ mu = coef1 * x0 + coef2 * xt
309
+ return mu, var, log_var, x0, eps
310
+
311
+ @paddle.no_grad()
312
+ def sample(
313
+ self,
314
+ noise,
315
+ model,
316
+ model_kwargs={},
317
+ condition_fn=None,
318
+ guide_scale=None,
319
+ guide_rescale=None,
320
+ clamp=None,
321
+ percentile=None,
322
+ solver="euler_a",
323
+ steps=20,
324
+ t_max=None,
325
+ t_min=None,
326
+ discretization=None,
327
+ discard_penultimate_step=None,
328
+ return_intermediate=None,
329
+ show_progress=False,
330
+ seed=-1,
331
+ **kwargs
332
+ ):
333
+ # sanity check
334
+ assert isinstance(steps, (int, "paddle.int64"))
335
+ assert t_max is None or (0 < t_max <= self.num_timesteps - 1)
336
+ assert t_min is None or (0 <= t_min < self.num_timesteps - 1)
337
+ assert discretization in (None, "leading", "linspace", "trailing")
338
+ assert discard_penultimate_step in (None, True, False)
339
+ assert return_intermediate in (None, "x0", "xt")
340
+
341
+ # function of diffusion solver
342
+ solver_fn = {"heun": sample_heun, "dpmpp_2m_sde": sample_dpmpp_2m_sde}[solver]
343
+
344
+ # options
345
+ schedule = "karras" if "karras" in solver else None
346
+ discretization = discretization or "linspace"
347
+ seed = seed if seed >= 0 else random.randint(0, 2**31)
348
+
349
+ if isinstance(steps, paddle.Tensor):
350
+ discard_penultimate_step = False
351
+ if discard_penultimate_step is None:
352
+ discard_penultimate_step = (
353
+ True
354
+ if solver
355
+ in (
356
+ "dpm2",
357
+ "dpm2_ancestral",
358
+ "dpmpp_2m_sde",
359
+ "dpm2_karras",
360
+ "dpm2_ancestral_karras",
361
+ "dpmpp_2m_sde_karras",
362
+ )
363
+ else False
364
+ )
365
+
366
+ # function for denoising xt to get x0
367
+ intermediates = []
368
+
369
+ def model_fn(xt, sigma):
370
+ # denoising
371
+ t = self._sigma_to_t(sigma).tile(len(xt)).round().astype("int64")
372
+ x0 = self.denoise(xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, percentile)[-2]
373
+
374
+ # collect intermediate outputs
375
+ if return_intermediate == "xt":
376
+ intermediates.append(xt)
377
+ elif return_intermediate == "x0":
378
+ intermediates.append(x0)
379
+ return x0
380
+
381
+ # get timesteps
382
+ if isinstance(steps, int):
383
+ steps += 1 if discard_penultimate_step else 0
384
+ t_max = self.num_timesteps - 1 if t_max is None else t_max
385
+ t_min = 0 if t_min is None else t_min
386
+
387
+ # discretize timesteps
388
+ if discretization == "leading":
389
+ steps = paddle.arange(t_min, t_max + 1, (t_max - t_min + 1) / steps).flip(0)
390
+ elif discretization == "linspace":
391
+ steps = paddle.linspace(t_max, t_min, steps)
392
+ elif discretization == "trailing":
393
+ steps = paddle.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps))
394
+ else:
395
+ raise NotImplementedError(f"{discretization} discretization not implemented")
396
+ steps = steps.clip_(t_min, t_max)
397
+ steps = paddle.to_tensor(steps, dtype=paddle.float32, place=noise.place)
398
+
399
+ # get sigmas
400
+ sigmas = self._t_to_sigma(steps)
401
+ sigmas = paddle.concat([sigmas, paddle.zeros([1]).astype(sigmas.dtype)])
402
+ if schedule == "karras":
403
+ if sigmas[0] == float("inf"):
404
+ sigmas = karras_schedule(
405
+ n=len(steps) - 1,
406
+ sigma_min=sigmas[sigmas > 0].min().item(),
407
+ sigma_max=sigmas[sigmas < float("inf")].max().item(),
408
+ rho=7.0,
409
+ ).to(sigmas)
410
+ sigmas = paddle.concat(
411
+ [sigmas.to_tensor([float("inf")]), sigmas, paddle.zeros([1]).astype(sigmas.dtype)]
412
+ )
413
+ else:
414
+ sigmas = karras_schedule(
415
+ n=len(steps), sigma_min=sigmas[sigmas > 0].min().item(), sigma_max=sigmas.max().item(), rho=7.0
416
+ ).to(sigmas)
417
+ sigmas = paddle.concat([sigmas, paddle.zeros([1]).astype(sigmas.dtype)])
418
+ if discard_penultimate_step:
419
+ sigmas = paddle.concat([sigmas[:-2], sigmas[-1:]])
420
+
421
+ # sampling
422
+ x0 = solver_fn(noise, model_fn, sigmas, show_progress=show_progress, **kwargs)
423
+ return (x0, intermediates) if return_intermediate is not None else x0
424
+
425
+ def _sigma_to_t(self, sigma):
426
+ if sigma == float("inf"):
427
+ t = paddle.full_like(sigma, len(self.sigmas) - 1)
428
+ else:
429
+ log_sigmas = paddle.sqrt(self.sigmas**2 / (1 - self.sigmas**2)).log().astype(sigma.dtype) # noqa
430
+ log_sigma = sigma.log()
431
+ dists = log_sigma - log_sigmas[:, None]
432
+
433
+ low_idx = dists.greater_equal(paddle.to_tensor(0, dtype=dists.dtype)).astype(dists.dtype)
434
+ low_idx = paddle.cumsum(low_idx, axis=0).argmax(axis=0).clip_(max=log_sigmas.shape[0] - 2)
435
+ high_idx = low_idx + 1
436
+ low, high = log_sigmas[low_idx], log_sigmas[high_idx]
437
+ w = (low - log_sigma) / (low - high)
438
+ w = w.clip_(0, 1)
439
+ t = (1 - w) * low_idx + w * high_idx
440
+ t = t.reshape(sigma.shape)
441
+ if t.ndim == 0:
442
+ t = t.unsqueeze(0)
443
+ return t
444
+
445
+ def _t_to_sigma(self, t):
446
+ t = t.astype("float32")
447
+ low_idx, high_idx, w = t.floor().astype("int64"), t.ceil().astype("int64"), t.frac()
448
+ log_sigmas = paddle.sqrt(self.sigmas**2 / (1 - self.sigmas**2)).log().astype(t.dtype) # noqa
449
+ log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
450
+ log_sigma[paddle.isnan(log_sigma) | paddle.isinf(log_sigma)] = float("inf")
451
+ return log_sigma.exp()
PaddleMIX/ppdiffusers/ppdiffusers/models/modelscope_st_unet_video2video.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn as nn
17
+ import paddle.nn.functional as F
18
+ from einops import rearrange
19
+
20
+ from ..configuration_utils import register_to_config
21
+ from .lvdm_util import avg_pool_nd
22
+ from .modelscope_st_unet import (
23
+ ResBlock,
24
+ SpatialTransformer,
25
+ STUNetModel,
26
+ STUNetOutput,
27
+ TemporalAttentionMultiBlock,
28
+ TemporalTransformer,
29
+ default,
30
+ prob_mask_like,
31
+ sinusoidal_embedding_paddle,
32
+ )
33
+
34
+ USE_TEMPORAL_TRANSFORMER = True
35
+
36
+
37
+ class Downsample(nn.Layer):
38
+ """
39
+ A downsampling layer with an optional convolution.
40
+ :param channels: channels in the inputs and outputs.
41
+ :param use_conv: a bool determining if a convolution is applied.
42
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
43
+ downsampling occurs in the inner-two dimensions.
44
+ """
45
+
46
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=(2, 1)):
47
+ super().__init__()
48
+ self.channels = channels
49
+ self.out_channels = out_channels or channels
50
+ self.use_conv = use_conv
51
+ self.dims = dims
52
+ stride = 2 if dims != 3 else (1, 2, 2)
53
+ if use_conv:
54
+ self.op = nn.Conv2D(self.channels, self.out_channels, 3, stride=stride, padding=padding)
55
+ else:
56
+ assert self.channels == self.out_channels
57
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
58
+
59
+ def forward(self, x):
60
+ assert x.shape[1] == self.channels
61
+ return self.op(x)
62
+
63
+
64
+ class Upsample(nn.Layer):
65
+ """
66
+ An upsampling layer with an optional convolution.
67
+ :param channels: channels in the inputs and outputs.
68
+ :param use_conv: a bool determining if a convolution is applied.
69
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
70
+ upsampling occurs in the inner-two dimensions.
71
+ """
72
+
73
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.out_channels = out_channels or channels
77
+ self.use_conv = use_conv
78
+ self.dims = dims
79
+ if use_conv:
80
+ self.conv = nn.Conv2D(self.channels, self.out_channels, 3, padding=padding)
81
+
82
+ def forward(self, x):
83
+ assert x.shape[1] == self.channels
84
+ if self.dims == 3:
85
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
86
+ else:
87
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
88
+ x = x[..., 1:-1, :]
89
+ if self.use_conv:
90
+ x = self.conv(x)
91
+ return x
92
+
93
+
94
+ class Vid2VidSTUNet(STUNetModel):
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ in_channels=4,
99
+ out_channels=4,
100
+ dim=320,
101
+ y_dim=1024,
102
+ context_channels=1024,
103
+ dim_mult=[1, 2, 4, 4],
104
+ num_heads=8,
105
+ head_dim=64,
106
+ num_res_blocks=2,
107
+ attn_scales=[1 / 1, 1 / 2, 1 / 4],
108
+ use_scale_shift_norm=True,
109
+ dropout=0.1,
110
+ temporal_attn_times=1,
111
+ temporal_attention=True,
112
+ use_checkpoint=True,
113
+ use_image_dataset=False,
114
+ use_fps_condition=False,
115
+ use_sim_mask=False,
116
+ training=False,
117
+ inpainting=True,
118
+ **kwargs
119
+ ):
120
+ super(Vid2VidSTUNet, self).__init__(
121
+ in_channels=in_channels,
122
+ out_channels=out_channels,
123
+ dim=dim,
124
+ y_dim=y_dim,
125
+ context_channels=context_channels,
126
+ dim_mult=dim_mult,
127
+ num_heads=num_heads,
128
+ head_dim=head_dim,
129
+ num_res_blocks=num_res_blocks,
130
+ attn_scales=attn_scales,
131
+ use_scale_shift_norm=use_scale_shift_norm,
132
+ dropout=dropout,
133
+ temporal_attn_times=temporal_attn_times,
134
+ temporal_attention=temporal_attention,
135
+ )
136
+ embed_dim = dim * 4
137
+ num_heads = num_heads if num_heads else dim // 32
138
+ self.in_dim = in_channels
139
+ self.dim = dim
140
+ self.y_dim = y_dim
141
+ self.context_dim = context_channels
142
+ self.embed_dim = embed_dim
143
+ self.out_dim = out_channels
144
+ self.dim_mult = dim_mult
145
+ # for temporal attention
146
+ self.num_heads = num_heads
147
+ # for spatial attention
148
+ self.head_dim = head_dim
149
+ self.num_res_blocks = num_res_blocks
150
+ self.attn_scales = attn_scales
151
+ self.use_scale_shift_norm = use_scale_shift_norm
152
+ self.temporal_attn_times = temporal_attn_times
153
+ self.temporal_attention = temporal_attention
154
+ self.inpainting = inpainting
155
+ self.use_fps_condition = use_fps_condition
156
+
157
+ use_linear_in_temporal = False
158
+ transformer_depth = 1
159
+ disabled_sa = False
160
+
161
+ enc_dims = [dim * u for u in [1] + dim_mult]
162
+ dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
163
+ shortcut_dims = []
164
+ scale = 1.0
165
+
166
+ if self.use_fps_condition:
167
+ self.fps_embedding = nn.Sequential(
168
+ nn.Linear(dim, embed_dim),
169
+ nn.Silu(),
170
+ nn.Linear(
171
+ embed_dim,
172
+ embed_dim,
173
+ weight_attr=nn.initializer.Constant(value=0.0),
174
+ bias_attr=nn.initializer.Constant(value=0.0),
175
+ ),
176
+ )
177
+
178
+ # encoder
179
+ self.input_blocks = nn.LayerList()
180
+ init_block = nn.LayerList([nn.Conv2D(self.in_dim, dim, 3, padding=1)])
181
+ if temporal_attention:
182
+ if USE_TEMPORAL_TRANSFORMER:
183
+ init_block.append(
184
+ TemporalTransformer(
185
+ dim,
186
+ num_heads,
187
+ head_dim,
188
+ depth=transformer_depth,
189
+ context_dim=context_channels,
190
+ disable_self_attn=disabled_sa,
191
+ use_linear=use_linear_in_temporal,
192
+ multiply_zero=use_image_dataset,
193
+ )
194
+ )
195
+ else:
196
+ init_block.append(
197
+ TemporalAttentionMultiBlock(
198
+ dim,
199
+ num_heads,
200
+ head_dim,
201
+ rotary_emb=self.rotary_emb,
202
+ temporal_attn_times=temporal_attn_times,
203
+ use_image_dataset=use_image_dataset,
204
+ )
205
+ )
206
+
207
+ self.input_blocks.append(init_block)
208
+ shortcut_dims.append(dim)
209
+ for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
210
+ for j in range(num_res_blocks):
211
+ block = nn.LayerList(
212
+ [
213
+ ResBlock(
214
+ in_dim,
215
+ embed_dim,
216
+ dropout,
217
+ out_channels=out_dim,
218
+ use_scale_shift_norm=False,
219
+ use_image_dataset=use_image_dataset,
220
+ )
221
+ ]
222
+ )
223
+ if scale in attn_scales:
224
+ block.append(
225
+ SpatialTransformer(
226
+ out_dim,
227
+ out_dim // head_dim,
228
+ head_dim,
229
+ depth=1,
230
+ context_dim=self.context_dim,
231
+ disable_self_attn=False,
232
+ use_linear=True,
233
+ )
234
+ )
235
+ if self.temporal_attention:
236
+ if USE_TEMPORAL_TRANSFORMER:
237
+ block.append(
238
+ TemporalTransformer(
239
+ out_dim,
240
+ out_dim // head_dim,
241
+ head_dim,
242
+ depth=transformer_depth,
243
+ context_dim=context_channels,
244
+ disable_self_attn=disabled_sa,
245
+ use_linear=use_linear_in_temporal,
246
+ multiply_zero=use_image_dataset,
247
+ )
248
+ )
249
+ else:
250
+ block.append(
251
+ TemporalAttentionMultiBlock(
252
+ out_dim,
253
+ num_heads,
254
+ head_dim,
255
+ rotary_emb=self.rotary_emb,
256
+ use_image_dataset=use_image_dataset,
257
+ use_sim_mask=use_sim_mask,
258
+ temporal_attn_times=temporal_attn_times,
259
+ )
260
+ )
261
+ in_dim = out_dim
262
+ self.input_blocks.append(block)
263
+ shortcut_dims.append(out_dim)
264
+
265
+ # downsample
266
+ if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
267
+ downsample = Downsample(out_dim, True, dims=2, out_channels=out_dim)
268
+ shortcut_dims.append(out_dim)
269
+ scale /= 2.0
270
+ self.input_blocks.append(downsample)
271
+
272
+ # decoder
273
+ self.output_blocks = nn.LayerList()
274
+ for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
275
+ for j in range(num_res_blocks + 1):
276
+ block = nn.LayerList(
277
+ [
278
+ ResBlock(
279
+ in_dim + shortcut_dims.pop(),
280
+ embed_dim,
281
+ dropout,
282
+ out_dim,
283
+ use_scale_shift_norm=False,
284
+ use_image_dataset=use_image_dataset,
285
+ )
286
+ ]
287
+ )
288
+ if scale in attn_scales:
289
+ block.append(
290
+ SpatialTransformer(
291
+ out_dim,
292
+ out_dim // head_dim,
293
+ head_dim,
294
+ depth=1,
295
+ context_dim=1024,
296
+ disable_self_attn=False,
297
+ use_linear=True,
298
+ )
299
+ )
300
+
301
+ if self.temporal_attention:
302
+ if USE_TEMPORAL_TRANSFORMER:
303
+ block.append(
304
+ TemporalTransformer(
305
+ out_dim,
306
+ out_dim // head_dim,
307
+ head_dim,
308
+ depth=transformer_depth,
309
+ context_dim=context_channels,
310
+ disable_self_attn=disabled_sa,
311
+ use_linear=use_linear_in_temporal,
312
+ multiply_zero=use_image_dataset,
313
+ )
314
+ )
315
+ else:
316
+ block.append(
317
+ TemporalAttentionMultiBlock(
318
+ out_dim,
319
+ num_heads,
320
+ head_dim,
321
+ rotary_emb=self.rotary_emb,
322
+ use_image_dataset=use_image_dataset,
323
+ use_sim_mask=use_sim_mask,
324
+ temporal_attn_times=temporal_attn_times,
325
+ )
326
+ )
327
+
328
+ in_dim = out_dim
329
+
330
+ # upsample
331
+ if i != len(dim_mult) - 1 and j == num_res_blocks:
332
+ upsample = Upsample(out_dim, True, dims=2, out_channels=out_dim)
333
+ scale *= 2.0
334
+ block.append(upsample)
335
+ self.output_blocks.append(block)
336
+
337
+ def forward(
338
+ self,
339
+ x,
340
+ t,
341
+ y,
342
+ x_lr=None,
343
+ fps=None,
344
+ video_mask=None,
345
+ focus_present_mask=None,
346
+ prob_focus_present=0.0,
347
+ mask_last_frame_num=0,
348
+ return_dict: bool = True,
349
+ **kwargs
350
+ ):
351
+ batch, x_c, x_f, x_h, x_w = x.shape
352
+ device = x.place
353
+ self.batch = batch
354
+
355
+ # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
356
+ if mask_last_frame_num > 0:
357
+ focus_present_mask = None
358
+ video_mask[-mask_last_frame_num:] = False
359
+ else:
360
+ focus_present_mask = default(
361
+ focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device=device)
362
+ )
363
+
364
+ time_rel_pos_bias = None
365
+
366
+ # embeddings
367
+ e = self.time_embed(sinusoidal_embedding_paddle(t, self.dim))
368
+ context = y
369
+
370
+ # repeat f times for spatial e and context
371
+ e = e.repeat_interleave(repeats=x_f, axis=0)
372
+ context = context.repeat_interleave(repeats=x_f, axis=0)
373
+
374
+ # always in shape (b f) c h w, except for temporal layer
375
+ x = rearrange(x, "b c f h w -> (b f) c h w")
376
+ # encoder
377
+ xs = []
378
+ for block in self.input_blocks:
379
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask)
380
+ xs.append(x)
381
+
382
+ # middle
383
+ for block in self.middle_block:
384
+ x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask)
385
+
386
+ # decoder
387
+ for block in self.output_blocks:
388
+ x = paddle.concat([x, xs.pop()], axis=1)
389
+ x = self._forward_single(
390
+ block,
391
+ x,
392
+ e,
393
+ context,
394
+ time_rel_pos_bias,
395
+ focus_present_mask,
396
+ video_mask,
397
+ reference=xs[-1] if len(xs) > 0 else None,
398
+ )
399
+
400
+ # head
401
+ x = self.out(x)
402
+
403
+ # reshape back to (b c f h w)
404
+ sample = rearrange(x, "(b f) c h w -> b c f h w", b=batch)
405
+
406
+ if not return_dict:
407
+ return (sample,)
408
+
409
+ return STUNetOutput(sample=sample)
PaddleMIX/ppdiffusers/ppdiffusers/models/prior_transformer.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Dict, Optional, Union
17
+
18
+ import paddle
19
+ import paddle.nn.functional as F
20
+ from paddle import nn
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..loaders import UNet2DConditionLoadersMixin
24
+ from ..utils import BaseOutput
25
+ from .attention import BasicTransformerBlock
26
+ from .attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from .embeddings import TimestepEmbedding, Timesteps
34
+ from .modeling_utils import ModelMixin
35
+
36
+
37
+ @dataclass
38
+ class PriorTransformerOutput(BaseOutput):
39
+ """
40
+ The output of [`PriorTransformer`].
41
+
42
+ Args:
43
+ predicted_image_embedding (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
44
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
45
+ """
46
+
47
+ predicted_image_embedding: paddle.Tensor
48
+
49
+
50
+ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
51
+ """
52
+ A Prior Transformer model.
53
+
54
+ Parameters:
55
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
56
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
57
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
58
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
59
+ num_embeddings (`int`, *optional*, defaults to 77):
60
+ The number of embeddings of the model input `hidden_states`
61
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
62
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
63
+ additional_embeddings`.
64
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
65
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
66
+ The activation function to use to create timestep embeddings.
67
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
68
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
69
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
70
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
71
+ needed.
72
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
73
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
74
+ `encoder_hidden_states` is `None`.
75
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
76
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
77
+ product between the text embedding and image embedding as proposed in the unclip paper
78
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
79
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
80
+ If None, will be set to `num_attention_heads * attention_head_dim`
81
+ embedding_proj_dim (`int`, *optional*, default to None):
82
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
83
+ clip_embed_dim (`int`, *optional*, default to None):
84
+ The dimension of the output. If None, will be set to `embedding_dim`.
85
+ """
86
+
87
+ @register_to_config
88
+ def __init__(
89
+ self,
90
+ num_attention_heads: int = 32,
91
+ attention_head_dim: int = 64,
92
+ num_layers: int = 20,
93
+ embedding_dim: int = 768,
94
+ num_embeddings=77,
95
+ additional_embeddings=4,
96
+ dropout: float = 0.0,
97
+ time_embed_act_fn: str = "silu",
98
+ norm_in_type: Optional[str] = None, # layer
99
+ embedding_proj_norm_type: Optional[str] = None, # layer
100
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
101
+ added_emb_type: Optional[str] = "prd", # prd
102
+ time_embed_dim: Optional[int] = None,
103
+ embedding_proj_dim: Optional[int] = None,
104
+ clip_embed_dim: Optional[int] = None,
105
+ ):
106
+ super().__init__()
107
+ self.num_attention_heads = num_attention_heads
108
+ self.attention_head_dim = attention_head_dim
109
+ inner_dim = num_attention_heads * attention_head_dim
110
+ self.additional_embeddings = additional_embeddings
111
+
112
+ time_embed_dim = time_embed_dim or inner_dim
113
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
114
+ clip_embed_dim = clip_embed_dim or embedding_dim
115
+
116
+ self.time_proj = Timesteps(inner_dim, True, 0)
117
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
118
+
119
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
120
+
121
+ if embedding_proj_norm_type is None:
122
+ self.embedding_proj_norm = None
123
+ elif embedding_proj_norm_type == "layer":
124
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
125
+ else:
126
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
127
+
128
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
129
+
130
+ if encoder_hid_proj_type is None:
131
+ self.encoder_hidden_states_proj = None
132
+ elif encoder_hid_proj_type == "linear":
133
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
134
+ else:
135
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
136
+
137
+ self.positional_embedding = nn.Parameter(paddle.zeros([1, num_embeddings + additional_embeddings, inner_dim]))
138
+
139
+ if added_emb_type == "prd":
140
+ self.prd_embedding = nn.Parameter(paddle.zeros([1, 1, inner_dim]))
141
+ elif added_emb_type is None:
142
+ self.prd_embedding = None
143
+ else:
144
+ raise ValueError(
145
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
146
+ )
147
+
148
+ self.transformer_blocks = nn.LayerList(
149
+ [
150
+ BasicTransformerBlock(
151
+ inner_dim,
152
+ num_attention_heads,
153
+ attention_head_dim,
154
+ dropout=dropout,
155
+ activation_fn="gelu",
156
+ attention_bias=True,
157
+ )
158
+ for d in range(num_layers)
159
+ ]
160
+ )
161
+
162
+ if norm_in_type == "layer":
163
+ self.norm_in = nn.LayerNorm(inner_dim)
164
+ elif norm_in_type is None:
165
+ self.norm_in = None
166
+ else:
167
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
168
+
169
+ self.norm_out = nn.LayerNorm(inner_dim)
170
+
171
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
172
+
173
+ causal_attention_mask = paddle.triu(
174
+ paddle.full([num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -1e4), 1
175
+ )
176
+ causal_attention_mask = causal_attention_mask[None, ...]
177
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistable=False)
178
+
179
+ self.clip_mean = nn.Parameter(paddle.zeros([1, clip_embed_dim]))
180
+ self.clip_std = nn.Parameter(paddle.zeros([1, clip_embed_dim]))
181
+
182
+ @property
183
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
184
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
185
+ r"""
186
+ Returns:
187
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
188
+ indexed by its weight name.
189
+ """
190
+ # set recursively
191
+ processors = {}
192
+
193
+ def fn_recursive_add_processors(name: str, module: nn.Layer, processors: Dict[str, AttentionProcessor]):
194
+ if hasattr(module, "get_processor"):
195
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
196
+
197
+ for sub_name, child in module.named_children():
198
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
199
+
200
+ return processors
201
+
202
+ for name, module in self.named_children():
203
+ fn_recursive_add_processors(name, module, processors)
204
+
205
+ return processors
206
+
207
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
208
+ def set_attn_processor(
209
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
210
+ ):
211
+ r"""
212
+ Sets the attention processor to use to compute attention.
213
+
214
+ Parameters:
215
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
216
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
217
+ for **all** `Attention` layers.
218
+
219
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
220
+ processor. This is strongly recommended when setting trainable attention processors.
221
+
222
+ """
223
+ count = len(self.attn_processors.keys())
224
+
225
+ if isinstance(processor, dict) and len(processor) != count:
226
+ raise ValueError(
227
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
228
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
229
+ )
230
+
231
+ def fn_recursive_attn_processor(name: str, module: nn.Layer, processor):
232
+ if hasattr(module, "set_processor"):
233
+ if not isinstance(processor, dict):
234
+ module.set_processor(processor, _remove_lora=_remove_lora)
235
+ else:
236
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
237
+
238
+ for sub_name, child in module.named_children():
239
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
240
+
241
+ for name, module in self.named_children():
242
+ fn_recursive_attn_processor(name, module, processor)
243
+
244
+ # Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
245
+ def set_default_attn_processor(self):
246
+ """
247
+ Disables custom attention processors and sets the default attention implementation.
248
+ """
249
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
250
+ processor = AttnAddedKVProcessor()
251
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
252
+ processor = AttnProcessor()
253
+ else:
254
+ raise ValueError(
255
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
256
+ )
257
+
258
+ self.set_attn_processor(processor, _remove_lora=True)
259
+
260
+ def forward(
261
+ self,
262
+ hidden_states,
263
+ timestep: Union[paddle.Tensor, float, int],
264
+ proj_embedding: paddle.Tensor,
265
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
266
+ attention_mask: Optional[paddle.Tensor] = None,
267
+ return_dict: bool = True,
268
+ ):
269
+ """
270
+ The [`PriorTransformer`] forward method.
271
+
272
+ Args:
273
+ hidden_states (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
274
+ The currently predicted image embeddings.
275
+ timestep (`paddle.Tensor`):
276
+ Current denoising step.
277
+ proj_embedding (`paddle.Tensor` of shape `(batch_size, embedding_dim)`):
278
+ Projected embedding vector the denoising process is conditioned on.
279
+ encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
280
+ Hidden states of the text embeddings the denoising process is conditioned on.
281
+ attention_mask (`paddle.Tensor` of shape `(batch_size, num_embeddings)`):
282
+ Text mask for the text embeddings.
283
+ return_dict (`bool`, *optional*, defaults to `True`):
284
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
285
+ tuple.
286
+
287
+ Returns:
288
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
289
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
290
+ tuple is returned where the first element is the sample tensor.
291
+ """
292
+ # TODO junnyu, add this to support pure fp16
293
+ hidden_states = hidden_states.cast(self.dtype)
294
+ batch_size = hidden_states.shape[0]
295
+
296
+ timesteps = timestep
297
+ if not paddle.is_tensor(timesteps):
298
+ timesteps = paddle.to_tensor([timesteps], dtype=paddle.int64)
299
+ elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0:
300
+ timesteps = timesteps[None]
301
+
302
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
303
+ timesteps = timesteps * paddle.ones((batch_size,), dtype=timesteps.dtype)
304
+
305
+ timesteps_projected = self.time_proj(timesteps)
306
+
307
+ # timesteps does not contain any weights and will always return f32 tensors
308
+ # but time_embedding might be fp16, so we need to cast here.
309
+ timesteps_projected = timesteps_projected.cast(hidden_states.dtype)
310
+ time_embeddings = self.time_embedding(timesteps_projected)
311
+
312
+ if self.embedding_proj_norm is not None:
313
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
314
+
315
+ proj_embeddings = self.embedding_proj(proj_embedding)
316
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
317
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
318
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
319
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
320
+
321
+ hidden_states = self.proj_in(hidden_states)
322
+
323
+ positional_embeddings = self.positional_embedding.cast(hidden_states.dtype)
324
+
325
+ additional_embeds = []
326
+ additional_embeddings_len = 0
327
+
328
+ if encoder_hidden_states is not None:
329
+ additional_embeds.append(encoder_hidden_states)
330
+ additional_embeddings_len += encoder_hidden_states.shape[1]
331
+
332
+ if len(proj_embeddings.shape) == 2:
333
+ proj_embeddings = proj_embeddings[:, None, :]
334
+
335
+ if len(hidden_states.shape) == 2:
336
+ hidden_states = hidden_states[:, None, :]
337
+
338
+ additional_embeds = additional_embeds + [
339
+ proj_embeddings,
340
+ time_embeddings[:, None, :],
341
+ hidden_states,
342
+ ]
343
+
344
+ if self.prd_embedding is not None:
345
+ prd_embedding = self.prd_embedding.cast(hidden_states.dtype).expand([batch_size, -1, -1])
346
+ additional_embeds.append(prd_embedding)
347
+
348
+ hidden_states = paddle.concat(
349
+ additional_embeds,
350
+ axis=1,
351
+ )
352
+
353
+ # Allow positional_embedding to not include the `additional_embeddings` and instead pad it with zeros for these additional tokens
354
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
355
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
356
+ positional_embeddings = F.pad(
357
+ positional_embeddings,
358
+ (
359
+ additional_embeddings_len,
360
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
361
+ ),
362
+ value=0.0,
363
+ data_format="NLC",
364
+ )
365
+
366
+ hidden_states = hidden_states + positional_embeddings
367
+
368
+ if attention_mask is not None:
369
+ attention_mask = (1 - attention_mask.cast(hidden_states.dtype)) * -1e4
370
+ attention_mask = F.pad(
371
+ attention_mask.unsqueeze(0), (0, self.additional_embeddings), value=0.0, data_format="NCL"
372
+ ).squeeze(0)
373
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).cast(hidden_states.dtype)
374
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, axis=0)
375
+
376
+ if self.norm_in is not None:
377
+ hidden_states = self.norm_in(hidden_states)
378
+
379
+ for block in self.transformer_blocks:
380
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
381
+
382
+ hidden_states = self.norm_out(hidden_states)
383
+
384
+ if self.prd_embedding is not None:
385
+ hidden_states = hidden_states[:, -1]
386
+ else:
387
+ hidden_states = hidden_states[:, additional_embeddings_len:]
388
+
389
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
390
+
391
+ if not return_dict:
392
+ return (predicted_image_embedding,)
393
+
394
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
395
+
396
+ def post_process_latents(self, prior_latents):
397
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
398
+ return prior_latents
PaddleMIX/ppdiffusers/ppdiffusers/models/simplified_sd3.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import paddle
16
+ import paddle.nn.functional as F
17
+ from paddle import nn
18
+ from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear as CPLinear
19
+ from paddle.distributed.fleet.meta_parallel import RowParallelLinear as RPLinear
20
+ from paddle.nn import LayerList as LayerList
21
+
22
+
23
+ class SimplifiedSD3(nn.Layer):
24
+ def __init__(self, num_layers: int, dim: int, num_attention_heads: int, attention_head_dim: int, mp_degree: int):
25
+ super().__init__()
26
+ self.num_layers = num_layers
27
+ self.dim = dim
28
+ self.head_dim = 64
29
+
30
+ self.mp_degree = mp_degree
31
+
32
+ self.silu = nn.Silu()
33
+ self.linear1 = LayerList([nn.Linear(self.dim, 6 * self.dim) for i in range(num_layers)])
34
+ self.linear_context = LayerList(
35
+ [nn.Linear(self.dim, (6 if i < num_layers - 1 else 2) * self.dim) for i in range(num_layers)]
36
+ )
37
+ self.norm_last_context = nn.LayerNorm(self.dim, epsilon=1e-6, weight_attr=False, bias_attr=True)
38
+
39
+ if mp_degree > 1:
40
+ self.qkv_mp = LayerList(
41
+ [CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]
42
+ )
43
+ self.eqkv_mp = LayerList(
44
+ [CPLinear(self.dim, 3 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]
45
+ )
46
+ self.to_out_linear_mp = LayerList(
47
+ [RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]
48
+ )
49
+ # When using Model Parallel, for the symmetry of GEMM, we change num_layers-1 here to num_layers, which has no effect on the results.
50
+ self.to_add_out_linear_mp = LayerList(
51
+ [RPLinear(self.dim, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]
52
+ )
53
+
54
+ self.ffn1_mp = LayerList(
55
+ [CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers)]
56
+ )
57
+ self.ffn2_mp = LayerList(
58
+ [RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True) for i in range(num_layers)]
59
+ )
60
+ self.ffn1_context_mp = LayerList(
61
+ [CPLinear(self.dim, 4 * self.dim, gather_output=False, has_bias=True) for i in range(num_layers - 1)]
62
+ )
63
+ self.ffn2_context_mp = LayerList(
64
+ [
65
+ RPLinear(self.dim * 4, self.dim, input_is_parallel=True, has_bias=True)
66
+ for i in range(num_layers - 1)
67
+ ]
68
+ )
69
+ else:
70
+ self.qkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
71
+ self.eqkv = LayerList([nn.Linear(self.dim, self.dim * 3) for i in range(num_layers)])
72
+ self.to_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])
73
+ # When using Model Parallel, for the symmetry of GEMM, we change num_layers-1 here to num_layers, which has no effect on the results.
74
+ self.to_add_out_linear = LayerList([nn.Linear(self.dim, self.dim) for i in range(num_layers)])
75
+
76
+ self.ffn1 = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers)])
77
+ self.ffn2 = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers)])
78
+ self.ffn1_context = LayerList([nn.Linear(self.dim, self.dim * 4) for i in range(num_layers - 1)])
79
+ self.ffn2_context = LayerList([nn.Linear(self.dim * 4, self.dim) for i in range(num_layers - 1)])
80
+
81
+ def forward(self, hidden_states, encoder_hidden_states, temb):
82
+ print("--------------------this is simplified_sd3------------------------")
83
+ temb_silu = self.silu(temb)
84
+
85
+ last_ffn_output = None
86
+ last_hidden_states = None
87
+ last_gate_mlp = None
88
+
89
+ last_context_ffn_output = None
90
+ last_context_hidden_states = None
91
+ last_context_gate_mlp = None
92
+
93
+ seq1 = hidden_states.shape[1]
94
+ seq2 = encoder_hidden_states.shape[1]
95
+
96
+ for i in range(self.num_layers):
97
+ context_pre_only = i == self.num_layers - 1
98
+
99
+ emb = self.linear1[i](temb_silu)
100
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, axis=1)
101
+
102
+ import paddlemix
103
+
104
+ if last_ffn_output is None:
105
+ norm_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(
106
+ hidden_states, scale_msa, shift_msa, epsilon=1e-06
107
+ )
108
+ else:
109
+ hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
110
+ last_hidden_states, last_ffn_output, last_gate_mlp, scale_msa, shift_msa, epsilon=1e-06
111
+ )
112
+
113
+ emb = self.linear_context[i](temb_silu)
114
+ if not context_pre_only:
115
+ shift_msa, scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = emb.chunk(6, axis=1)
116
+ if last_context_ffn_output is None:
117
+ norm_encoder_hidden_states = paddlemix.triton_ops.adaptive_layer_norm(
118
+ encoder_hidden_states, scale_msa, shift_msa, epsilon=1e-06
119
+ )
120
+ else:
121
+ (
122
+ encoder_hidden_states,
123
+ norm_encoder_hidden_states,
124
+ ) = paddlemix.triton_ops.fused_adaLN_scale_residual(
125
+ last_context_hidden_states,
126
+ last_context_ffn_output,
127
+ last_context_gate_mlp,
128
+ scale_msa,
129
+ shift_msa,
130
+ epsilon=1e-06,
131
+ )
132
+ else:
133
+ # the last layer.
134
+ scale, shift = paddle.chunk(emb, 2, axis=1)
135
+ (encoder_hidden_states, norm_encoder_hidden_states,) = paddlemix.triton_ops.fused_adaLN_scale_residual(
136
+ last_context_hidden_states,
137
+ last_context_ffn_output,
138
+ last_context_gate_mlp,
139
+ scale,
140
+ shift,
141
+ epsilon=1e-06,
142
+ )
143
+
144
+ if self.mp_degree > 1:
145
+ qkv = self.qkv_mp[i](norm_hidden_states)
146
+ eqkv = self.eqkv_mp[i](norm_encoder_hidden_states)
147
+
148
+ else:
149
+ qkv = self.qkv[i](norm_hidden_states)
150
+ eqkv = self.eqkv[i](norm_encoder_hidden_states)
151
+
152
+ q, k, v = paddlemix.triton_ops.split_concat(qkv, eqkv)
153
+
154
+ bs = hidden_states.shape[0]
155
+ head_nums = q.shape[2] // self.head_dim
156
+ q = q.reshape([bs, -1, head_nums, self.head_dim])
157
+ k = k.reshape([bs, -1, head_nums, self.head_dim])
158
+ v = v.reshape([bs, -1, head_nums, self.head_dim])
159
+
160
+ norm_hidden_states1 = F.scaled_dot_product_attention_(q, k, v, dropout_p=0.0, is_causal=False)
161
+ norm_hidden_states1 = norm_hidden_states1.reshape([bs, -1, head_nums * self.head_dim])
162
+ attn_output, context_attn_output = paddle.split(norm_hidden_states1, num_or_sections=[seq1, seq2], axis=1)
163
+
164
+ # attn_output, context_attn_output = paddlemix.triton_ops.triton_split(
165
+ # norm_hidden_states1, num_or_sections=[1024, 154], axis=1
166
+ # )
167
+
168
+ if self.mp_degree > 1:
169
+ attn_output = self.to_out_linear_mp[i](attn_output)
170
+ context_attn_output = self.to_add_out_linear_mp[i](context_attn_output)
171
+ else:
172
+ attn_output = self.to_out_linear[i](attn_output)
173
+ context_attn_output = self.to_add_out_linear[i](context_attn_output)
174
+
175
+ hidden_states, norm_hidden_states = paddlemix.triton_ops.fused_adaLN_scale_residual(
176
+ hidden_states, attn_output, gate_msa, scale_mlp, shift_mlp, epsilon=1e-06
177
+ )
178
+
179
+ # ffn1
180
+ if self.mp_degree > 1:
181
+ ffn_output = self.ffn1_mp[i](norm_hidden_states)
182
+ ffn_output = F.gelu(ffn_output, approximate=True)
183
+ ffn_output = self.ffn2_mp[i](ffn_output)
184
+ else:
185
+ ffn_output = self.ffn1[i](norm_hidden_states)
186
+ ffn_output = F.gelu(ffn_output, approximate=True)
187
+ ffn_output = self.ffn2[i](ffn_output)
188
+
189
+ if context_pre_only:
190
+ ffn_output = gate_mlp.unsqueeze(1) * ffn_output
191
+ hidden_states = hidden_states + ffn_output
192
+ else:
193
+ last_ffn_output = ffn_output
194
+ last_hidden_states = hidden_states
195
+ last_gate_mlp = gate_mlp
196
+
197
+ # ffn2
198
+ if not context_pre_only:
199
+ (encoder_hidden_states, norm_encoder_hidden_states,) = paddlemix.triton_ops.fused_adaLN_scale_residual(
200
+ encoder_hidden_states, context_attn_output, c_gate_msa, c_scale_mlp, c_shift_mlp, epsilon=1e-06
201
+ )
202
+
203
+ if self.mp_degree > 1:
204
+ context_ffn_output = self.ffn1_context_mp[i](norm_encoder_hidden_states)
205
+ context_ffn_output = F.gelu(context_ffn_output, approximate=True)
206
+ context_ffn_output = self.ffn2_context_mp[i](context_ffn_output)
207
+ else:
208
+ context_ffn_output = self.ffn1_context[i](norm_encoder_hidden_states)
209
+ context_ffn_output = F.gelu(context_ffn_output, approximate=True)
210
+ context_ffn_output = self.ffn2_context[i](context_ffn_output)
211
+
212
+ last_context_ffn_output = context_ffn_output
213
+ last_context_hidden_states = encoder_hidden_states
214
+ last_context_gate_mlp = c_gate_mlp
215
+
216
+ return hidden_states
PaddleMIX/ppdiffusers/ppdiffusers/models/transformer_2d.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional
17
+
18
+ import paddle
19
+ import paddle.nn.functional as F
20
+ from paddle import nn
21
+ from paddle.distributed.fleet.utils import recompute
22
+
23
+ from ..configuration_utils import ConfigMixin, register_to_config
24
+ from ..models.embeddings import ImagePositionalEmbeddings
25
+ from ..utils import (
26
+ USE_PEFT_BACKEND,
27
+ BaseOutput,
28
+ deprecate,
29
+ recompute_use_reentrant,
30
+ use_old_recompute,
31
+ )
32
+ from .attention import BasicTransformerBlock
33
+ from .embeddings import CaptionProjection, PatchEmbed
34
+ from .lora import LoRACompatibleConv, LoRACompatibleLinear
35
+ from .modeling_utils import ModelMixin
36
+ from .normalization import AdaLayerNormSingle
37
+ from .simplified_facebook_dit import SimplifiedFacebookDIT
38
+
39
+
40
+ @dataclass
41
+ class Transformer2DModelOutput(BaseOutput):
42
+ """
43
+ The output of [`Transformer2DModel`].
44
+
45
+ Args:
46
+ sample (`paddle.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
47
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
48
+ distributions for the unnoised latent pixels.
49
+ """
50
+
51
+ sample: paddle.Tensor
52
+
53
+
54
+ class Transformer2DModel(ModelMixin, ConfigMixin):
55
+ """
56
+ A 2D Transformer model for image-like data.
57
+
58
+ Parameters:
59
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
60
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
61
+ in_channels (`int`, *optional*):
62
+ The number of channels in the input and output (specify if the input is **continuous**).
63
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
64
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
65
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
66
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
67
+ This is fixed during training since it is used to learn a number of position embeddings.
68
+ num_vector_embeds (`int`, *optional*):
69
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
70
+ Includes the class for the masked latent pixel.
71
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
72
+ num_embeds_ada_norm ( `int`, *optional*):
73
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
74
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
75
+ added to the hidden states.
76
+
77
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
78
+ attention_bias (`bool`, *optional*):
79
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
80
+ """
81
+
82
+ _supports_gradient_checkpointing = True
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ num_attention_heads: int = 16,
88
+ attention_head_dim: int = 88,
89
+ in_channels: Optional[int] = None,
90
+ out_channels: Optional[int] = None,
91
+ num_layers: int = 1,
92
+ dropout: float = 0.0,
93
+ norm_num_groups: int = 32,
94
+ cross_attention_dim: Optional[int] = None,
95
+ attention_bias: bool = False,
96
+ sample_size: Optional[int] = None,
97
+ num_vector_embeds: Optional[int] = None,
98
+ patch_size: Optional[int] = None,
99
+ activation_fn: str = "geglu",
100
+ num_embeds_ada_norm: Optional[int] = None,
101
+ use_linear_projection: bool = False,
102
+ only_cross_attention: bool = False,
103
+ double_self_attention: bool = False,
104
+ upcast_attention: bool = False,
105
+ norm_type: str = "layer_norm",
106
+ norm_elementwise_affine: bool = True,
107
+ norm_eps: float = 1e-5,
108
+ attention_type: str = "default",
109
+ caption_channels: int = None,
110
+ data_format: str = "NCHW",
111
+ ):
112
+ super().__init__()
113
+ self.use_linear_projection = use_linear_projection
114
+ self.num_attention_heads = num_attention_heads
115
+ self.attention_head_dim = attention_head_dim
116
+ self.inner_dim = inner_dim = num_attention_heads * attention_head_dim
117
+ self.data_format = data_format
118
+
119
+ self.inference_optimize = os.getenv("INFERENCE_OPTIMIZE") == "True"
120
+
121
+ conv_cls = nn.Conv2D if USE_PEFT_BACKEND else LoRACompatibleConv
122
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
123
+
124
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
125
+ # Define whether input is continuous or discrete depending on configuration
126
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
127
+ self.is_input_vectorized = num_vector_embeds is not None
128
+ self.is_input_patches = in_channels is not None and patch_size is not None
129
+
130
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
131
+ deprecation_message = (
132
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
133
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
134
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
135
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
136
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
137
+ )
138
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
139
+ norm_type = "ada_norm"
140
+
141
+ if self.is_input_continuous and self.is_input_vectorized:
142
+ raise ValueError(
143
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
144
+ " sure that either `in_channels` or `num_vector_embeds` is None."
145
+ )
146
+ elif self.is_input_vectorized and self.is_input_patches:
147
+ raise ValueError(
148
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
149
+ " sure that either `num_vector_embeds` or `num_patches` is None."
150
+ )
151
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
152
+ raise ValueError(
153
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
154
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
155
+ )
156
+
157
+ # 2. Define input layers
158
+ if self.is_input_continuous:
159
+ self.in_channels = in_channels
160
+
161
+ self.norm = nn.GroupNorm(
162
+ num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6, data_format=data_format
163
+ )
164
+ if use_linear_projection:
165
+ self.proj_in = linear_cls(in_channels, inner_dim)
166
+ else:
167
+ self.proj_in = conv_cls(
168
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0, data_format=data_format
169
+ )
170
+ elif self.is_input_vectorized:
171
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
172
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
173
+
174
+ self.height = sample_size
175
+ self.width = sample_size
176
+ self.num_vector_embeds = num_vector_embeds
177
+ self.num_latent_pixels = self.height * self.width
178
+
179
+ self.latent_image_embedding = ImagePositionalEmbeddings(
180
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
181
+ )
182
+ elif self.is_input_patches:
183
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
184
+
185
+ self.height = sample_size
186
+ self.width = sample_size
187
+
188
+ self.patch_size = patch_size
189
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
190
+ interpolation_scale = max(interpolation_scale, 1)
191
+ self.pos_embed = PatchEmbed(
192
+ height=sample_size,
193
+ width=sample_size,
194
+ patch_size=patch_size,
195
+ in_channels=in_channels,
196
+ embed_dim=inner_dim,
197
+ interpolation_scale=interpolation_scale,
198
+ data_format=data_format,
199
+ )
200
+
201
+ # 3. Define transformers blocks
202
+ self.transformer_blocks = nn.LayerList(
203
+ [
204
+ BasicTransformerBlock(
205
+ inner_dim,
206
+ num_attention_heads,
207
+ attention_head_dim,
208
+ dropout=dropout,
209
+ cross_attention_dim=cross_attention_dim,
210
+ activation_fn=activation_fn,
211
+ num_embeds_ada_norm=num_embeds_ada_norm,
212
+ attention_bias=attention_bias,
213
+ only_cross_attention=only_cross_attention,
214
+ double_self_attention=double_self_attention,
215
+ upcast_attention=upcast_attention,
216
+ norm_type=norm_type,
217
+ norm_elementwise_affine=norm_elementwise_affine,
218
+ norm_eps=norm_eps,
219
+ attention_type=attention_type,
220
+ )
221
+ for d in range(num_layers)
222
+ ]
223
+ )
224
+ if self.inference_optimize:
225
+ self.simplified_facebookdit = SimplifiedFacebookDIT(
226
+ num_layers, inner_dim, num_attention_heads, attention_head_dim
227
+ )
228
+
229
+ # 4. Define output layers
230
+ self.out_channels = in_channels if out_channels is None else out_channels
231
+ if self.is_input_continuous:
232
+ # TODO: should use out_channels for continuous projections
233
+ if use_linear_projection:
234
+ self.proj_out = linear_cls(inner_dim, in_channels)
235
+ else:
236
+ self.proj_out = conv_cls(
237
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0, data_format=data_format
238
+ )
239
+ elif self.is_input_vectorized:
240
+ self.norm_out = nn.LayerNorm(inner_dim)
241
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
242
+ elif self.is_input_patches and norm_type != "ada_norm_single":
243
+ norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False)
244
+ self.norm_out = nn.LayerNorm(inner_dim, epsilon=1e-6, **norm_elementwise_affine_kwargs)
245
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
246
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
247
+ elif self.is_input_patches and norm_type == "ada_norm_single":
248
+ norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False)
249
+ self.norm_out = nn.LayerNorm(inner_dim, epsilon=1e-6, **norm_elementwise_affine_kwargs)
250
+ self.scale_shift_table = nn.Parameter(paddle.randn([2, inner_dim]) / inner_dim**0.5)
251
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
252
+
253
+ # 5. PixArt-Alpha blocks.
254
+ self.adaln_single = None
255
+ self.use_additional_conditions = False
256
+ if norm_type == "ada_norm_single":
257
+ self.use_additional_conditions = self.config.sample_size == 128
258
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
259
+ # additional conditions until we find better name
260
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
261
+
262
+ self.caption_projection = None
263
+ if caption_channels is not None:
264
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
265
+
266
+ self.gradient_checkpointing = False
267
+
268
+ def _set_gradient_checkpointing(self, module, value=False):
269
+ if hasattr(module, "gradient_checkpointing"):
270
+ module.gradient_checkpointing = value
271
+
272
+ def forward(
273
+ self,
274
+ hidden_states: paddle.Tensor,
275
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
276
+ timestep: Optional[paddle.Tensor] = None,
277
+ added_cond_kwargs: Dict[str, paddle.Tensor] = None,
278
+ class_labels: Optional[paddle.Tensor] = None,
279
+ cross_attention_kwargs: Dict[str, Any] = None,
280
+ attention_mask: Optional[paddle.Tensor] = None,
281
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
282
+ return_dict: bool = True,
283
+ ):
284
+ """
285
+ The [`Transformer2DModel`] forward method.
286
+
287
+ Args:
288
+ hidden_states (`paddle.Tensor` of shape `(batch size, num latent pixels)` if discrete, `paddle.Tensor` of shape `(batch size, channel, height, width)` if continuous):
289
+ Input `hidden_states`.
290
+ encoder_hidden_states ( `paddle.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
291
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
292
+ self-attention.
293
+ timestep ( `paddle.Tensor`, *optional*):
294
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
295
+ class_labels ( `paddle.Tensor` of shape `(batch size, num classes)`, *optional*):
296
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
297
+ `AdaLayerZeroNorm`.
298
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
299
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
300
+ `self.processor` in
301
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
302
+ attention_mask ( `paddle.Tensor`, *optional*):
303
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
304
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
305
+ negative values to the attention scores corresponding to "discard" tokens.
306
+ encoder_attention_mask ( `paddle.Tensor`, *optional*):
307
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
308
+
309
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
310
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
311
+
312
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
313
+ above. This bias will be added to the cross-attention scores.
314
+ return_dict (`bool`, *optional*, defaults to `True`):
315
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
316
+ tuple.
317
+
318
+ Returns:
319
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
320
+ `tuple` where the first element is the sample tensor.
321
+ """
322
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
323
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
324
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
325
+ # expects mask of shape:
326
+ # [batch, key_tokens]
327
+ # adds singleton query_tokens dimension:
328
+ # [batch, 1, key_tokens]
329
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
330
+ # [batch, query_tokens, heads, key_tokens] (e.g. paddle sdp or ppxformers attn)
331
+ # [batch, heads, query_tokens, key_tokens] (e.g. classic attn)
332
+ # pure fp16
333
+ hidden_states = hidden_states.cast(self.dtype)
334
+ if attention_mask is not None and attention_mask.ndim == 2:
335
+ # assume that mask is expressed as:
336
+ # (1 = keep, 0 = discard)
337
+ # convert mask into a bias that can be added to attention scores:
338
+ # (keep = +0, discard = -10000.0)
339
+ attention_mask = (1 - attention_mask.cast(hidden_states.dtype)) * -10000.0
340
+ attention_mask = attention_mask.unsqueeze(1)
341
+
342
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
343
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
344
+ encoder_attention_mask = (1 - encoder_attention_mask.cast(hidden_states.dtype)) * -10000.0
345
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
346
+
347
+ # Retrieve lora scale.
348
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
349
+
350
+ # 1. Input
351
+ if self.is_input_continuous:
352
+ if self.data_format == "NCHW":
353
+ # (NOTE,zhoukangkang paddle inference ) make hit paddle inference elementwiseadd_transpose_pass.
354
+ batch, _, height, width = hidden_states.shape
355
+ else:
356
+ batch, height, width, _ = hidden_states.shape
357
+ residual = hidden_states
358
+ shape = paddle.shape(hidden_states)
359
+ hidden_states = self.norm(hidden_states)
360
+ if not self.use_linear_projection:
361
+ hidden_states = (
362
+ self.proj_in(hidden_states, scale=lora_scale)
363
+ if not USE_PEFT_BACKEND
364
+ else self.proj_in(hidden_states)
365
+ )
366
+ if self.data_format == "NCHW":
367
+ hidden_states = hidden_states.transpose([0, 2, 3, 1]).flatten(1, 2)
368
+ else:
369
+ hidden_states = hidden_states.flatten(1, 2)
370
+ else:
371
+ if self.data_format == "NCHW":
372
+ hidden_states = hidden_states.transpose([0, 2, 3, 1]).flatten(1, 2)
373
+ else:
374
+ hidden_states = hidden_states.flatten(1, 2)
375
+ hidden_states = (
376
+ self.proj_in(hidden_states, scale=lora_scale)
377
+ if not USE_PEFT_BACKEND
378
+ else self.proj_in(hidden_states)
379
+ )
380
+
381
+ elif self.is_input_vectorized:
382
+ hidden_states = self.latent_image_embedding(hidden_states.cast("int64")) # NEW ADD
383
+ elif self.is_input_patches:
384
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
385
+ hidden_states = self.pos_embed(hidden_states)
386
+
387
+ if self.adaln_single is not None:
388
+ if self.use_additional_conditions and added_cond_kwargs is None:
389
+ raise ValueError(
390
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
391
+ )
392
+ batch_size = hidden_states.shape[0]
393
+ timestep, embedded_timestep = self.adaln_single(
394
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
395
+ )
396
+
397
+ # 2. Blocks
398
+ if self.caption_projection is not None:
399
+ batch_size = hidden_states.shape[0]
400
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
401
+ encoder_hidden_states = encoder_hidden_states.reshape([batch_size, -1, hidden_states.shape[-1]])
402
+
403
+ if self.inference_optimize:
404
+ hidden_states = self.simplified_facebookdit(hidden_states, timestep, class_labels)
405
+ else:
406
+ for block in self.transformer_blocks:
407
+ if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute():
408
+
409
+ def create_custom_forward(module, return_dict=None):
410
+ def custom_forward(*inputs):
411
+ if return_dict is not None:
412
+ return module(*inputs, return_dict=return_dict)
413
+ else:
414
+ return module(*inputs)
415
+
416
+ return custom_forward
417
+
418
+ ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
419
+ hidden_states = recompute(
420
+ create_custom_forward(block),
421
+ hidden_states,
422
+ attention_mask,
423
+ encoder_hidden_states,
424
+ encoder_attention_mask,
425
+ timestep,
426
+ cross_attention_kwargs,
427
+ class_labels,
428
+ **ckpt_kwargs,
429
+ )
430
+ else:
431
+ hidden_states = block(
432
+ hidden_states,
433
+ attention_mask=attention_mask,
434
+ encoder_hidden_states=encoder_hidden_states,
435
+ encoder_attention_mask=encoder_attention_mask,
436
+ timestep=timestep,
437
+ cross_attention_kwargs=cross_attention_kwargs,
438
+ class_labels=class_labels,
439
+ )
440
+
441
+ # 3. Output
442
+ if self.is_input_continuous:
443
+ if not self.use_linear_projection:
444
+ if self.data_format == "NCHW":
445
+ hidden_states = hidden_states.reshape([shape[0], shape[2], shape[3], self.inner_dim])
446
+ else:
447
+ hidden_states = hidden_states.reshape([shape[0], shape[1], shape[2], self.inner_dim])
448
+ if self.data_format == "NCHW":
449
+ hidden_states = hidden_states.transpose([0, 3, 1, 2])
450
+ hidden_states = (
451
+ self.proj_out(hidden_states, scale=lora_scale)
452
+ if not USE_PEFT_BACKEND
453
+ else self.proj_out(hidden_states)
454
+ )
455
+ else:
456
+ hidden_states = (
457
+ self.proj_out(hidden_states, scale=lora_scale)
458
+ if not USE_PEFT_BACKEND
459
+ else self.proj_out(hidden_states)
460
+ )
461
+ if self.data_format == "NCHW":
462
+ hidden_states = hidden_states.reshape([shape[0], shape[2], shape[3], self.inner_dim])
463
+ else:
464
+ hidden_states = hidden_states.reshape([shape[0], shape[1], shape[2], self.inner_dim])
465
+ if self.data_format == "NCHW":
466
+ hidden_states = hidden_states.transpose([0, 3, 1, 2])
467
+
468
+ output = hidden_states + residual
469
+ elif self.is_input_vectorized:
470
+ hidden_states = self.norm_out(hidden_states)
471
+ logits = self.out(hidden_states)
472
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
473
+ logits = logits.transpose([0, 2, 1])
474
+
475
+ # log(p(x_0))
476
+ output = F.log_softmax(logits.cast("float64"), axis=1).cast("float32")
477
+
478
+ if self.is_input_patches:
479
+ if self.config.norm_type != "ada_norm_single":
480
+ conditioning = self.transformer_blocks[0].norm1.emb(
481
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
482
+ )
483
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, axis=1)
484
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
485
+ hidden_states = self.proj_out_2(hidden_states)
486
+ elif self.config.norm_type == "ada_norm_single":
487
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1)
488
+ hidden_states = self.norm_out(hidden_states)
489
+ # Modulation
490
+ hidden_states = hidden_states * (1 + scale) + shift
491
+ hidden_states = self.proj_out(hidden_states)
492
+ hidden_states = hidden_states.squeeze(1)
493
+
494
+ # unpatchify
495
+ if self.adaln_single is None:
496
+ height = width = int(hidden_states.shape[1] ** 0.5)
497
+ hidden_states = hidden_states.reshape(
498
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
499
+ )
500
+ # hidden_states = paddle.einsum("nhwpqc->nchpwq", hidden_states)
501
+ hidden_states = hidden_states.transpose([0, 5, 1, 3, 2, 4])
502
+ output = hidden_states.reshape(
503
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
504
+ )
505
+
506
+ if not return_dict:
507
+ return (output,)
508
+
509
+ return Transformer2DModelOutput(sample=output)
510
+
511
+ @classmethod
512
+ def custom_modify_weight(cls, model_to_load, state_dict):
513
+ if not model_to_load.inference_optimize:
514
+ return
515
+ for i in range(28):
516
+ map_from_my_dit = [
517
+ (f"q.{i}.weight", f"{i}.attn1.to_q.weight"),
518
+ (f"k.{i}.weight", f"{i}.attn1.to_k.weight"),
519
+ (f"v.{i}.weight", f"{i}.attn1.to_v.weight"),
520
+ (f"q.{i}.bias", f"{i}.attn1.to_q.bias"),
521
+ (f"k.{i}.bias", f"{i}.attn1.to_k.bias"),
522
+ (f"v.{i}.bias", f"{i}.attn1.to_v.bias"),
523
+ (f"out_proj.{i}.weight", f"{i}.attn1.to_out.0.weight"),
524
+ (f"out_proj.{i}.bias", f"{i}.attn1.to_out.0.bias"),
525
+ (f"ffn1.{i}.weight", f"{i}.ff.net.0.proj.weight"),
526
+ (f"ffn1.{i}.bias", f"{i}.ff.net.0.proj.bias"),
527
+ (f"ffn2.{i}.weight", f"{i}.ff.net.2.weight"),
528
+ (f"ffn2.{i}.bias", f"{i}.ff.net.2.bias"),
529
+ (f"fcs0.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_1.weight"),
530
+ (f"fcs0.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_1.bias"),
531
+ (f"fcs1.{i}.weight", f"{i}.norm1.emb.timestep_embedder.linear_2.weight"),
532
+ (f"fcs1.{i}.bias", f"{i}.norm1.emb.timestep_embedder.linear_2.bias"),
533
+ (f"fcs2.{i}.weight", f"{i}.norm1.linear.weight"),
534
+ (f"fcs2.{i}.bias", f"{i}.norm1.linear.bias"),
535
+ (f"embs.{i}.weight", f"{i}.norm1.emb.class_embedder.embedding_table.weight"),
536
+ ]
537
+ for to_, from_ in map_from_my_dit:
538
+ state_dict["simplified_facebookdit." + to_] = paddle.assign(state_dict["transformer_blocks." + from_])
PaddleMIX/ppdiffusers/ppdiffusers/models/unet_1d_blocks.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import paddle
18
+ import paddle.nn.functional as F
19
+ from paddle import nn
20
+
21
+ from ..utils import is_ppxformers_available
22
+ from .activations import get_activation
23
+ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
24
+
25
+
26
+ class DownResnetBlock1D(nn.Layer):
27
+ def __init__(
28
+ self,
29
+ in_channels: int,
30
+ out_channels: Optional[int] = None,
31
+ num_layers: int = 1,
32
+ conv_shortcut: bool = False,
33
+ temb_channels: int = 32,
34
+ groups: int = 32,
35
+ groups_out: Optional[int] = None,
36
+ non_linearity: Optional[str] = None,
37
+ time_embedding_norm: str = "default",
38
+ output_scale_factor: float = 1.0,
39
+ add_downsample: bool = True,
40
+ ):
41
+ super().__init__()
42
+ self.in_channels = in_channels
43
+ out_channels = in_channels if out_channels is None else out_channels
44
+ self.out_channels = out_channels
45
+ self.use_conv_shortcut = conv_shortcut
46
+ self.time_embedding_norm = time_embedding_norm
47
+ self.add_downsample = add_downsample
48
+ self.output_scale_factor = output_scale_factor
49
+
50
+ if groups_out is None:
51
+ groups_out = groups
52
+
53
+ # there will always be at least one resnet
54
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
55
+
56
+ for _ in range(num_layers):
57
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
58
+
59
+ self.resnets = nn.LayerList(resnets)
60
+
61
+ if non_linearity is None:
62
+ self.nonlinearity = None
63
+ else:
64
+ self.nonlinearity = get_activation(non_linearity)
65
+
66
+ self.downsample = None
67
+ if add_downsample:
68
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
69
+
70
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
71
+ output_states = ()
72
+
73
+ hidden_states = self.resnets[0](hidden_states, temb)
74
+ for resnet in self.resnets[1:]:
75
+ hidden_states = resnet(hidden_states, temb)
76
+
77
+ output_states += (hidden_states,)
78
+
79
+ if self.nonlinearity is not None:
80
+ hidden_states = self.nonlinearity(hidden_states)
81
+
82
+ if self.downsample is not None:
83
+ hidden_states = self.downsample(hidden_states)
84
+
85
+ return hidden_states, output_states
86
+
87
+
88
+ class UpResnetBlock1D(nn.Layer):
89
+ def __init__(
90
+ self,
91
+ in_channels: int,
92
+ out_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ temb_channels: int = 32,
95
+ groups: int = 32,
96
+ groups_out: Optional[int] = None,
97
+ non_linearity: Optional[str] = None,
98
+ time_embedding_norm: str = "default",
99
+ output_scale_factor: float = 1.0,
100
+ add_upsample: bool = True,
101
+ ):
102
+ super().__init__()
103
+ self.in_channels = in_channels
104
+ out_channels = in_channels if out_channels is None else out_channels
105
+ self.out_channels = out_channels
106
+ self.time_embedding_norm = time_embedding_norm
107
+ self.add_upsample = add_upsample
108
+ self.output_scale_factor = output_scale_factor
109
+
110
+ if groups_out is None:
111
+ groups_out = groups
112
+
113
+ # there will always be at least one resnet
114
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
115
+
116
+ for _ in range(num_layers):
117
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
118
+
119
+ self.resnets = nn.LayerList(resnets)
120
+
121
+ if non_linearity is None:
122
+ self.nonlinearity = None
123
+ else:
124
+ self.nonlinearity = get_activation(non_linearity)
125
+
126
+ self.upsample = None
127
+ if add_upsample:
128
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
129
+
130
+ def forward(
131
+ self,
132
+ hidden_states: paddle.Tensor,
133
+ res_hidden_states_tuple: Optional[Tuple[paddle.Tensor, ...]] = None,
134
+ temb: Optional[paddle.Tensor] = None,
135
+ ) -> paddle.Tensor:
136
+ if res_hidden_states_tuple is not None:
137
+ res_hidden_states = res_hidden_states_tuple[-1]
138
+ hidden_states = paddle.concat((hidden_states, res_hidden_states), axis=1)
139
+
140
+ hidden_states = self.resnets[0](hidden_states, temb)
141
+ for resnet in self.resnets[1:]:
142
+ hidden_states = resnet(hidden_states, temb)
143
+
144
+ if self.nonlinearity is not None:
145
+ hidden_states = self.nonlinearity(hidden_states)
146
+
147
+ if self.upsample is not None:
148
+ hidden_states = self.upsample(hidden_states)
149
+
150
+ return hidden_states
151
+
152
+
153
+ class ValueFunctionMidBlock1D(nn.Layer):
154
+ def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
155
+ super().__init__()
156
+ self.in_channels = in_channels
157
+ self.out_channels = out_channels
158
+ self.embed_dim = embed_dim
159
+
160
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
161
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
162
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
163
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
164
+
165
+ def forward(self, x: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
166
+ x = self.res1(x, temb)
167
+ x = self.down1(x)
168
+ x = self.res2(x, temb)
169
+ x = self.down2(x)
170
+ return x
171
+
172
+
173
+ class MidResTemporalBlock1D(nn.Layer):
174
+ def __init__(
175
+ self,
176
+ in_channels: int,
177
+ out_channels: int,
178
+ embed_dim: int,
179
+ num_layers: int = 1,
180
+ add_downsample: bool = False,
181
+ add_upsample: bool = False,
182
+ non_linearity: Optional[str] = None,
183
+ ):
184
+ super().__init__()
185
+ self.in_channels = in_channels
186
+ self.out_channels = out_channels
187
+ self.add_downsample = add_downsample
188
+
189
+ # there will always be at least one resnet
190
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
191
+
192
+ for _ in range(num_layers):
193
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
194
+
195
+ self.resnets = nn.LayerList(resnets)
196
+
197
+ if non_linearity is None:
198
+ self.nonlinearity = None
199
+ else:
200
+ self.nonlinearity = get_activation(non_linearity)
201
+
202
+ self.upsample = None
203
+ if add_upsample:
204
+ self.upsample = Downsample1D(out_channels, use_conv=True)
205
+
206
+ self.downsample = None
207
+ if add_downsample:
208
+ self.downsample = Downsample1D(out_channels, use_conv=True)
209
+
210
+ if self.upsample and self.downsample:
211
+ raise ValueError("Block cannot downsample and upsample")
212
+
213
+ def forward(self, hidden_states: paddle.Tensor, temb: paddle.Tensor) -> paddle.Tensor:
214
+ hidden_states = self.resnets[0](hidden_states, temb)
215
+ for resnet in self.resnets[1:]:
216
+ hidden_states = resnet(hidden_states, temb)
217
+
218
+ if self.upsample:
219
+ hidden_states = self.upsample(hidden_states)
220
+ if self.downsample:
221
+ self.downsample = self.downsample(hidden_states)
222
+
223
+ return hidden_states
224
+
225
+
226
+ class OutConv1DBlock(nn.Layer):
227
+ def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
228
+ super().__init__()
229
+ self.final_conv1d_1 = nn.Conv1D(embed_dim, embed_dim, 5, padding=2)
230
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
231
+ self.final_conv1d_act = get_activation(act_fn)
232
+ self.final_conv1d_2 = nn.Conv1D(embed_dim, out_channels, 1)
233
+
234
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
235
+ hidden_states = self.final_conv1d_1(hidden_states)
236
+ hidden_states = rearrange_dims(hidden_states)
237
+ hidden_states = self.final_conv1d_gn(hidden_states)
238
+ hidden_states = rearrange_dims(hidden_states)
239
+ hidden_states = self.final_conv1d_act(hidden_states)
240
+ hidden_states = self.final_conv1d_2(hidden_states)
241
+ return hidden_states
242
+
243
+
244
+ class OutValueFunctionBlock(nn.Layer):
245
+ def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
246
+ super().__init__()
247
+ self.final_block = nn.LayerList(
248
+ [
249
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
250
+ get_activation(act_fn),
251
+ nn.Linear(fc_dim // 2, 1),
252
+ ]
253
+ )
254
+
255
+ def forward(self, hidden_states: paddle.Tensor, temb: paddle.Tensor) -> paddle.Tensor:
256
+ hidden_states = hidden_states.reshape([hidden_states.shape[0], -1])
257
+ hidden_states = paddle.concat((hidden_states, temb), axis=-1)
258
+ for layer in self.final_block:
259
+ hidden_states = layer(hidden_states)
260
+
261
+ return hidden_states
262
+
263
+
264
+ _kernels = {
265
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
266
+ "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
267
+ "lanczos3": [
268
+ 0.003689131001010537,
269
+ 0.015056144446134567,
270
+ -0.03399861603975296,
271
+ -0.066637322306633,
272
+ 0.13550527393817902,
273
+ 0.44638532400131226,
274
+ 0.44638532400131226,
275
+ 0.13550527393817902,
276
+ -0.066637322306633,
277
+ -0.03399861603975296,
278
+ 0.015056144446134567,
279
+ 0.003689131001010537,
280
+ ],
281
+ }
282
+
283
+
284
+ class Downsample1d(nn.Layer):
285
+ def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
286
+ super().__init__()
287
+ self.pad_mode = pad_mode
288
+ kernel_1d = paddle.to_tensor(_kernels[kernel])
289
+ self.pad = kernel_1d.shape[0] // 2 - 1
290
+ self.register_buffer("kernel", kernel_1d)
291
+
292
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
293
+ hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode, data_format="NCL")
294
+ weight = paddle.zeros(
295
+ [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]], dtype=hidden_states.dtype
296
+ )
297
+ indices = paddle.arange(hidden_states.shape[1])
298
+ weight[indices, indices] = self.kernel.cast(weight.dtype).expand([hidden_states.shape[1], -1])
299
+ return F.conv1d(hidden_states, weight, stride=2)
300
+
301
+
302
+ class Upsample1d(nn.Layer):
303
+ def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
304
+ super().__init__()
305
+ self.pad_mode = pad_mode
306
+ kernel_1d = paddle.to_tensor(_kernels[kernel])
307
+ self.pad = kernel_1d.shape[0] // 2 - 1
308
+ self.register_buffer("kernel", kernel_1d)
309
+
310
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
311
+ hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode, data_format="NCL")
312
+ weight = paddle.zeros(
313
+ [hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]], dtype=hidden_states.dtype
314
+ )
315
+ indices = paddle.arange(hidden_states.shape[1])
316
+ weight[indices, indices] = self.kernel.cast(weight.dtype).expand([hidden_states.shape[1], -1])
317
+ return F.conv1d_transpose(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
318
+
319
+
320
+ class SelfAttention1d(nn.Layer):
321
+ def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
322
+ super().__init__()
323
+ self.channels = in_channels
324
+ self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
325
+ self.num_heads = n_head
326
+ self.head_size = in_channels // n_head
327
+ self.scale = 1 / math.sqrt(self.head_size)
328
+
329
+ self.query = nn.Linear(self.channels, self.channels)
330
+ self.key = nn.Linear(self.channels, self.channels)
331
+ self.value = nn.Linear(self.channels, self.channels)
332
+
333
+ self.proj_attn = nn.Linear(self.channels, self.channels)
334
+
335
+ self.dropout = nn.Dropout(dropout_rate)
336
+
337
+ self._use_memory_efficient_attention_xformers = False
338
+ self._attention_op = None
339
+
340
+ def reshape_heads_to_batch_dim(self, tensor, transpose=True):
341
+ tensor = tensor.reshape([0, 0, self.num_heads, self.head_size])
342
+ if transpose:
343
+ tensor = tensor.transpose([0, 2, 1, 3])
344
+ return tensor
345
+
346
+ def reshape_batch_dim_to_heads(self, tensor, transpose=True):
347
+ if transpose:
348
+ tensor = tensor.transpose([0, 2, 1, 3])
349
+ tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
350
+ return tensor
351
+
352
+ def set_use_memory_efficient_attention_xformers(
353
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[str] = None
354
+ ):
355
+ if use_memory_efficient_attention_xformers:
356
+ if not is_ppxformers_available():
357
+ raise NotImplementedError(
358
+ "requires the scaled_dot_product_attention but your PaddlePaddle donot have this. Checkout the instructions on the installation page: https://www.paddlepaddle.org.cn/install/quick and follow the ones that match your environment."
359
+ )
360
+ else:
361
+ try:
362
+ _ = F.scaled_dot_product_attention_(
363
+ paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
364
+ paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
365
+ paddle.ones((1, 1, 2, 40), dtype=paddle.float16),
366
+ attention_op=attention_op,
367
+ )
368
+ except Exception as e:
369
+ raise e
370
+
371
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
372
+ self._attention_op = attention_op
373
+
374
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
375
+ residual = hidden_states
376
+
377
+ hidden_states = self.group_norm(hidden_states)
378
+ hidden_states = hidden_states.transpose([0, 2, 1])
379
+
380
+ query_proj = self.query(hidden_states)
381
+ key_proj = self.key(hidden_states)
382
+ value_proj = self.value(hidden_states)
383
+
384
+ query_proj = self.reshape_heads_to_batch_dim(
385
+ query_proj, transpose=not self._use_memory_efficient_attention_xformers
386
+ )
387
+ key_proj = self.reshape_heads_to_batch_dim(
388
+ key_proj, transpose=not self._use_memory_efficient_attention_xformers
389
+ )
390
+ value_proj = self.reshape_heads_to_batch_dim(
391
+ value_proj, transpose=not self._use_memory_efficient_attention_xformers
392
+ )
393
+
394
+ if self._use_memory_efficient_attention_xformers:
395
+ hidden_states = F.scaled_dot_product_attention_(
396
+ query_proj,
397
+ key_proj,
398
+ value_proj,
399
+ attn_mask=None,
400
+ scale=self.scale,
401
+ dropout_p=0.0,
402
+ training=self.training,
403
+ attention_op=self._attention_op,
404
+ )
405
+ else:
406
+ attention_scores = paddle.matmul(query_proj, key_proj, transpose_y=True) * self.scale
407
+ attention_probs = F.softmax(attention_scores.cast("float32"), axis=-1).cast(attention_scores.dtype)
408
+ hidden_states = paddle.matmul(attention_probs, value_proj)
409
+
410
+ # reshape hidden_states
411
+ hidden_states = self.reshape_batch_dim_to_heads(
412
+ hidden_states, transpose=not self._use_memory_efficient_attention_xformers
413
+ )
414
+
415
+ # compute next hidden_states
416
+ hidden_states = self.proj_attn(hidden_states)
417
+ hidden_states = hidden_states.transpose([0, 2, 1])
418
+ hidden_states = self.dropout(hidden_states)
419
+
420
+ output = hidden_states + residual
421
+
422
+ return output
423
+
424
+
425
+ class ResConvBlock(nn.Layer):
426
+ def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
427
+ super().__init__()
428
+ self.is_last = is_last
429
+ self.has_conv_skip = in_channels != out_channels
430
+
431
+ if self.has_conv_skip:
432
+ self.conv_skip = nn.Conv1D(in_channels, out_channels, 1, bias_attr=False)
433
+
434
+ self.conv_1 = nn.Conv1D(in_channels, mid_channels, 5, padding=2)
435
+ self.group_norm_1 = nn.GroupNorm(1, mid_channels)
436
+ self.gelu_1 = nn.GELU()
437
+ self.conv_2 = nn.Conv1D(mid_channels, out_channels, 5, padding=2)
438
+
439
+ if not self.is_last:
440
+ self.group_norm_2 = nn.GroupNorm(1, out_channels)
441
+ self.gelu_2 = nn.GELU()
442
+
443
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
444
+ residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
445
+
446
+ hidden_states = self.conv_1(hidden_states)
447
+ hidden_states = self.group_norm_1(hidden_states)
448
+ hidden_states = self.gelu_1(hidden_states)
449
+ hidden_states = self.conv_2(hidden_states)
450
+
451
+ if not self.is_last:
452
+ hidden_states = self.group_norm_2(hidden_states)
453
+ hidden_states = self.gelu_2(hidden_states)
454
+
455
+ output = hidden_states + residual
456
+ return output
457
+
458
+
459
+ class UNetMidBlock1D(nn.Layer):
460
+ def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
461
+ super().__init__()
462
+
463
+ out_channels = in_channels if out_channels is None else out_channels
464
+
465
+ # there is always at least one resnet
466
+ self.down = Downsample1d("cubic")
467
+ resnets = [
468
+ ResConvBlock(in_channels, mid_channels, mid_channels),
469
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
470
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
471
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
472
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
473
+ ResConvBlock(mid_channels, mid_channels, out_channels),
474
+ ]
475
+ attentions = [
476
+ SelfAttention1d(mid_channels, mid_channels // 32),
477
+ SelfAttention1d(mid_channels, mid_channels // 32),
478
+ SelfAttention1d(mid_channels, mid_channels // 32),
479
+ SelfAttention1d(mid_channels, mid_channels // 32),
480
+ SelfAttention1d(mid_channels, mid_channels // 32),
481
+ SelfAttention1d(out_channels, out_channels // 32),
482
+ ]
483
+ self.up = Upsample1d(kernel="cubic")
484
+
485
+ self.attentions = nn.LayerList(attentions)
486
+ self.resnets = nn.LayerList(resnets)
487
+
488
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
489
+ hidden_states = self.down(hidden_states)
490
+ for attn, resnet in zip(self.attentions, self.resnets):
491
+ hidden_states = resnet(hidden_states)
492
+ hidden_states = attn(hidden_states)
493
+
494
+ hidden_states = self.up(hidden_states)
495
+
496
+ return hidden_states
497
+
498
+
499
+ class AttnDownBlock1D(nn.Layer):
500
+ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
501
+ super().__init__()
502
+ mid_channels = out_channels if mid_channels is None else mid_channels
503
+
504
+ self.down = Downsample1d("cubic")
505
+ resnets = [
506
+ ResConvBlock(in_channels, mid_channels, mid_channels),
507
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
508
+ ResConvBlock(mid_channels, mid_channels, out_channels),
509
+ ]
510
+ attentions = [
511
+ SelfAttention1d(mid_channels, mid_channels // 32),
512
+ SelfAttention1d(mid_channels, mid_channels // 32),
513
+ SelfAttention1d(out_channels, out_channels // 32),
514
+ ]
515
+
516
+ self.attentions = nn.LayerList(attentions)
517
+ self.resnets = nn.LayerList(resnets)
518
+
519
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
520
+ hidden_states = self.down(hidden_states)
521
+
522
+ for resnet, attn in zip(self.resnets, self.attentions):
523
+ hidden_states = resnet(hidden_states)
524
+ hidden_states = attn(hidden_states)
525
+
526
+ return hidden_states, (hidden_states,)
527
+
528
+
529
+ class DownBlock1D(nn.Layer):
530
+ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
531
+ super().__init__()
532
+ mid_channels = out_channels if mid_channels is None else mid_channels
533
+
534
+ self.down = Downsample1d("cubic")
535
+ resnets = [
536
+ ResConvBlock(in_channels, mid_channels, mid_channels),
537
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
538
+ ResConvBlock(mid_channels, mid_channels, out_channels),
539
+ ]
540
+
541
+ self.resnets = nn.LayerList(resnets)
542
+
543
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
544
+ hidden_states = self.down(hidden_states)
545
+
546
+ for resnet in self.resnets:
547
+ hidden_states = resnet(hidden_states)
548
+
549
+ return hidden_states, (hidden_states,)
550
+
551
+
552
+ class DownBlock1DNoSkip(nn.Layer):
553
+ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
554
+ super().__init__()
555
+ mid_channels = out_channels if mid_channels is None else mid_channels
556
+
557
+ resnets = [
558
+ ResConvBlock(in_channels, mid_channels, mid_channels),
559
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
560
+ ResConvBlock(mid_channels, mid_channels, out_channels),
561
+ ]
562
+
563
+ self.resnets = nn.LayerList(resnets)
564
+
565
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
566
+ hidden_states = paddle.concat([hidden_states, temb], axis=1)
567
+ for resnet in self.resnets:
568
+ hidden_states = resnet(hidden_states)
569
+
570
+ return hidden_states, (hidden_states,)
571
+
572
+
573
+ class AttnUpBlock1D(nn.Layer):
574
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
575
+ super().__init__()
576
+ mid_channels = out_channels if mid_channels is None else mid_channels
577
+
578
+ resnets = [
579
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
580
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
581
+ ResConvBlock(mid_channels, mid_channels, out_channels),
582
+ ]
583
+ attentions = [
584
+ SelfAttention1d(mid_channels, mid_channels // 32),
585
+ SelfAttention1d(mid_channels, mid_channels // 32),
586
+ SelfAttention1d(out_channels, out_channels // 32),
587
+ ]
588
+
589
+ self.attentions = nn.LayerList(attentions)
590
+ self.resnets = nn.LayerList(resnets)
591
+ self.up = Upsample1d(kernel="cubic")
592
+
593
+ def forward(
594
+ self,
595
+ hidden_states: paddle.Tensor,
596
+ res_hidden_states_tuple: Tuple[paddle.Tensor, ...],
597
+ temb: Optional[paddle.Tensor] = None,
598
+ ) -> paddle.Tensor:
599
+ res_hidden_states = res_hidden_states_tuple[-1]
600
+ hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)
601
+
602
+ for resnet, attn in zip(self.resnets, self.attentions):
603
+ hidden_states = resnet(hidden_states)
604
+ hidden_states = attn(hidden_states)
605
+
606
+ hidden_states = self.up(hidden_states)
607
+
608
+ return hidden_states
609
+
610
+
611
+ class UpBlock1D(nn.Layer):
612
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
613
+ super().__init__()
614
+ mid_channels = in_channels if mid_channels is None else mid_channels
615
+
616
+ resnets = [
617
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
618
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
619
+ ResConvBlock(mid_channels, mid_channels, out_channels),
620
+ ]
621
+
622
+ self.resnets = nn.LayerList(resnets)
623
+ self.up = Upsample1d(kernel="cubic")
624
+
625
+ def forward(
626
+ self,
627
+ hidden_states: paddle.Tensor,
628
+ res_hidden_states_tuple: Tuple[paddle.Tensor, ...],
629
+ temb: Optional[paddle.Tensor] = None,
630
+ ) -> paddle.Tensor:
631
+ res_hidden_states = res_hidden_states_tuple[-1]
632
+ hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)
633
+
634
+ for resnet in self.resnets:
635
+ hidden_states = resnet(hidden_states)
636
+
637
+ hidden_states = self.up(hidden_states)
638
+
639
+ return hidden_states
640
+
641
+
642
+ class UpBlock1DNoSkip(nn.Layer):
643
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
644
+ super().__init__()
645
+ mid_channels = in_channels if mid_channels is None else mid_channels
646
+
647
+ resnets = [
648
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
649
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
650
+ ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
651
+ ]
652
+
653
+ self.resnets = nn.LayerList(resnets)
654
+
655
+ def forward(
656
+ self,
657
+ hidden_states: paddle.Tensor,
658
+ res_hidden_states_tuple: Tuple[paddle.Tensor, ...],
659
+ temb: Optional[paddle.Tensor] = None,
660
+ ) -> paddle.Tensor:
661
+ res_hidden_states = res_hidden_states_tuple[-1]
662
+ hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)
663
+
664
+ for resnet in self.resnets:
665
+ hidden_states = resnet(hidden_states)
666
+
667
+ return hidden_states
668
+
669
+
670
+ DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
671
+ MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
672
+ OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
673
+ UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
674
+
675
+
676
+ def get_down_block(
677
+ down_block_type: str,
678
+ num_layers: int,
679
+ in_channels: int,
680
+ out_channels: int,
681
+ temb_channels: int,
682
+ add_downsample: bool,
683
+ ) -> DownBlockType:
684
+ if down_block_type == "DownResnetBlock1D":
685
+ return DownResnetBlock1D(
686
+ in_channels=in_channels,
687
+ num_layers=num_layers,
688
+ out_channels=out_channels,
689
+ temb_channels=temb_channels,
690
+ add_downsample=add_downsample,
691
+ )
692
+ elif down_block_type == "DownBlock1D":
693
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
694
+ elif down_block_type == "AttnDownBlock1D":
695
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
696
+ elif down_block_type == "DownBlock1DNoSkip":
697
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
698
+ raise ValueError(f"{down_block_type} does not exist.")
699
+
700
+
701
+ def get_up_block(
702
+ up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
703
+ ) -> UpBlockType:
704
+ if up_block_type == "UpResnetBlock1D":
705
+ return UpResnetBlock1D(
706
+ in_channels=in_channels,
707
+ num_layers=num_layers,
708
+ out_channels=out_channels,
709
+ temb_channels=temb_channels,
710
+ add_upsample=add_upsample,
711
+ )
712
+ elif up_block_type == "UpBlock1D":
713
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
714
+ elif up_block_type == "AttnUpBlock1D":
715
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
716
+ elif up_block_type == "UpBlock1DNoSkip":
717
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
718
+ raise ValueError(f"{up_block_type} does not exist.")
719
+
720
+
721
+ def get_mid_block(
722
+ mid_block_type: str,
723
+ num_layers: int,
724
+ in_channels: int,
725
+ mid_channels: int,
726
+ out_channels: int,
727
+ embed_dim: int,
728
+ add_downsample: bool,
729
+ ) -> MidBlockType:
730
+ if mid_block_type == "MidResTemporalBlock1D":
731
+ return MidResTemporalBlock1D(
732
+ num_layers=num_layers,
733
+ in_channels=in_channels,
734
+ out_channels=out_channels,
735
+ embed_dim=embed_dim,
736
+ add_downsample=add_downsample,
737
+ )
738
+ elif mid_block_type == "ValueFunctionMidBlock1D":
739
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
740
+ elif mid_block_type == "UNetMidBlock1D":
741
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
742
+ raise ValueError(f"{mid_block_type} does not exist.")
743
+
744
+
745
+ def get_out_block(
746
+ *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
747
+ ) -> Optional[OutBlockType]:
748
+ if out_block_type == "OutConv1DBlock":
749
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
750
+ elif out_block_type == "ValueFunction":
751
+ return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
752
+ return None