Update pipeline_stable_diffusion_3_ipa_clip.py
Browse files
pipeline_stable_diffusion_3_ipa_clip.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
|
15 |
import inspect
|
16 |
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
17 |
|
18 |
import torch
|
19 |
import torch.nn as nn
|
@@ -922,8 +923,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
922 |
|
923 |
@torch.inference_mode()
|
924 |
def encode_clip_image_emb(self, clip_image, device, dtype):
|
925 |
-
|
926 |
-
|
|
|
927 |
clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
|
928 |
clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
|
929 |
clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
|
@@ -965,7 +967,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
965 |
|
966 |
# ipa
|
967 |
clip_image=None,
|
|
|
|
|
|
|
|
|
|
|
968 |
ipadapter_scale=1.0,
|
|
|
|
|
|
|
|
|
|
|
969 |
):
|
970 |
r"""
|
971 |
Function invoked when calling the pipeline for generation.
|
@@ -1126,11 +1138,82 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
1126 |
if self.do_classifier_free_guidance:
|
1127 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1128 |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1129 |
-
|
|
|
|
|
|
|
|
|
1130 |
# 3. prepare clip emb
|
1131 |
-
clip_image
|
1132 |
-
|
1133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1134 |
# 4. Prepare timesteps
|
1135 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1136 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
@@ -1223,7 +1306,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
1223 |
else:
|
1224 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1225 |
|
1226 |
-
image = self.vae.decode(latents, return_dict=False)[0]
|
|
|
1227 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1228 |
|
1229 |
# Offload all models
|
|
|
14 |
|
15 |
import inspect
|
16 |
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
from PIL import Image
|
18 |
|
19 |
import torch
|
20 |
import torch.nn as nn
|
|
|
923 |
|
924 |
@torch.inference_mode()
|
925 |
def encode_clip_image_emb(self, clip_image, device, dtype):
|
926 |
+
if isinstance(clip_image, Image.Image):
|
927 |
+
clip_image = [clip_image]
|
928 |
+
# clip
|
929 |
clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
|
930 |
clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
|
931 |
clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
|
|
|
967 |
|
968 |
# ipa
|
969 |
clip_image=None,
|
970 |
+
clip_image_2=None,
|
971 |
+
clip_image_3=None,
|
972 |
+
clip_image_4=None,
|
973 |
+
clip_image_5=None,
|
974 |
+
text_scale=1.0,
|
975 |
ipadapter_scale=1.0,
|
976 |
+
scale_1=1.0,
|
977 |
+
scale_2=1.0,
|
978 |
+
scale_3=1.0,
|
979 |
+
scale_4=1.0,
|
980 |
+
scale_5=1.0,
|
981 |
):
|
982 |
r"""
|
983 |
Function invoked when calling the pipeline for generation.
|
|
|
1138 |
if self.do_classifier_free_guidance:
|
1139 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1140 |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1141 |
+
|
1142 |
+
prompt_embeds = prompt_embeds * text_scale
|
1143 |
+
|
1144 |
+
image_prompt_embeds_list = []
|
1145 |
+
|
1146 |
# 3. prepare clip emb
|
1147 |
+
if clip_image != None:
|
1148 |
+
print('Using primary image.')
|
1149 |
+
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
1150 |
+
#clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
|
1151 |
+
#with torch.no_grad():
|
1152 |
+
clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
|
1153 |
+
print('clip output shape: ', clip_image_embeds_1.shape)
|
1154 |
+
clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
|
1155 |
+
clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
|
1156 |
+
print('encoder output shape: ', clip_image_embeds_1.shape)
|
1157 |
+
clip_image_embeds_1 = clip_image_embeds_1 * scale_1
|
1158 |
+
image_prompt_embeds_list.append(clip_image_embeds_1)
|
1159 |
+
if clip_image_2 != None:
|
1160 |
+
print('Using secondary image.')
|
1161 |
+
clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
|
1162 |
+
#with torch.no_grad():
|
1163 |
+
clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
|
1164 |
+
clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
|
1165 |
+
clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
|
1166 |
+
clip_image_embeds_2 = clip_image_embeds_2 * scale_2
|
1167 |
+
image_prompt_embeds_list.append(clip_image_embeds_2)
|
1168 |
+
if clip_image_3 != None:
|
1169 |
+
print('Using tertiary image.')
|
1170 |
+
clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
|
1171 |
+
#with torch.no_grad():
|
1172 |
+
clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
|
1173 |
+
clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
|
1174 |
+
clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
|
1175 |
+
clip_image_embeds_3 = clip_image_embeds_3 * scale_3
|
1176 |
+
image_prompt_embeds_list.append(clip_image_embeds_3)
|
1177 |
+
if clip_image_4 != None:
|
1178 |
+
print('Using quaternary image.')
|
1179 |
+
clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
|
1180 |
+
#with torch.no_grad():
|
1181 |
+
clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
|
1182 |
+
clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
|
1183 |
+
clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
|
1184 |
+
clip_image_embeds_4 = clip_image_embeds_4 * scale_4
|
1185 |
+
image_prompt_embeds_list.append(clip_image_embeds_4)
|
1186 |
+
if clip_image_5 != None:
|
1187 |
+
print('Using quinary image.')
|
1188 |
+
clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
|
1189 |
+
#with torch.no_grad():
|
1190 |
+
clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
|
1191 |
+
clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
|
1192 |
+
clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
|
1193 |
+
clip_image_embeds_5 = clip_image_embeds_5 * scale_5
|
1194 |
+
image_prompt_embeds_list.append(clip_image_embeds_5)
|
1195 |
+
|
1196 |
+
# with cat and mean
|
1197 |
+
clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list)
|
1198 |
+
clip_image_embeds_cat_list = torch.mean(clip_image_embeds_cat_list,dim=0,keepdim=True)
|
1199 |
+
print('catted embeds list: ',clip_image_embeds_cat_list.shape)
|
1200 |
+
zeros_tensor = torch.zeros_like(clip_image_embeds_cat_list)
|
1201 |
+
clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_cat_list], dim=0)
|
1202 |
+
print('catted embeds: ',clip_image_embeds.shape)
|
1203 |
+
|
1204 |
+
'''
|
1205 |
+
clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
|
1206 |
+
print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
|
1207 |
+
seq_len, _ = clip_image_embeds_cat_list.shape
|
1208 |
+
clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
|
1209 |
+
print('catted embeds repeat: ',clip_image_embeds_cat_list_repeat.shape)
|
1210 |
+
clip_image_embeds_view = clip_image_embeds_cat_list_repeat.view(1, seq_len, -1)
|
1211 |
+
print('catted viewed: ',clip_image_embeds_view.shape)
|
1212 |
+
zeros_tensor = torch.zeros_like(clip_image_embeds_view)
|
1213 |
+
print('zeros: ',zeros_tensor.shape)
|
1214 |
+
clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
|
1215 |
+
print('embeds shape: ', clip_image_embeds.shape)
|
1216 |
+
'''
|
1217 |
# 4. Prepare timesteps
|
1218 |
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1219 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
|
1306 |
else:
|
1307 |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1308 |
|
1309 |
+
image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
|
1310 |
+
|
1311 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1312 |
|
1313 |
# Offload all models
|