Commit
·
14ce2e6
1
Parent(s):
5b19a88
Upload modeling_ddpm.py
Browse files- modeling_ddpm.py +24 -26
modeling_ddpm.py
CHANGED
@@ -14,15 +14,13 @@
|
|
14 |
# limitations under the License.
|
15 |
|
16 |
|
17 |
-
from diffusers import DiffusionPipeline
|
18 |
-
import tqdm
|
19 |
import torch
|
20 |
|
|
|
|
|
21 |
|
22 |
-
class DDPM(DiffusionPipeline):
|
23 |
-
|
24 |
-
modeling_file = "modeling_ddpm.py"
|
25 |
|
|
|
26 |
def __init__(self, unet, noise_scheduler):
|
27 |
super().__init__()
|
28 |
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
@@ -32,30 +30,30 @@ class DDPM(DiffusionPipeline):
|
|
32 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
|
34 |
self.unet.to(torch_device)
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
with torch.no_grad():
|
46 |
-
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
|
51 |
-
pred_mean = torch.clamp(pred_mean, -1, 1)
|
52 |
-
prev_image = clip_coeff * pred_mean + image_coeff * image
|
53 |
|
54 |
-
#
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
-
#
|
58 |
-
|
59 |
-
image = sampled_prev_image
|
60 |
|
61 |
return image
|
|
|
14 |
# limitations under the License.
|
15 |
|
16 |
|
|
|
|
|
17 |
import torch
|
18 |
|
19 |
+
import tqdm
|
20 |
+
from diffusers import DiffusionPipeline
|
21 |
|
|
|
|
|
|
|
22 |
|
23 |
+
class DDPM(DiffusionPipeline):
|
24 |
def __init__(self, unet, noise_scheduler):
|
25 |
super().__init__()
|
26 |
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
|
|
30 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
|
32 |
self.unet.to(torch_device)
|
33 |
+
|
34 |
+
# Sample gaussian noise to begin loop
|
35 |
+
image = self.noise_scheduler.sample_noise(
|
36 |
+
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
37 |
+
device=torch_device,
|
38 |
+
generator=generator,
|
39 |
+
)
|
40 |
+
|
41 |
+
num_prediction_steps = len(self.noise_scheduler)
|
42 |
+
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
43 |
+
# 1. predict noise residual
|
44 |
with torch.no_grad():
|
45 |
+
residual = self.unet(image, t)
|
46 |
|
47 |
+
# 2. predict previous mean of image x_t-1
|
48 |
+
pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t)
|
|
|
|
|
|
|
49 |
|
50 |
+
# 3. optionally sample variance
|
51 |
+
variance = 0
|
52 |
+
if t > 0:
|
53 |
+
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
|
54 |
+
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
55 |
|
56 |
+
# 4. set current image to prev_image: x_t -> x_t-1
|
57 |
+
image = pred_prev_image + variance
|
|
|
58 |
|
59 |
return image
|