caohy666 commited on
Commit
842e89f
·
1 Parent(s): 05cbda5

<fix> add tp ip attn enhancement control in forward.

Browse files
app.py CHANGED
@@ -264,6 +264,7 @@ def process_image_and_text(condition_image, target_prompt, condition_image_promp
264
  frame_gap=48,
265
  mixup=True,
266
  mixup_num_imgs=2,
 
267
  ).frames
268
 
269
  gen_img = gen_img[:, 0:1, :, :, :]
 
264
  frame_gap=48,
265
  mixup=True,
266
  mixup_num_imgs=2,
267
+ enhance_tp=task in ['subject_driven', 'style_transfer'],
268
  ).frames
269
 
270
  gen_img = gen_img[:, 0:1, :, :, :]
models/hyvideo/transformer_hunyuan_video_i2v.py CHANGED
@@ -64,6 +64,7 @@ class HunyuanVideoAttnProcessor2_0:
64
  encoder_hidden_states: Optional[torch.Tensor] = None,
65
  attention_mask: Optional[torch.Tensor] = None,
66
  image_rotary_emb: Optional[torch.Tensor] = None,
 
67
  ) -> torch.Tensor:
68
  if attn.add_q_proj is None and encoder_hidden_states is not None:
69
  hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -154,7 +155,7 @@ class HunyuanVideoAttnProcessor2_0:
154
  k_lens = torch.tensor([sum([u[seg_start[seg]:seg_end[seg]].long().sum().item() for seg in segs]) for u in valid_indices for segs in k_segs],
155
  dtype=torch.int32, device=valid_indices.device)
156
  query = torch.cat([u[i:j][v[i:j]] for u,v in zip(query, valid_indices) for i,j in zip(seg_start, seg_end)], dim=0)
157
- if self.inference_subject_driven:
158
  key = torch.cat([torch.cat([ torch.cat([u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][:144], u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:] + 0.6 * u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:].abs().mean()], dim=0) if segs == [0, 1, 2] and seg == 2 else u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]] for seg in segs], dim=0) \
159
  for u,v in zip(key, valid_indices) for segs in k_segs], dim=0)
160
  else:
@@ -756,6 +757,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
756
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
757
  token_replace_emb: torch.Tensor = None,
758
  num_tokens: int = None,
 
759
  ) -> torch.Tensor:
760
  text_seq_length = encoder_hidden_states.shape[1]
761
  hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -777,6 +779,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
777
  encoder_hidden_states=norm_encoder_hidden_states,
778
  attention_mask=attention_mask,
779
  image_rotary_emb=image_rotary_emb,
 
780
  )
781
  attn_output = torch.cat([attn_output, context_attn_output], dim=1)
782
 
@@ -841,6 +844,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
841
  freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
842
  token_replace_emb: torch.Tensor = None,
843
  num_tokens: int = None,
 
844
  ) -> Tuple[torch.Tensor, torch.Tensor]:
845
  # 1. Input normalization
846
  (
@@ -864,6 +868,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
864
  encoder_hidden_states=norm_encoder_hidden_states,
865
  attention_mask=attention_mask,
866
  image_rotary_emb=freqs_cis,
 
867
  )
868
 
869
  # 3. Modulation and residual connection
@@ -1109,6 +1114,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1109
  attention_kwargs: Optional[Dict[str, Any]] = None,
1110
  return_dict: bool = True,
1111
  frame_gap: Union[int, None] = None,
 
1112
  ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1113
  if attention_kwargs is not None:
1114
  attention_kwargs = attention_kwargs.copy()
@@ -1181,6 +1187,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1181
  image_rotary_emb,
1182
  token_replace_emb,
1183
  first_frame_num_tokens,
 
1184
  )
1185
 
1186
  for block in self.single_transformer_blocks:
@@ -1193,6 +1200,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1193
  image_rotary_emb,
1194
  token_replace_emb,
1195
  first_frame_num_tokens,
 
1196
  )
1197
 
1198
  else:
@@ -1205,6 +1213,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1205
  image_rotary_emb,
1206
  token_replace_emb,
1207
  first_frame_num_tokens,
 
1208
  )
1209
 
1210
  for block in self.single_transformer_blocks:
@@ -1216,6 +1225,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1216
  image_rotary_emb,
1217
  token_replace_emb,
1218
  first_frame_num_tokens,
 
1219
  )
1220
 
1221
  # 5. Output projection
 
64
  encoder_hidden_states: Optional[torch.Tensor] = None,
65
  attention_mask: Optional[torch.Tensor] = None,
