bilalfaye commited on
Commit
9da9224
·
verified ·
1 Parent(s): 9a254c3

Upload dpm_unet.py

Browse files
Files changed (1) hide show
  1. unet/dpm_unet.py +189 -0
unet/dpm_unet.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import UNet2DModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Optional, Tuple, Union
6
+ from collections import OrderedDict
7
+ from dataclasses import dataclass
8
+ from datasets import load_dataset
9
+ import matplotlib.pyplot as plt
10
+ from torchvision import transforms
11
+ from functools import partial
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from PIL import Image
15
+ from diffusers import DDPMScheduler
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class BaseOutput(OrderedDict):
20
+ """
21
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
22
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
23
+ Python dictionary.
24
+ """
25
+ def __init_subclass__(cls) -> None:
26
+ if torch.__version__ >= "2.2":
27
+ import torch.utils._pytree as pytree
28
+ pytree.register_pytree_node(
29
+ cls,
30
+ pytree._dict_flatten,
31
+ lambda values, context: cls(**pytree._dict_unflatten(values, context)),
32
+ serialized_type_name=f"{cls.__module__}.{cls.__name__}",
33
+ )
34
+ else:
35
+ import torch.utils._pytree as pytree
36
+ pytree._register_pytree_node(
37
+ cls,
38
+ pytree._dict_flatten,
39
+ lambda values, context: cls(**pytree._dict_unflatten(values, context)),
40
+ )
41
+
42
+ @dataclass
43
+ class UNet2DOutput(BaseOutput):
44
+ """
45
+ The output of [`UNet2DModel`].
46
+
47
+ Args:
48
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
49
+ The hidden states output from the last layer of the model.
50
+ """
51
+ sample: torch.Tensor
52
+
53
+
54
+ class DPM(UNet2DModel):
55
+ def __init__(self, *args, **kwargs):
56
+ super().__init__(*args, **kwargs)
57
+
58
+ hidden_size = self.config.block_out_channels[-1]
59
+ self.bottleneck_attn = nn.MultiheadAttention(
60
+ embed_dim=hidden_size,
61
+ num_heads=8, # ou ajuster selon besoin
62
+ batch_first=True
63
+ )
64
+
65
+
66
+ def forward(
67
+ self,
68
+ sample: torch.Tensor,
69
+ timestep: Union[torch.Tensor, float, int],
70
+ class_labels: Optional[torch.Tensor] = None,
71
+ return_dict: bool = True,
72
+ prototype: Optional[torch.Tensor] = None, # <--- ajouté ici
73
+ ) -> Union[UNet2DOutput, Tuple]:
74
+ r"""
75
+ The [`UNet2DModel`] forward method.
76
+
77
+ Args:
78
+ sample (`torch.Tensor`):
79
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
80
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
81
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
82
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
83
+ return_dict (`bool`, *optional*, defaults to `True`):
84
+ Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple.
85
+
86
+ Returns:
87
+ [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`:
88
+ If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
89
+ returned where the first element is the sample tensor.
90
+ """
91
+ # 0. center input if necessary
92
+ if self.config.center_input_sample:
93
+ sample = 2 * sample - 1.0
94
+
95
+ # 1. time
96
+ timesteps = timestep
97
+ if not torch.is_tensor(timesteps):
98
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
99
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
100
+ timesteps = timesteps[None].to(sample.device)
101
+
102
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
103
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
104
+
105
+ t_emb = self.time_proj(timesteps)
106
+
107
+ # timesteps does not contain any weights and will always return f32 tensors
108
+ # but time_embedding might actually be running in fp16. so we need to cast here.
109
+ # there might be better ways to encapsulate this.
110
+ t_emb = t_emb.to(dtype=self.dtype)
111
+ emb = self.time_embedding(t_emb)
112
+
113
+ if self.class_embedding is not None:
114
+ if class_labels is None:
115
+ raise ValueError("class_labels should be provided when doing class conditioning")
116
+
117
+ if self.config.class_embed_type == "timestep":
118
+ class_labels = self.time_proj(class_labels)
119
+
120
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
121
+ emb = emb + class_emb
122
+ elif self.class_embedding is None and class_labels is not None:
123
+ raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
124
+
125
+ # 2. pre-process
126
+ skip_sample = sample
127
+ sample = self.conv_in(sample)
128
+
129
+ # 3. down
130
+ down_block_res_samples = (sample,)
131
+ for downsample_block in self.down_blocks:
132
+ if hasattr(downsample_block, "skip_conv"):
133
+ sample, res_samples, skip_sample = downsample_block(
134
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
135
+ )
136
+ else:
137
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
138
+
139
+ down_block_res_samples += res_samples
140
+
141
+ # ----------- Cross-Attention after downsampling ------------------
142
+ if prototype is None:
143
+ raise ValueError("You must provide a `prototype` tensor for cross-attention")
144
+
145
+ b, c, h, w = sample.shape
146
+ query = sample.view(b, c, h * w).transpose(1, 2) # (B, HW, C)
147
+
148
+ # prototype: expected shape (B, N, C)
149
+ key = value = prototype.to(dtype=sample.dtype)
150
+
151
+ attn_output, _ = self.bottleneck_attn(query, key, value)
152
+ attn_output = attn_output.transpose(1, 2).view(b, c, h, w) # (B, C, H, W)
153
+
154
+ # Résiduel
155
+ sample = sample + attn_output
156
+ # ---------------------------------------------------------------
157
+
158
+
159
+ # 4. mid
160
+ if self.mid_block is not None:
161
+ sample = self.mid_block(sample, emb)
162
+
163
+ # 5. up
164
+ skip_sample = None
165
+ for upsample_block in self.up_blocks:
166
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
167
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
168
+
169
+ if hasattr(upsample_block, "skip_conv"):
170
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
171
+ else:
172
+ sample = upsample_block(sample, res_samples, emb)
173
+
174
+ # 6. post-process
175
+ sample = self.conv_norm_out(sample)
176
+ sample = self.conv_act(sample)
177
+ sample = self.conv_out(sample)
178
+
179
+ if skip_sample is not None:
180
+ sample += skip_sample
181
+
182
+ if self.config.time_embedding_type == "fourier":
183
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
184
+ sample = sample / timesteps
185
+
186
+ if not return_dict:
187
+ return (sample,)
188
+
189
+ return UNet2DOutput(sample=sample)