1inkusFace commited on
Commit
255ddb1
·
verified ·
1 Parent(s): 344f6ab

Update pipeline_stable_diffusion_3_ipa_clip.py

Browse files
pipeline_stable_diffusion_3_ipa_clip.py CHANGED
@@ -14,6 +14,7 @@
14
 
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Union
 
17
 
18
  import torch
19
  import torch.nn as nn
@@ -922,8 +923,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
922
 
923
  @torch.inference_mode()
924
  def encode_clip_image_emb(self, clip_image, device, dtype):
925
-
926
- # clip
 
927
  clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
928
  clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
929
  clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
@@ -965,7 +967,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
965
 
966
  # ipa
967
  clip_image=None,
 
 
 
 
 
968
  ipadapter_scale=1.0,
 
 
 
 
 
969
  ):
970
  r"""
971
  Function invoked when calling the pipeline for generation.
@@ -1126,11 +1138,82 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1126
  if self.do_classifier_free_guidance:
1127
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1128
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1129
-
 
 
 
 
1130
  # 3. prepare clip emb
1131
- clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1132
- clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype)
1133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1134
  # 4. Prepare timesteps
1135
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1136
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1223,7 +1306,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1223
  else:
1224
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1225
 
1226
- image = self.vae.decode(latents, return_dict=False)[0]
 
1227
  image = self.image_processor.postprocess(image, output_type=output_type)
1228
 
1229
  # Offload all models
 
14
 
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
+ from PIL import Image
18
 
19
  import torch
20
  import torch.nn as nn
 
923
 
924
  @torch.inference_mode()
925
  def encode_clip_image_emb(self, clip_image, device, dtype):
926
+ if isinstance(clip_image, Image.Image):
927
+ clip_image = [clip_image]
928
+ # clip
929
  clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
930
  clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
931
  clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
 
967
 
968
  # ipa
969
  clip_image=None,
970
+ clip_image_2=None,
971
+ clip_image_3=None,
972
+ clip_image_4=None,
973
+ clip_image_5=None,
974
+ text_scale=1.0,
975
  ipadapter_scale=1.0,
976
+ scale_1=1.0,
977
+ scale_2=1.0,
978
+ scale_3=1.0,
979
+ scale_4=1.0,
980
+ scale_5=1.0,
981
  ):
982
  r"""
983
  Function invoked when calling the pipeline for generation.
 
1138
  if self.do_classifier_free_guidance:
1139
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1140
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1141
+
1142
+ prompt_embeds = prompt_embeds * text_scale
1143
+
1144
+ image_prompt_embeds_list = []
1145
+
1146
  # 3. prepare clip emb
1147
+ if clip_image != None:
1148
+ print('Using primary image.')
1149
+ clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1150
+ #clip_image_embeds_1 = self.encode_clip_image_emb(clip_image, device, dtype)
1151
+ #with torch.no_grad():
1152
+ clip_image_embeds_1 = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
1153
+ print('clip output shape: ', clip_image_embeds_1.shape)
1154
+ clip_image_embeds_1 = clip_image_embeds_1.to(device, dtype=dtype)
1155
+ clip_image_embeds_1 = self.image_encoder(clip_image_embeds_1, output_hidden_states=True).hidden_states[-2]
1156
+ print('encoder output shape: ', clip_image_embeds_1.shape)
1157
+ clip_image_embeds_1 = clip_image_embeds_1 * scale_1
1158
+ image_prompt_embeds_list.append(clip_image_embeds_1)
1159
+ if clip_image_2 != None:
1160
+ print('Using secondary image.')
1161
+ clip_image_2 = clip_image_2.resize((max(clip_image_2.size), max(clip_image_2.size)))
1162
+ #with torch.no_grad():
1163
+ clip_image_embeds_2 = self.clip_image_processor(images=clip_image_2, return_tensors="pt").pixel_values
1164
+ clip_image_embeds_2 = clip_image_embeds_2.to(device, dtype=dtype)
1165
+ clip_image_embeds_2 = self.image_encoder(clip_image_embeds_2, output_hidden_states=True).hidden_states[-2]
1166
+ clip_image_embeds_2 = clip_image_embeds_2 * scale_2
1167
+ image_prompt_embeds_list.append(clip_image_embeds_2)
1168
+ if clip_image_3 != None:
1169
+ print('Using tertiary image.')
1170
+ clip_image_3 = clip_image_3.resize((max(clip_image_3.size), max(clip_image_3.size)))
1171
+ #with torch.no_grad():
1172
+ clip_image_embeds_3 = self.clip_image_processor(images=clip_image_3, return_tensors="pt").pixel_values
1173
+ clip_image_embeds_3 = clip_image_embeds_3.to(device, dtype=dtype)
1174
+ clip_image_embeds_3 = self.image_encoder(clip_image_embeds_3, output_hidden_states=True).hidden_states[-2]
1175
+ clip_image_embeds_3 = clip_image_embeds_3 * scale_3
1176
+ image_prompt_embeds_list.append(clip_image_embeds_3)
1177
+ if clip_image_4 != None:
1178
+ print('Using quaternary image.')
1179
+ clip_image_4 = clip_image_4.resize((max(clip_image_4.size), max(clip_image_4.size)))
1180
+ #with torch.no_grad():
1181
+ clip_image_embeds_4 = self.clip_image_processor(images=clip_image_4, return_tensors="pt").pixel_values
1182
+ clip_image_embeds_4 = clip_image_embeds_4.to(device, dtype=dtype)
1183
+ clip_image_embeds_4 = self.image_encoder(clip_image_embeds_4, output_hidden_states=True).hidden_states[-2]
1184
+ clip_image_embeds_4 = clip_image_embeds_4 * scale_4
1185
+ image_prompt_embeds_list.append(clip_image_embeds_4)
1186
+ if clip_image_5 != None:
1187
+ print('Using quinary image.')
1188
+ clip_image_5 = clip_image_5.resize((max(clip_image_5.size), max(clip_image_5.size)))
1189
+ #with torch.no_grad():
1190
+ clip_image_embeds_5 = self.clip_image_processor(images=clip_image_5, return_tensors="pt").pixel_values
1191
+ clip_image_embeds_5 = clip_image_embeds_5.to(device, dtype=dtype)
1192
+ clip_image_embeds_5 = self.image_encoder(clip_image_embeds_5, output_hidden_states=True).hidden_states[-2]
1193
+ clip_image_embeds_5 = clip_image_embeds_5 * scale_5
1194
+ image_prompt_embeds_list.append(clip_image_embeds_5)
1195
+
1196
+ # with cat and mean
1197
+ clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list)
1198
+ clip_image_embeds_cat_list = torch.mean(clip_image_embeds_cat_list,dim=0,keepdim=True)
1199
+ print('catted embeds list: ',clip_image_embeds_cat_list.shape)
1200
+ zeros_tensor = torch.zeros_like(clip_image_embeds_cat_list)
1201
+ clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_cat_list], dim=0)
1202
+ print('catted embeds: ',clip_image_embeds.shape)
1203
+
1204
+ '''
1205
+ clip_image_embeds_cat_list = torch.cat(image_prompt_embeds_list).mean(dim=0)
1206
+ print('catted embeds list with mean: ',clip_image_embeds_cat_list.shape)
1207
+ seq_len, _ = clip_image_embeds_cat_list.shape
1208
+ clip_image_embeds_cat_list_repeat = clip_image_embeds_cat_list.repeat(1, 1, 1)
1209
+ print('catted embeds repeat: ',clip_image_embeds_cat_list_repeat.shape)
1210
+ clip_image_embeds_view = clip_image_embeds_cat_list_repeat.view(1, seq_len, -1)
1211
+ print('catted viewed: ',clip_image_embeds_view.shape)
1212
+ zeros_tensor = torch.zeros_like(clip_image_embeds_view)
1213
+ print('zeros: ',zeros_tensor.shape)
1214
+ clip_image_embeds = torch.cat([zeros_tensor, clip_image_embeds_view], dim=0)
1215
+ print('embeds shape: ', clip_image_embeds.shape)
1216
+ '''
1217
  # 4. Prepare timesteps
1218
  timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1219
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
1306
  else:
1307
  latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1308
 
1309
+ image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
1310
+
1311
  image = self.image_processor.postprocess(image, output_type=output_type)
1312
 
1313
  # Offload all models