Spaces:
Running
on
Zero
Running
on
Zero
<fix> add tp ip attn enhancement control in forward.
Browse files
app.py
CHANGED
@@ -264,6 +264,7 @@ def process_image_and_text(condition_image, target_prompt, condition_image_promp
|
|
264 |
frame_gap=48,
|
265 |
mixup=True,
|
266 |
mixup_num_imgs=2,
|
|
|
267 |
).frames
|
268 |
|
269 |
gen_img = gen_img[:, 0:1, :, :, :]
|
|
|
264 |
frame_gap=48,
|
265 |
mixup=True,
|
266 |
mixup_num_imgs=2,
|
267 |
+
enhance_tp=task in ['subject_driven', 'style_transfer'],
|
268 |
).frames
|
269 |
|
270 |
gen_img = gen_img[:, 0:1, :, :, :]
|
models/hyvideo/transformer_hunyuan_video_i2v.py
CHANGED
@@ -64,6 +64,7 @@ class HunyuanVideoAttnProcessor2_0:
|
|
64 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
65 |
attention_mask: Optional[torch.Tensor] = None,
|
66 |
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
|
67 |
) -> torch.Tensor:
|
68 |
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
69 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
@@ -154,7 +155,7 @@ class HunyuanVideoAttnProcessor2_0:
|
|
154 |
k_lens = torch.tensor([sum([u[seg_start[seg]:seg_end[seg]].long().sum().item() for seg in segs]) for u in valid_indices for segs in k_segs],
|
155 |
dtype=torch.int32, device=valid_indices.device)
|
156 |
query = torch.cat([u[i:j][v[i:j]] for u,v in zip(query, valid_indices) for i,j in zip(seg_start, seg_end)], dim=0)
|
157 |
-
if self.inference_subject_driven:
|
158 |
key = torch.cat([torch.cat([ torch.cat([u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][:144], u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:] + 0.6 * u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:].abs().mean()], dim=0) if segs == [0, 1, 2] and seg == 2 else u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]] for seg in segs], dim=0) \
|
159 |
for u,v in zip(key, valid_indices) for segs in k_segs], dim=0)
|
160 |
else:
|
@@ -756,6 +757,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
|
756 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
757 |
token_replace_emb: torch.Tensor = None,
|
758 |
num_tokens: int = None,
|
|
|
759 |
) -> torch.Tensor:
|
760 |
text_seq_length = encoder_hidden_states.shape[1]
|
761 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
@@ -777,6 +779,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
|
777 |
encoder_hidden_states=norm_encoder_hidden_states,
|
778 |
attention_mask=attention_mask,
|
779 |
image_rotary_emb=image_rotary_emb,
|
|
|
780 |
)
|
781 |
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
782 |
|
@@ -841,6 +844,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
|
841 |
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
842 |
token_replace_emb: torch.Tensor = None,
|
843 |
num_tokens: int = None,
|
|
|
844 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
845 |
# 1. Input normalization
|
846 |
(
|
@@ -864,6 +868,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
|
864 |
encoder_hidden_states=norm_encoder_hidden_states,
|
865 |
attention_mask=attention_mask,
|
866 |
image_rotary_emb=freqs_cis,
|
|
|
867 |
)
|
868 |
|
869 |
# 3. Modulation and residual connection
|
@@ -1109,6 +1114,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
1109 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
1110 |
return_dict: bool = True,
|
1111 |
frame_gap: Union[int, None] = None,
|
|
|
1112 |
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
1113 |
if attention_kwargs is not None:
|
1114 |
attention_kwargs = attention_kwargs.copy()
|
@@ -1181,6 +1187,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
1181 |
image_rotary_emb,
|
1182 |
token_replace_emb,
|
1183 |
first_frame_num_tokens,
|
|
|
1184 |
)
|
1185 |
|
1186 |
for block in self.single_transformer_blocks:
|
@@ -1193,6 +1200,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
1193 |
image_rotary_emb,
|
1194 |
token_replace_emb,
|
1195 |
first_frame_num_tokens,
|
|
|
1196 |
)
|
1197 |
|
1198 |
else:
|
@@ -1205,6 +1213,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
1205 |
image_rotary_emb,
|
1206 |
token_replace_emb,
|
1207 |
first_frame_num_tokens,
|
|
|
1208 |
)
|
1209 |
|
1210 |
for block in self.single_transformer_blocks:
|
@@ -1216,6 +1225,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
1216 |
image_rotary_emb,
|
1217 |
token_replace_emb,
|
1218 |
first_frame_num_tokens,
|
|
|
1219 |
)
|
1220 |
|
1221 |
# 5. Output projection
|
|
|
64 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
65 |
attention_mask: Optional[torch.Tensor] = None,
|
66 |
image_rotary_emb: Optional[torch.Tensor] = None,
|
67 |
+
enhance_tp: bool = False,
|
68 |
) -> torch.Tensor:
|
69 |
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
70 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
|
155 |
k_lens = torch.tensor([sum([u[seg_start[seg]:seg_end[seg]].long().sum().item() for seg in segs]) for u in valid_indices for segs in k_segs],
|
156 |
dtype=torch.int32, device=valid_indices.device)
|
157 |
query = torch.cat([u[i:j][v[i:j]] for u,v in zip(query, valid_indices) for i,j in zip(seg_start, seg_end)], dim=0)
|
158 |
+
if self.inference_subject_driven or enhance_tp:
|
159 |
key = torch.cat([torch.cat([ torch.cat([u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][:144], u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:] + 0.6 * u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:].abs().mean()], dim=0) if segs == [0, 1, 2] and seg == 2 else u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]] for seg in segs], dim=0) \
|
160 |
for u,v in zip(key, valid_indices) for segs in k_segs], dim=0)
|
161 |
else:
|
|
|
757 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
758 |
token_replace_emb: torch.Tensor = None,
|
759 |
num_tokens: int = None,
|
760 |
+
enhance_tp: bool = False,
|
761 |
) -> torch.Tensor:
|
762 |
text_seq_length = encoder_hidden_states.shape[1]
|
763 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
|
779 |
encoder_hidden_states=norm_encoder_hidden_states,
|
780 |
attention_mask=attention_mask,
|
781 |
image_rotary_emb=image_rotary_emb,
|
782 |
+
enhance_tp=enhance_tp,
|
783 |
)
|
784 |
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
785 |
|
|
|
844 |
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
845 |
token_replace_emb: torch.Tensor = None,
|
846 |
num_tokens: int = None,
|
847 |
+
enhance_tp: bool = False,
|
848 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
849 |
# 1. Input normalization
|
850 |
(
|
|
|
868 |
encoder_hidden_states=norm_encoder_hidden_states,
|
869 |
attention_mask=attention_mask,
|
870 |
image_rotary_emb=freqs_cis,
|
871 |
+
enhance_tp=enhance_tp,
|
872 |
)
|
873 |
|
874 |
# 3. Modulation and residual connection
|
|
|
1114 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
1115 |
return_dict: bool = True,
|
1116 |
frame_gap: Union[int, None] = None,
|
1117 |
+
enhance_tp: bool = False,
|
1118 |
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
1119 |
if attention_kwargs is not None:
|
1120 |
attention_kwargs = attention_kwargs.copy()
|
|
|
1187 |
image_rotary_emb,
|
1188 |
token_replace_emb,
|
1189 |
first_frame_num_tokens,
|
1190 |
+
enhance_tp,
|
1191 |
)
|
1192 |
|
1193 |
for block in self.single_transformer_blocks:
|
|
|
1200 |
image_rotary_emb,
|
1201 |
token_replace_emb,
|
1202 |
first_frame_num_tokens,
|
1203 |
+
enhance_tp,
|
1204 |
)
|
1205 |
|
1206 |
else:
|
|
|
1213 |
image_rotary_emb,
|
1214 |
token_replace_emb,
|
1215 |
first_frame_num_tokens,
|
1216 |
+
enhance_tp,
|
1217 |
)
|
1218 |
|
1219 |
for block in self.single_transformer_blocks:
|
|
|
1225 |
image_rotary_emb,
|
1226 |
token_replace_emb,
|
1227 |
first_frame_num_tokens,
|
1228 |
+
enhance_tp,
|
1229 |
)
|
1230 |
|
1231 |
# 5. Output projection
|
pipelines/pipeline_hunyuan_video_i2v.py
CHANGED
@@ -649,6 +649,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
|
649 |
frame_gap: Union[int, None] = None,
|
650 |
mixup: bool = False,
|
651 |
mixup_num_imgs: Union[int, None] = None,
|
|
|
652 |
):
|
653 |
r"""
|
654 |
The call function to the pipeline for generation.
|
@@ -899,6 +900,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
|
899 |
attention_kwargs=attention_kwargs,
|
900 |
return_dict=False,
|
901 |
frame_gap=int(frame_gap / 4) if frame_gap is not None else frame_gap,
|
|
|
902 |
)[0]
|
903 |
|
904 |
if do_true_cfg:
|
|
|
649 |
frame_gap: Union[int, None] = None,
|
650 |
mixup: bool = False,
|
651 |
mixup_num_imgs: Union[int, None] = None,
|
652 |
+
enhance_tp: bool = False,
|
653 |
):
|
654 |
r"""
|
655 |
The call function to the pipeline for generation.
|
|
|
900 |
attention_kwargs=attention_kwargs,
|
901 |
return_dict=False,
|
902 |
frame_gap=int(frame_gap / 4) if frame_gap is not None else frame_gap,
|
903 |
+
enhance_tp=enhance_tp,
|
904 |
)[0]
|
905 |
|
906 |
if do_true_cfg:
|