66
  image_rotary_emb: Optional[torch.Tensor] = None,
67
+ enhance_tp: bool = False,
68
  ) -> torch.Tensor:
69
  if attn.add_q_proj is None and encoder_hidden_states is not None:
70
  hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
 
155
  k_lens = torch.tensor([sum([u[seg_start[seg]:seg_end[seg]].long().sum().item() for seg in segs]) for u in valid_indices for segs in k_segs],
156
  dtype=torch.int32, device=valid_indices.device)
157
  query = torch.cat([u[i:j][v[i:j]] for u,v in zip(query, valid_indices) for i,j in zip(seg_start, seg_end)], dim=0)
158
+ if self.inference_subject_driven or enhance_tp:
159
  key = torch.cat([torch.cat([ torch.cat([u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][:144], u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:] + 0.6 * u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:].abs().mean()], dim=0) if segs == [0, 1, 2] and seg == 2 else u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]] for seg in segs], dim=0) \
160
  for u,v in zip(key, valid_indices) for segs in k_segs], dim=0)
161
  else:
 
757
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
758
  token_replace_emb: torch.Tensor = None,
759
  num_tokens: int = None,
760
+ enhance_tp: bool = False,
761
  ) -> torch.Tensor:
762
  text_seq_length = encoder_hidden_states.shape[1]
763
  hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
 
779
  encoder_hidden_states=norm_encoder_hidden_states,
780
  attention_mask=attention_mask,
781
  image_rotary_emb=image_rotary_emb,
782
+ enhance_tp=enhance_tp,
783
  )
784
  attn_output = torch.cat([attn_output, context_attn_output], dim=1)
785
 
 
844
  freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
845
  token_replace_emb: torch.Tensor = None,
846
  num_tokens: int = None,
847
+ enhance_tp: bool = False,
848
  ) -> Tuple[torch.Tensor, torch.Tensor]:
849
  # 1. Input normalization
850
  (
 
868
  encoder_hidden_states=norm_encoder_hidden_states,
869
  attention_mask=attention_mask,
870
  image_rotary_emb=freqs_cis,
871
+ enhance_tp=enhance_tp,
872
  )
873
 
874
  # 3. Modulation and residual connection
 
1114
  attention_kwargs: Optional[Dict[str, Any]] = None,
1115
  return_dict: bool = True,
1116
  frame_gap: Union[int, None] = None,
1117
+ enhance_tp: bool = False,
1118
  ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
1119
  if attention_kwargs is not None:
1120
  attention_kwargs = attention_kwargs.copy()
 
1187
  image_rotary_emb,
1188
  token_replace_emb,
1189
  first_frame_num_tokens,
1190
+ enhance_tp,
1191
  )
1192
 
1193
  for block in self.single_transformer_blocks:
 
1200
  image_rotary_emb,
1201
  token_replace_emb,
1202
  first_frame_num_tokens,
1203
+ enhance_tp,
1204
  )
1205
 
1206
  else:
 
1213
  image_rotary_emb,
1214
  token_replace_emb,
1215
  first_frame_num_tokens,
1216
+ enhance_tp,
1217
  )
1218
 
1219
  for block in self.single_transformer_blocks:
 
1225
  image_rotary_emb,
1226
  token_replace_emb,
1227
  first_frame_num_tokens,
1228
+ enhance_tp,
1229
  )
1230
 
1231
  # 5. Output projection
pipelines/pipeline_hunyuan_video_i2v.py CHANGED
@@ -649,6 +649,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
649
  frame_gap: Union[int, None] = None,
650
  mixup: bool = False,
651
  mixup_num_imgs: Union[int, None] = None,
 
652
  ):
653
  r"""
654
  The call function to the pipeline for generation.
@@ -899,6 +900,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
899
  attention_kwargs=attention_kwargs,
900
  return_dict=False,
901
  frame_gap=int(frame_gap / 4) if frame_gap is not None else frame_gap,
 
902
  )[0]
903
 
904
  if do_true_cfg:
 
649
  frame_gap: Union[int, None] = None,
650
  mixup: bool = False,
651
  mixup_num_imgs: Union[int, None] = None,
652
+ enhance_tp: bool = False,
653
  ):
654
  r"""
655
  The call function to the pipeline for generation.
 
900
  attention_kwargs=attention_kwargs,
901
  return_dict=False,
902
  frame_gap=int(frame_gap / 4) if frame_gap is not None else frame_gap,
903
+ enhance_tp=enhance_tp,
904
  )[0]
905
 
906
  if do_true_cfg: