This Pull Request upgrades to the modern AI FramePack

#2
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Example1.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ img_examples/Example1.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ img_examples/Example1.png filter=lfs diff=lfs merge=lfs -text
39
+ img_examples/Example2.webp filter=lfs diff=lfs merge=lfs -text
40
+ img_examples/Example3.jpg filter=lfs diff=lfs merge=lfs -text
41
+ img_examples/Example4.webp filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,21 @@
1
- ---
2
- title: Openai Sora
3
- emoji: 🚀🎥
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.31.4
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Generate incredible videos using Openai Sora
12
- ---
13
-
 
 
 
 
 
 
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: FramePack/HunyuanVideo
3
+ emoji: 🎥
4
+ colorFrom: pink
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.29.1
8
+ app_file: app.py
9
+ license: apache-2.0
10
+ short_description: Text-to-Video/Image-to-Video/Video extender (timed prompt)
11
+ tags:
12
+ - Image-to-Video
13
+ - Image-2-Video
14
+ - Img-to-Vid
15
+ - Img-2-Vid
16
+ - language models
17
+ - LLMs
18
+ suggested_hardware: zero-a10g
19
+ ---
20
+
21
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,52 +1,1393 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import spaces
3
  import torch
4
- from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
5
- from diffusers.utils import export_to_video
6
- import cv2
7
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
10
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
11
- pipe.enable_model_cpu_offload()
12
- pipe.enable_vae_slicing()
13
-
14
- @spaces.GPU(duration=250)
15
- def generate(prompt, num_inference_steps, num_frames):
16
- video_frames = pipe(prompt, num_inference_steps=num_inference_steps, num_frames=num_frames).frames[0]
17
- video_path = export_to_video(video_frames, fps=10)
18
- return video_path
19
-
20
- prompt = gr.Textbox(label="Enter prompt to generate a video", info="Based on this prompt ai will generate a video")
21
- description="""
22
- 🚀 This is **unofficial** demo of Openai's Sora that haven't been released yet.\n
23
- ✔ This space made using [ali-vilab/text-to-video-ms-1.7b](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b)\n
24
- Estimated generation time is **150 seconds**\n
25
- 🎁 Space is running on ZeroGPU, if you want faster generation, duplicate space and choose faster GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
- num_inference_steps=gr.Slider(8, 128, step=1, value=24, label="Num Inference Steps", info="More steps then better quality")
28
- num_frames=gr.Slider(8, 1000, step=1, value=200, label="Num of Frames", info="It is duration of video")
29
-
30
- interface = gr.Interface(
31
- generate,
32
- inputs=[prompt],
33
- additional_inputs=[num_inference_steps, num_frames],
34
- examples=[
35
- ["Astronaut riding a horse", 60, 100],
36
- ["Darth vader surfing in waves", 30, 200],
37
- ["A house in the woods in ocean", 70, 100],
38
- ["A car in the forest", 70, 100],
39
- ["A house firing", 60, 150],
40
- ["A plane firing and falling down", 100, 20],
41
- ["Campfire", 50, 50],
42
- ["Zombie apocalypse", 100, 20],
43
- ["A New Yourk City", 100, 20],
44
- ["A man running in beautiufl forest", 100, 20],
45
- ["A cup of tea with fog", 100, 20]
46
- ],
47
- outputs="video",
48
- title="Openai Sora (Unofficial)",
49
- description=description,
50
- cache_examples=False,
51
- theme="soft"
52
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers_helper.hf_login import login
2
+
3
+ import os
4
+
5
+ os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
6
+
7
+ try:
8
+ import spaces
9
+ except:
10
+ print("Not on HuggingFace")
11
  import gradio as gr
 
12
  import torch
13
+ import traceback
14
+ import einops
15
+ import safetensors.torch as sf
16
  import numpy as np
17
+ import random
18
+ import time
19
+ import math
20
+ # 20250506 pftq: Added for video input loading
21
+ import decord
22
+ # 20250506 pftq: Added for progress bars in video_encode
23
+ from tqdm import tqdm
24
+ # 20250506 pftq: Normalize file paths for Windows compatibility
25
+ import pathlib
26
+ # 20250506 pftq: for easier to read timestamp
27
+ from datetime import datetime
28
+ # 20250508 pftq: for saving prompt to mp4 comments metadata
29
+ import imageio_ffmpeg
30
+ import tempfile
31
+ import shutil
32
+ import subprocess
33
+
34
+ from PIL import Image
35
+ from diffusers import AutoencoderKLHunyuanVideo
36
+ from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
37
+ from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
38
+ from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
39
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
40
+ from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
41
+ if torch.cuda.device_count() > 0:
42
+ from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
43
+ from diffusers_helper.thread_utils import AsyncStream, async_run
44
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
45
+ from transformers import SiglipImageProcessor, SiglipVisionModel
46
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
47
+ from diffusers_helper.bucket_tools import find_nearest_bucket
48
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
49
+ import pillow_heif
50
+
51
+ pillow_heif.register_heif_opener()
52
+
53
+ high_vram = False
54
+ free_mem_gb = 0
55
+
56
+ if torch.cuda.device_count() > 0:
57
+ free_mem_gb = get_cuda_free_memory_gb(gpu)
58
+ high_vram = free_mem_gb > 60
59
+
60
+ print(f'Free VRAM {free_mem_gb} GB')
61
+ print(f'High-VRAM Mode: {high_vram}')
62
+
63
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
64
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
65
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
66
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
67
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
68
+
69
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
70
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
71
+
72
+ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePack_F1_I2V_HY_20250503', torch_dtype=torch.bfloat16).cpu()
73
+
74
+ vae.eval()
75
+ text_encoder.eval()
76
+ text_encoder_2.eval()
77
+ image_encoder.eval()
78
+ transformer.eval()
79
+
80
+ if not high_vram:
81
+ vae.enable_slicing()
82
+ vae.enable_tiling()
83
+
84
+ transformer.high_quality_fp32_output_for_inference = True
85
+ print('transformer.high_quality_fp32_output_for_inference = True')
86
+
87
+ transformer.to(dtype=torch.bfloat16)
88
+ vae.to(dtype=torch.float16)
89
+ image_encoder.to(dtype=torch.float16)
90
+ text_encoder.to(dtype=torch.float16)
91
+ text_encoder_2.to(dtype=torch.float16)
92
+
93
+ vae.requires_grad_(False)
94
+ text_encoder.requires_grad_(False)
95
+ text_encoder_2.requires_grad_(False)
96
+ image_encoder.requires_grad_(False)
97
+ transformer.requires_grad_(False)
98
+
99
+ if not high_vram:
100
+ # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
101
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
102
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
103
+ else:
104
+ text_encoder.to(gpu)
105
+ text_encoder_2.to(gpu)
106
+ image_encoder.to(gpu)
107
+ vae.to(gpu)
108
+ transformer.to(gpu)
109
+
110
+ stream = AsyncStream()
111
+
112
+ outputs_folder = './outputs/'
113
+ os.makedirs(outputs_folder, exist_ok=True)
114
+
115
+ default_local_storage = {
116
+ "generation-mode": "image",
117
+ }
118
+
119
+ @torch.no_grad()
120
+ def video_encode(video_path, resolution, no_resize, vae, vae_batch_size=16, device="cuda", width=None, height=None):
121
+ """
122
+ Encode a video into latent representations using the VAE.
123
+
124
+ Args:
125
+ video_path: Path to the input video file.
126
+ vae: AutoencoderKLHunyuanVideo model.
127
+ height, width: Target resolution for resizing frames.
128
+ vae_batch_size: Number of frames to process per batch.
129
+ device: Device for computation (e.g., "cuda").
130
+
131
+ Returns:
132
+ start_latent: Latent of the first frame (for compatibility with original code).
133
+ input_image_np: First frame as numpy array (for CLIP vision encoding).
134
+ history_latents: Latents of all frames (shape: [1, channels, frames, height//8, width//8]).
135
+ fps: Frames per second of the input video.
136
+ """
137
+ # 20250506 pftq: Normalize video path for Windows compatibility
138
+ video_path = str(pathlib.Path(video_path).resolve())
139
+ print(f"Processing video: {video_path}")
140
+
141
+ # 20250506 pftq: Check CUDA availability and fallback to CPU if needed
142
+ if device == "cuda" and not torch.cuda.is_available():
143
+ print("CUDA is not available, falling back to CPU")
144
+ device = "cpu"
145
+
146
+ try:
147
+ # 20250506 pftq: Load video and get FPS
148
+ print("Initializing VideoReader...")
149
+ vr = decord.VideoReader(video_path)
150
+ fps = vr.get_avg_fps() # Get input video FPS
151
+ num_real_frames = len(vr)
152
+ print(f"Video loaded: {num_real_frames} frames, FPS: {fps}")
153
+
154
+ # Truncate to nearest latent size (multiple of 4)
155
+ latent_size_factor = 4
156
+ num_frames = (num_real_frames // latent_size_factor) * latent_size_factor
157
+ if num_frames != num_real_frames:
158
+ print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility")
159
+ num_real_frames = num_frames
160
+
161
+ # 20250506 pftq: Read frames
162
+ print("Reading video frames...")
163
+ frames = vr.get_batch(range(num_real_frames)).asnumpy() # Shape: (num_real_frames, height, width, channels)
164
+ print(f"Frames read: {frames.shape}")
165
+
166
+ # 20250506 pftq: Get native video resolution
167
+ native_height, native_width = frames.shape[1], frames.shape[2]
168
+ print(f"Native video resolution: {native_width}x{native_height}")
169
+
170
+ # 20250506 pftq: Use native resolution if height/width not specified, otherwise use provided values
171
+ target_height = native_height if height is None else height
172
+ target_width = native_width if width is None else width
173
+
174
+ # 20250506 pftq: Adjust to nearest bucket for model compatibility
175
+ if not no_resize:
176
+ target_height, target_width = find_nearest_bucket(target_height, target_width, resolution=resolution)
177
+ print(f"Adjusted resolution: {target_width}x{target_height}")
178
+ else:
179
+ print(f"Using native resolution without resizing: {target_width}x{target_height}")
180
+
181
+ # 20250506 pftq: Preprocess frames to match original image processing
182
+ processed_frames = []
183
+ for i, frame in enumerate(frames):
184
+ #print(f"Preprocessing frame {i+1}/{num_frames}")
185
+ frame_np = resize_and_center_crop(frame, target_width=target_width, target_height=target_height)
186
+ processed_frames.append(frame_np)
187
+ processed_frames = np.stack(processed_frames) # Shape: (num_real_frames, height, width, channels)
188
+ print(f"Frames preprocessed: {processed_frames.shape}")
189
+
190
+ # 20250506 pftq: Save first frame for CLIP vision encoding
191
+ input_image_np = processed_frames[0]
192
+
193
+ # 20250506 pftq: Convert to tensor and normalize to [-1, 1]
194
+ print("Converting frames to tensor...")
195
+ frames_pt = torch.from_numpy(processed_frames).float() / 127.5 - 1
196
+ frames_pt = frames_pt.permute(0, 3, 1, 2) # Shape: (num_real_frames, channels, height, width)
197
+ frames_pt = frames_pt.unsqueeze(0) # Shape: (1, num_real_frames, channels, height, width)
198
+ frames_pt = frames_pt.permute(0, 2, 1, 3, 4) # Shape: (1, channels, num_real_frames, height, width)
199
+ print(f"Tensor shape: {frames_pt.shape}")
200
+
201
+ # 20250507 pftq: Save pixel frames for use in worker
202
+ input_video_pixels = frames_pt.cpu()
203
+
204
+ # 20250506 pftq: Move to device
205
+ print(f"Moving tensor to device: {device}")
206
+ frames_pt = frames_pt.to(device)
207
+ print("Tensor moved to device")
208
+
209
+ # 20250506 pftq: Move VAE to device
210
+ print(f"Moving VAE to device: {device}")
211
+ vae.to(device)
212
+ print("VAE moved to device")
213
+
214
+ # 20250506 pftq: Encode frames in batches
215
+ print(f"Encoding input video frames in VAE batch size {vae_batch_size} (reduce if memory issues here or if forcing video resolution)")
216
+ latents = []
217
+ vae.eval()
218
+ with torch.no_grad():
219
+ for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="Encoding video frames", mininterval=0.1):
220
+ #print(f"Encoding batch {i//vae_batch_size + 1}: frames {i} to {min(i + vae_batch_size, frames_pt.shape[2])}")
221
+ batch = frames_pt[:, :, i:i + vae_batch_size] # Shape: (1, channels, batch_size, height, width)
222
+ try:
223
+ # 20250506 pftq: Log GPU memory before encoding
224
+ if device == "cuda":
225
+ free_mem = torch.cuda.memory_allocated() / 1024**3
226
+ #print(f"GPU memory before encoding: {free_mem:.2f} GB")
227
+ batch_latent = vae_encode(batch, vae)
228
+ # 20250506 pftq: Synchronize CUDA to catch issues
229
+ if device == "cuda":
230
+ torch.cuda.synchronize()
231
+ #print(f"GPU memory after encoding: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
232
+ latents.append(batch_latent)
233
+ #print(f"Batch encoded, latent shape: {batch_latent.shape}")
234
+ except RuntimeError as e:
235
+ print(f"Error during VAE encoding: {str(e)}")
236
+ if device == "cuda" and "out of memory" in str(e).lower():
237
+ print("CUDA out of memory, try reducing vae_batch_size or using CPU")
238
+ raise
239
+
240
+ # 20250506 pftq: Concatenate latents
241
+ print("Concatenating latents...")
242
+ history_latents = torch.cat(latents, dim=2) # Shape: (1, channels, frames, height//8, width//8)
243
+ print(f"History latents shape: {history_latents.shape}")
244
+
245
+ # 20250506 pftq: Get first frame's latent
246
+ start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8)
247
+ print(f"Start latent shape: {start_latent.shape}")
248
+
249
+ # 20250506 pftq: Move VAE back to CPU to free GPU memory
250
+ if device == "cuda":
251
+ vae.to(cpu)
252
+ torch.cuda.empty_cache()
253
+ print("VAE moved back to CPU, CUDA cache cleared")
254
+
255
+ return start_latent, input_image_np, history_latents, fps, target_height, target_width, input_video_pixels
256
+
257
+ except Exception as e:
258
+ print(f"Error in video_encode: {str(e)}")
259
+ raise
260
+
261
+ # 20250508 pftq: for saving prompt to mp4 metadata comments
262
+ def set_mp4_comments_imageio_ffmpeg(input_file, comments):
263
+ try:
264
+ # Get the path to the bundled FFmpeg binary from imageio-ffmpeg
265
+ ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
266
+
267
+ # Check if input file exists
268
+ if not os.path.exists(input_file):
269
+ print(f"Error: Input file {input_file} does not exist")
270
+ return False
271
+
272
+ # Create a temporary file path
273
+ temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
274
+
275
+ # FFmpeg command using the bundled binary
276
+ command = [
277
+ ffmpeg_path, # Use imageio-ffmpeg's FFmpeg
278
+ '-i', input_file, # input file
279
+ '-metadata', f'comment={comments}', # set comment metadata
280
+ '-c:v', 'copy', # copy video stream without re-encoding
281
+ '-c:a', 'copy', # copy audio stream without re-encoding
282
+ '-y', # overwrite output file if it exists
283
+ temp_file # temporary output file
284
+ ]
285
+
286
+ # Run the FFmpeg command
287
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
288
+
289
+ if result.returncode == 0:
290
+ # Replace the original file with the modified one
291
+ shutil.move(temp_file, input_file)
292
+ print(f"Successfully added comments to {input_file}")
293
+ return True
294
+ else:
295
+ # Clean up temp file if FFmpeg fails
296
+ if os.path.exists(temp_file):
297
+ os.remove(temp_file)
298
+ print(f"Error: FFmpeg failed with message:\n{result.stderr}")
299
+ return False
300
+
301
+ except Exception as e:
302
+ # Clean up temp file in case of other errors
303
+ if 'temp_file' in locals() and os.path.exists(temp_file):
304
+ os.remove(temp_file)
305
+ print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e))
306
+ return False
307
+
308
+ @torch.no_grad()
309
+ def worker(input_image, image_position, prompts, n_prompt, seed, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf):
310
+ is_last_frame = (image_position == 100)
311
+ def encode_prompt(prompt, n_prompt):
312
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
313
+
314
+ if cfg == 1:
315
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
316
+ else:
317
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
318
+
319
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
320
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
321
+
322
+ llama_vec = llama_vec.to(transformer.dtype)
323
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
324
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
325
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
326
+ return [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n]
327
+
328
+ total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
329
+ total_latent_sections = int(max(round(total_latent_sections), 1))
330
+
331
+ job_id = generate_timestamp()
332
+
333
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
334
+
335
+ try:
336
+ # Clean GPU
337
+ if not high_vram:
338
+ unload_complete_models(
339
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
340
+ )
341
+
342
+ # Text encoding
343
+
344
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
345
+
346
+ if not high_vram:
347
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
348
+ load_model_as_complete(text_encoder_2, target_device=gpu)
349
+
350
+ prompt_parameters = []
351
+
352
+ for prompt_part in prompts:
353
+ prompt_parameters.append(encode_prompt(prompt_part, n_prompt))
354
+
355
+ # Processing input image
356
+
357
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
358
+
359
+ H, W, C = input_image.shape
360
+ height, width = find_nearest_bucket(H, W, resolution=resolution)
361
+
362
+ def get_start_latent(input_image, height, width, vae, gpu, image_encoder, high_vram):
363
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
364
+
365
+ #Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
366
+
367
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
368
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
369
+
370
+ # VAE encoding
371
+
372
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
373
+
374
+ if not high_vram:
375
+ load_model_as_complete(vae, target_device=gpu)
376
+
377
+ start_latent = vae_encode(input_image_pt, vae)
378
+
379
+ # CLIP Vision
380
+
381
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
382
+
383
+ if not high_vram:
384
+ load_model_as_complete(image_encoder, target_device=gpu)
385
+
386
+ image_encoder_last_hidden_state = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder).last_hidden_state
387
+
388
+ return [start_latent, image_encoder_last_hidden_state]
389
+
390
+ [start_latent, image_encoder_last_hidden_state] = get_start_latent(input_image, height, width, vae, gpu, image_encoder, high_vram)
391
+
392
+ # Dtype
393
+
394
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
395
+
396
+ # Sampling
397
+
398
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
399
+
400
+ rnd = torch.Generator("cpu").manual_seed(seed)
401
+
402
+ history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu()
403
+ start_latent = start_latent.to(history_latents)
404
+ history_pixels = None
405
+
406
+ history_latents = torch.cat([start_latent, history_latents] if is_last_frame else [history_latents, start_latent], dim=2)
407
+ total_generated_latent_frames = 1
408
+
409
+ if enable_preview:
410
+ def callback(d):
411
+ preview = d['denoised']
412
+ preview = vae_decode_fake(preview)
413
+
414
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
415
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
416
+
417
+ if stream.input_queue.top() == 'end':
418
+ stream.output_queue.push(('end', None))
419
+ raise KeyboardInterrupt('User ends the task.')
420
+
421
+ current_step = d['i'] + 1
422
+ percentage = int(100.0 * current_step / steps)
423
+ hint = f'Sampling {current_step}/{steps}'
424
+ desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30), Resolution: {height}px * {width}px. The video is being extended now ...'
425
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
426
+ return
427
+ else:
428
+ def callback(d):
429
+ return
430
+
431
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
432
+ if is_last_frame:
433
+ latent_indices, clean_latent_1x_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latent_indices_start = indices.split([latent_window_size, 1, 2, 16, 1], dim=1)
434
+ clean_latent_indices = torch.cat([clean_latent_1x_indices, clean_latent_indices_start], dim=1)
435
+ else:
436
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
437
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
438
+
439
+ def post_process(generated_latents, total_generated_latent_frames, history_latents, high_vram, transformer, gpu, vae, history_pixels, latent_window_size, enable_preview, section_index, total_latent_sections, outputs_folder, mp4_crf, stream):
440
+ total_generated_latent_frames += int(generated_latents.shape[2])
441
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) if is_last_frame else torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
442
+
443
+ if not high_vram:
444
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
445
+ load_model_as_complete(vae, target_device=gpu)
446
+
447
+ if history_pixels is None:
448
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] if is_last_frame else history_latents[:, :, -total_generated_latent_frames:, :, :]
449
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
450
+ else:
451
+ section_latent_frames = latent_window_size * 2
452
+ overlapped_frames = latent_window_size * 4 - 3
453
+
454
+ if is_last_frame:
455
+ real_history_latents = history_latents[:, :, :min(section_latent_frames, total_generated_latent_frames), :, :]
456
+ history_pixels = soft_append_bcthw(vae_decode(real_history_latents, vae).cpu(), history_pixels, overlapped_frames)
457
+ else:
458
+ real_history_latents = history_latents[:, :, -min(section_latent_frames, total_generated_latent_frames):, :, :]
459
+ history_pixels = soft_append_bcthw(history_pixels, vae_decode(real_history_latents, vae).cpu(), overlapped_frames)
460
+
461
+ if not high_vram:
462
+ unload_complete_models()
463
+
464
+ if enable_preview or section_index == (0 if is_last_frame else (total_latent_sections - 1)):
465
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
466
+
467
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)
468
+
469
+ print(f'Decoded. Current latent shape pixel shape {history_pixels.shape}')
470
+
471
+ stream.output_queue.push(('file', output_filename))
472
+ return [total_generated_latent_frames, history_latents, history_pixels]
473
+
474
+ for section_index in range(total_latent_sections - 1, -1, -1) if is_last_frame else range(total_latent_sections):
475
+ if stream.input_queue.top() == 'end':
476
+ stream.output_queue.push(('end', None))
477
+ return
478
+
479
+ print(f'section_index = {section_index}, total_latent_sections = {total_latent_sections}')
480
+
481
+ if len(prompt_parameters) > 0:
482
+ [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n] = prompt_parameters.pop((len(prompt_parameters) - 1) if is_last_frame else 0)
483
+
484
+ if not high_vram:
485
+ unload_complete_models()
486
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
487
+
488
+ if use_teacache:
489
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
490
+ else:
491
+ transformer.initialize_teacache(enable_teacache=False)
492
+
493
+ if is_last_frame:
494
+ clean_latents_1x, clean_latents_2x, clean_latents_4x = history_latents[:, :, :sum([1, 2, 16]), :, :].split([1, 2, 16], dim=2)
495
+ clean_latents = torch.cat([clean_latents_1x, start_latent], dim=2)
496
+ else:
497
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
498
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
499
+
500
+ generated_latents = sample_hunyuan(
501
+ transformer=transformer,
502
+ sampler='unipc',
503
+ width=width,
504
+ height=height,
505
+ frames=latent_window_size * 4 - 3,
506
+ real_guidance_scale=cfg,
507
+ distilled_guidance_scale=gs,
508
+ guidance_rescale=rs,
509
+ # shift=3.0,
510
+ num_inference_steps=steps,
511
+ generator=rnd,
512
+ prompt_embeds=llama_vec,
513
+ prompt_embeds_mask=llama_attention_mask,
514
+ prompt_poolers=clip_l_pooler,
515
+ negative_prompt_embeds=llama_vec_n,
516
+ negative_prompt_embeds_mask=llama_attention_mask_n,
517
+ negative_prompt_poolers=clip_l_pooler_n,
518
+ device=gpu,
519
+ dtype=torch.bfloat16,
520
+ image_embeddings=image_encoder_last_hidden_state,
521
+ latent_indices=latent_indices,
522
+ clean_latents=clean_latents,
523
+ clean_latent_indices=clean_latent_indices,
524
+ clean_latents_2x=clean_latents_2x,
525
+ clean_latent_2x_indices=clean_latent_2x_indices,
526
+ clean_latents_4x=clean_latents_4x,
527
+ clean_latent_4x_indices=clean_latent_4x_indices,
528
+ callback=callback,
529
+ )
530
+
531
+ [total_generated_latent_frames, history_latents, history_pixels] = post_process(generated_latents, total_generated_latent_frames, history_latents, high_vram, transformer, gpu, vae, history_pixels, latent_window_size, enable_preview, section_index, total_latent_sections, outputs_folder, mp4_crf, stream)
532
+ except:
533
+ traceback.print_exc()
534
+
535
+ if not high_vram:
536
+ unload_complete_models(
537
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
538
+ )
539
+
540
+ stream.output_queue.push(('end', None))
541
+ return
542
+
543
+ # 20250506 pftq: Modified worker to accept video input and clean frame count
544
+ @torch.no_grad()
545
+ def worker_video(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
546
+ def encode_prompt(prompt, n_prompt):
547
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
548
+
549
+ if cfg == 1:
550
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
551
+ else:
552
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
553
+
554
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
555
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
556
+
557
+ llama_vec = llama_vec.to(transformer.dtype)
558
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
559
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
560
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
561
+ return [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n]
562
+
563
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
564
+
565
+ try:
566
+ # Clean GPU
567
+ if not high_vram:
568
+ unload_complete_models(
569
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
570
+ )
571
+
572
+ # Text encoding
573
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
574
+
575
+ if not high_vram:
576
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
577
+ load_model_as_complete(text_encoder_2, target_device=gpu)
578
+
579
+ prompt_parameters = []
580
+
581
+ for prompt_part in prompts:
582
+ prompt_parameters.append(encode_prompt(prompt_part, n_prompt))
583
+
584
+ # 20250506 pftq: Processing input video instead of image
585
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Video processing ...'))))
586
+
587
+ # 20250506 pftq: Encode video
588
+ start_latent, input_image_np, video_latents, fps, height, width = video_encode(input_video, resolution, no_resize, vae, vae_batch_size=vae_batch, device=gpu)[:6]
589
+ start_latent = start_latent.to(dtype=torch.float32).cpu()
590
+ video_latents = video_latents.cpu()
591
 
592
+ # CLIP Vision
593
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
594
+
595
+ if not high_vram:
596
+ load_model_as_complete(image_encoder, target_device=gpu)
597
+
598
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
599
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
600
+
601
+ # Dtype
602
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
603
+
604
+ total_latent_sections = (total_second_length * fps) / (latent_window_size * 4)
605
+ total_latent_sections = int(max(round(total_latent_sections), 1))
606
+
607
+ if enable_preview:
608
+ def callback(d):
609
+ preview = d['denoised']
610
+ preview = vae_decode_fake(preview)
611
+
612
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
613
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
614
+
615
+ if stream.input_queue.top() == 'end':
616
+ stream.output_queue.push(('end', None))
617
+ raise KeyboardInterrupt('User ends the task.')
618
+
619
+ current_step = d['i'] + 1
620
+ percentage = int(100.0 * current_step / steps)
621
+ hint = f'Sampling {current_step}/{steps}'
622
+ desc = f'Total frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / fps) :.2f} seconds (FPS-{fps}), Resolution: {height}px * {width}px, Seed: {seed}, Video {idx+1} of {batch}. The video is generating part {section_index+1} of {total_latent_sections}...'
623
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
624
+ return
625
+ else:
626
+ def callback(d):
627
+ return
628
+
629
+ def compute_latent(history_latents, latent_window_size, num_clean_frames, start_latent):
630
+ # 20250506 pftq: Use user-specified number of context frames, matching original allocation for num_clean_frames=2
631
+ available_frames = history_latents.shape[2] # Number of latent frames
632
+ max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames
633
+ adjusted_latent_frames = max(1, (max_pixel_frames + 3) // 4) # Convert back to latent frames
634
+ # Adjust num_clean_frames to match original behavior: num_clean_frames=2 means 1 frame for clean_latents_1x
635
+ effective_clean_frames = max(0, num_clean_frames - 1)
636
+ effective_clean_frames = min(effective_clean_frames, available_frames - 2) if available_frames > 2 else 0 # 20250507 pftq: changed 1 to 2 for edge case for <=1 sec videos
637
+ num_2x_frames = min(2, max(1, available_frames - effective_clean_frames - 1)) if available_frames > effective_clean_frames + 1 else 0 # 20250507 pftq: subtracted 1 for edge case for <=1 sec videos
638
+ num_4x_frames = min(16, max(1, available_frames - effective_clean_frames - num_2x_frames)) if available_frames > effective_clean_frames + num_2x_frames else 0 # 20250507 pftq: Edge case for <=1 sec
639
+
640
+ total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames
641
+ total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos
642
+
643
+ indices = torch.arange(0, sum([1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames])).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
644
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split(
645
+ [1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
646
+ )
647
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
648
+
649
+ # 20250506 pftq: Split history_latents dynamically based on available frames
650
+ fallback_frame_count = 2 # 20250507 pftq: Changed 0 to 2 Edge case for <=1 sec videos
651
+ context_frames = clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :]
652
+
653
+ if total_context_frames > 0:
654
+ context_frames = history_latents[:, :, -total_context_frames:, :, :]
655
+ split_sizes = [num_4x_frames, num_2x_frames, effective_clean_frames]
656
+ split_sizes = [s for s in split_sizes if s > 0] # Remove zero sizes
657
+ if split_sizes:
658
+ splits = context_frames.split(split_sizes, dim=2)
659
+ split_idx = 0
660
+
661
+ if num_4x_frames > 0:
662
+ clean_latents_4x = splits[split_idx]
663
+ split_idx = 1
664
+ if clean_latents_4x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
665
+ print("Edge case for <=1 sec videos 4x")
666
+ clean_latents_4x = clean_latents_4x.expand(-1, -1, 2, -1, -1)
667
+
668
+ if num_2x_frames > 0 and split_idx < len(splits):
669
+ clean_latents_2x = splits[split_idx]
670
+ if clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
671
+ print("Edge case for <=1 sec videos 2x")
672
+ clean_latents_2x = clean_latents_2x.expand(-1, -1, 2, -1, -1)
673
+ split_idx += 1
674
+ elif clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
675
+ clean_latents_2x = clean_latents_4x
676
+
677
+ if effective_clean_frames > 0 and split_idx < len(splits):
678
+ clean_latents_1x = splits[split_idx]
679
+
680
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
681
+
682
+ # 20250507 pftq: Fix for <=1 sec videos.
683
+ max_frames = min(latent_window_size * 4 - 3, history_latents.shape[2] * 4)
684
+ return [max_frames, clean_latents, clean_latents_2x, clean_latents_4x, latent_indices, clean_latents, clean_latent_indices, clean_latent_2x_indices, clean_latent_4x_indices]
685
+
686
+ for idx in range(batch):
687
+ if batch > 1:
688
+ print(f"Beginning video {idx+1} of {batch} with seed {seed} ")
689
+
690
+ #job_id = generate_timestamp()
691
+ job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+f"_framepackf1-videoinput_{width}-{total_second_length}sec_seed-{seed}_steps-{steps}_distilled-{gs}_cfg-{cfg}" # 20250506 pftq: easier to read timestamp and filename
692
+
693
+ # Sampling
694
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
695
+
696
+ rnd = torch.Generator("cpu").manual_seed(seed)
697
+
698
+ # 20250506 pftq: Initialize history_latents with video latents
699
+ history_latents = video_latents
700
+ total_generated_latent_frames = history_latents.shape[2]
701
+ # 20250506 pftq: Initialize history_pixels to fix UnboundLocalError
702
+ history_pixels = None
703
+ previous_video = None
704
+
705
+ for section_index in range(total_latent_sections):
706
+ if stream.input_queue.top() == 'end':
707
+ stream.output_queue.push(('end', None))
708
+ return
709
+
710
+ print(f'section_index = {section_index}, total_latent_sections = {total_latent_sections}')
711
+
712
+ if len(prompt_parameters) > 0:
713
+ [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n] = prompt_parameters.pop(0)
714
+
715
+ if not high_vram:
716
+ unload_complete_models()
717
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
718
+
719
+ if use_teacache:
720
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
721
+ else:
722
+ transformer.initialize_teacache(enable_teacache=False)
723
+
724
+ [max_frames, clean_latents, clean_latents_2x, clean_latents_4x, latent_indices, clean_latents, clean_latent_indices, clean_latent_2x_indices, clean_latent_4x_indices] = compute_latent(history_latents, latent_window_size, num_clean_frames, start_latent)
725
+
726
+ generated_latents = sample_hunyuan(
727
+ transformer=transformer,
728
+ sampler='unipc',
729
+ width=width,
730
+ height=height,
731
+ frames=max_frames,
732
+ real_guidance_scale=cfg,
733
+ distilled_guidance_scale=gs,
734
+ guidance_rescale=rs,
735
+ num_inference_steps=steps,
736
+ generator=rnd,
737
+ prompt_embeds=llama_vec,
738
+ prompt_embeds_mask=llama_attention_mask,
739
+ prompt_poolers=clip_l_pooler,
740
+ negative_prompt_embeds=llama_vec_n,
741
+ negative_prompt_embeds_mask=llama_attention_mask_n,
742
+ negative_prompt_poolers=clip_l_pooler_n,
743
+ device=gpu,
744
+ dtype=torch.bfloat16,
745
+ image_embeddings=image_encoder_last_hidden_state,
746
+ latent_indices=latent_indices,
747
+ clean_latents=clean_latents,
748
+ clean_latent_indices=clean_latent_indices,
749
+ clean_latents_2x=clean_latents_2x,
750
+ clean_latent_2x_indices=clean_latent_2x_indices,
751
+ clean_latents_4x=clean_latents_4x,
752
+ clean_latent_4x_indices=clean_latent_4x_indices,
753
+ callback=callback,
754
+ )
755
+
756
+ total_generated_latent_frames += int(generated_latents.shape[2])
757
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
758
+
759
+ if not high_vram:
760
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
761
+ load_model_as_complete(vae, target_device=gpu)
762
+
763
+ if history_pixels is None:
764
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
765
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
766
+ else:
767
+ section_latent_frames = latent_window_size * 2
768
+ overlapped_frames = min(latent_window_size * 4 - 3, history_pixels.shape[2])
769
+
770
+ real_history_latents = history_latents[:, :, -min(total_generated_latent_frames, section_latent_frames):, :, :]
771
+ history_pixels = soft_append_bcthw(history_pixels, vae_decode(real_history_latents, vae).cpu(), overlapped_frames)
772
+
773
+ if not high_vram:
774
+ unload_complete_models()
775
+
776
+ if enable_preview or section_index == total_latent_sections - 1:
777
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
778
+
779
+ # 20250506 pftq: Use input video FPS for output
780
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=fps, crf=mp4_crf)
781
+ print(f"Latest video saved: {output_filename}")
782
+ # 20250508 pftq: Save prompt to mp4 metadata comments
783
+ set_mp4_comments_imageio_ffmpeg(output_filename, f"Prompt: {prompts} | Negative Prompt: {n_prompt}");
784
+ print(f"Prompt saved to mp4 metadata comments: {output_filename}")
785
+
786
+ # 20250506 pftq: Clean up previous partial files
787
+ if previous_video is not None and os.path.exists(previous_video):
788
+ try:
789
+ os.remove(previous_video)
790
+ print(f"Previous partial video deleted: {previous_video}")
791
+ except Exception as e:
792
+ print(f"Error deleting previous partial video {previous_video}: {e}")
793
+ previous_video = output_filename
794
+
795
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
796
+
797
+ stream.output_queue.push(('file', output_filename))
798
+
799
+ seed = (seed + 1) % np.iinfo(np.int32).max
800
+
801
+ except:
802
+ traceback.print_exc()
803
+
804
+ if not high_vram:
805
+ unload_complete_models(
806
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
807
+ )
808
+
809
+ stream.output_queue.push(('end', None))
810
+ return
811
+
812
+ def get_duration(input_image, image_position, prompts, generation_mode, n_prompt, seed, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf):
813
+ return total_second_length * 60 * (0.9 if use_teacache else 1.5) * (1 + ((steps - 25) / 100))
814
+
815
+ # Remove this decorator if you run on local
816
+ @spaces.GPU(duration=get_duration)
817
+ def process_on_gpu(input_image,
818
+ image_position=0,
819
+ prompts=[""],
820
+ generation_mode="image",
821
+ n_prompt="",
822
+ seed=31337,
823
+ resolution=640,
824
+ total_second_length=5,
825
+ latent_window_size=9,
826
+ steps=25,
827
+ cfg=1.0,
828
+ gs=10.0,
829
+ rs=0.0,
830
+ gpu_memory_preservation=6,
831
+ enable_preview=True,
832
+ use_teacache=False,
833
+ mp4_crf=16
834
+ ):
835
+ start = time.time()
836
+ global stream
837
+ stream = AsyncStream()
838
+
839
+ async_run(worker, input_image, image_position, prompts, n_prompt, seed, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf)
840
+
841
+ output_filename = None
842
+
843
+ while True:
844
+ flag, data = stream.output_queue.next()
845
+
846
+ if flag == 'file':
847
+ output_filename = data
848
+ yield gr.update(value=output_filename, label="Previewed Frames"), gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True), gr.update()
849
+
850
+ if flag == 'progress':
851
+ preview, desc, html = data
852
+ yield gr.update(label="Previewed Frames"), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True), gr.update()
853
+
854
+ if flag == 'end':
855
+ end = time.time()
856
+ secondes = int(end - start)
857
+ minutes = math.floor(secondes / 60)
858
+ secondes = secondes - (minutes * 60)
859
+ hours = math.floor(minutes / 60)
860
+ minutes = minutes - (hours * 60)
861
+ yield gr.update(value=output_filename, label="Finished Frames"), gr.update(visible=False), gr.update(), "The process has lasted " + \
862
+ ((str(hours) + " h, ") if hours != 0 else "") + \
863
+ ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
864
+ str(secondes) + " sec. " + \
865
+ "You can upscale the result with RIFE. To make all your generated scenes consistent, you can then apply a face swap on the main character. If you do not see the generated video above, the process may have failed. See the logs for more information. If you see an error like ''NVML_SUCCESS == r INTERNAL ASSERT FAILED'', you probably haven't enough VRAM. Test an example or other options to compare. You can share your inputs to the original space or set your space in public for a peer review.", gr.update(interactive=True), gr.update(interactive=False), gr.update(visible = False)
866
+ break
867
+
868
+ def process(input_image,
869
+ image_position=0,
870
+ prompt="",
871
+ generation_mode="image",
872
+ n_prompt="",
873
+ randomize_seed=True,
874
+ seed=31337,
875
+ resolution=640,
876
+ total_second_length=5,
877
+ latent_window_size=9,
878
+ steps=25,
879
+ cfg=1.0,
880
+ gs=10.0,
881
+ rs=0.0,
882
+ gpu_memory_preservation=6,
883
+ enable_preview=True,
884
+ use_teacache=False,
885
+ mp4_crf=16
886
+ ):
887
+
888
+ if torch.cuda.device_count() == 0:
889
+ gr.Warning('Set this space to GPU config to make it work.')
890
+ yield gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible = False)
891
+ return
892
+
893
+ if randomize_seed:
894
+ seed = random.randint(0, np.iinfo(np.int32).max)
895
+
896
+ prompts = prompt.split(";")
897
+
898
+ # assert input_image is not None, 'No input image!'
899
+ if generation_mode == "text":
900
+ default_height, default_width = 640, 640
901
+ input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
902
+ print("No input image provided. Using a blank white image.")
903
+
904
+ yield gr.update(label="Previewed Frames"), None, '', '', gr.update(interactive=False), gr.update(interactive=True), gr.update()
905
+
906
+ yield from process_on_gpu(input_image,
907
+ image_position,
908
+ prompts,
909
+ generation_mode,
910
+ n_prompt,
911
+ seed,
912
+ resolution,
913
+ total_second_length,
914
+ latent_window_size,
915
+ steps,
916
+ cfg,
917
+ gs,
918
+ rs,
919
+ gpu_memory_preservation,
920
+ enable_preview,
921
+ use_teacache,
922
+ mp4_crf
923
+ )
924
+
925
+ def get_duration_video(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
926
+ return total_second_length * 60 * (1.5 if use_teacache else 2.5) * (1 + ((steps - 25) / 100))
927
+
928
+ # Remove this decorator if you run on local
929
+ @spaces.GPU(duration=get_duration_video)
930
+ def process_video_on_gpu(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
931
+ start = time.time()
932
+ global stream
933
+ stream = AsyncStream()
934
+
935
+ # 20250506 pftq: Pass num_clean_frames, vae_batch, etc
936
+ async_run(worker_video, input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch)
937
+
938
+ output_filename = None
939
+
940
+ while True:
941
+ flag, data = stream.output_queue.next()
942
+
943
+ if flag == 'file':
944
+ output_filename = data
945
+ yield gr.update(value=output_filename, label="Previewed Frames"), gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True), gr.update()
946
+
947
+ if flag == 'progress':
948
+ preview, desc, html = data
949
+ yield gr.update(label="Previewed Frames"), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True), gr.update() # 20250506 pftq: Keep refreshing the video in case it got hidden when the tab was in the background
950
+
951
+ if flag == 'end':
952
+ end = time.time()
953
+ secondes = int(end - start)
954
+ minutes = math.floor(secondes / 60)
955
+ secondes = secondes - (minutes * 60)
956
+ hours = math.floor(minutes / 60)
957
+ minutes = minutes - (hours * 60)
958
+ yield gr.update(value=output_filename, label="Finished Frames"), gr.update(visible=False), desc + \
959
+ " The process has lasted " + \
960
+ ((str(hours) + " h, ") if hours != 0 else "") + \
961
+ ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
962
+ str(secondes) + " sec. " + \
963
+ " You can upscale the result with RIFE. To make all your generated scenes consistent, you can then apply a face swap on the main character. If you do not see the generated video above, the process may have failed. See the logs for more information. If you see an error like ''NVML_SUCCESS == r INTERNAL ASSERT FAILED'', you probably haven't enough VRAM. Test an example or other options to compare. You can share your inputs to the original space or set your space in public for a peer review.", '', gr.update(interactive=True), gr.update(interactive=False), gr.update(visible = False)
964
+ break
965
+
966
+ def process_video(input_video, prompt, n_prompt, randomize_seed, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
967
+ global high_vram
968
+
969
+ if torch.cuda.device_count() == 0:
970
+ gr.Warning('Set this space to GPU config to make it work.')
971
+ yield gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible = False)
972
+ return
973
+
974
+ if randomize_seed:
975
+ seed = random.randint(0, np.iinfo(np.int32).max)
976
+
977
+ prompts = prompt.split(";")
978
+
979
+ # 20250506 pftq: Updated assertion for video input
980
+ assert input_video is not None, 'No input video!'
981
+
982
+ yield gr.update(label="Previewed Frames"), None, '', '', gr.update(interactive=False), gr.update(interactive=True), gr.update()
983
+
984
+ # 20250507 pftq: Even the H100 needs offloading if the video dimensions are 720p or higher
985
+ if high_vram and (no_resize or resolution>640):
986
+ print("Disabling high vram mode due to no resize and/or potentially higher resolution...")
987
+ high_vram = False
988
+ vae.enable_slicing()
989
+ vae.enable_tiling()
990
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
991
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
992
+
993
+ # 20250508 pftq: automatically set distilled cfg to 1 if cfg is used
994
+ if cfg > 1:
995
+ gs = 1
996
+
997
+ yield from process_video_on_gpu(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch)
998
+
999
+ def end_process():
1000
+ stream.input_queue.push('end')
1001
+
1002
+ timeless_prompt_value = [""]
1003
+ timed_prompts = {}
1004
+
1005
+ def handle_prompt_number_change():
1006
+ timed_prompts.clear()
1007
+ return []
1008
+
1009
+ def handle_timeless_prompt_change(timeless_prompt):
1010
+ timeless_prompt_value[0] = timeless_prompt
1011
+ return refresh_prompt()
1012
+
1013
+ def handle_timed_prompt_change(timed_prompt_id, timed_prompt):
1014
+ timed_prompts[timed_prompt_id] = timed_prompt
1015
+ return refresh_prompt()
1016
+
1017
+ def refresh_prompt():
1018
+ dict_values = {k: v for k, v in timed_prompts.items()}
1019
+ sorted_dict_values = sorted(dict_values.items(), key=lambda x: x[0])
1020
+ array = []
1021
+ for sorted_dict_value in sorted_dict_values:
1022
+ if timeless_prompt_value[0] is not None and len(timeless_prompt_value[0]) and sorted_dict_value[1] is not None and len(sorted_dict_value[1]):
1023
+ array.append(timeless_prompt_value[0] + ". " + sorted_dict_value[1])
1024
+ else:
1025
+ array.append(timeless_prompt_value[0] + sorted_dict_value[1])
1026
+ print(str(array))
1027
+ return ";".join(array)
1028
+
1029
+ title_html = """
1030
+ <h1><center>FramePack</center></h1>
1031
+ <big><center>Generate videos from text/image/video freely, without account, without watermark and download it</center></big>
1032
+ <br/>
1033
+
1034
+ <p>This space is ready to work on ZeroGPU and GPU and has been tested successfully on ZeroGPU. Please leave a <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/FramePack/discussions/new">message in discussion</a> if you encounter issues.</p>
1035
+ """
1036
+
1037
+ js = """
1038
+ function createGradioAnimation() {
1039
+ window.addEventListener("beforeunload", function (e) {
1040
+ if (document.getElementById('end-button') && !document.getElementById('end-button').disabled) {
1041
+ var confirmationMessage = 'A process is still running. '
1042
+ + 'If you leave before saving, your changes will be lost.';
1043
+
1044
+ (e || window.event).returnValue = confirmationMessage;
1045
+ }
1046
+ return confirmationMessage;
1047
+ });
1048
+ return 'Animation created';
1049
+ }
1050
  """
1051
+
1052
+ css = make_progress_bar_css()
1053
+ block = gr.Blocks(css=css, js=js).queue()
1054
+ with block:
1055
+ if torch.cuda.device_count() == 0:
1056
+ with gr.Row():
1057
+ gr.HTML("""
1058
+ <p style="background-color: red;"><big><big><big><b>⚠️To use FramePack, <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/FramePack?duplicate=true">duplicate this space</a> and set a GPU with 30 GB VRAM.</b>
1059
+
1060
+ You can't use FramePack directly here because this space runs on a CPU, which is not enough for FramePack. Please provide <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/FramePack/discussions/new">feedback</a> if you have issues.
1061
+ </big></big></big></p>
1062
+ """)
1063
+ gr.HTML(title_html)
1064
+ local_storage = gr.BrowserState(default_local_storage)
1065
+ with gr.Row():
1066
+ with gr.Column():
1067
+ generation_mode = gr.Radio([["Text-to-Video", "text"], ["Image-to-Video", "image"], ["Video Extension", "video"]], elem_id="generation-mode", label="Generation mode", value = "image")
1068
+ text_to_video_hint = gr.HTML("I discourage to use the Text-to-Video feature. You should rather generate an image with Flux and use Image-to-Video. You will save time.")
1069
+ input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320)
1070
+ image_position = gr.Slider(label="Image position", minimum=0, maximum=100, value=0, step=100, info='0=Video start; 100=Video end (lower quality)')
1071
+ input_video = gr.Video(sources='upload', label="Input Video", height=320)
1072
+ timeless_prompt = gr.Textbox(label="Timeless prompt", info='Used on the whole duration of the generation', value='', placeholder="The creature starts to move, fast motion, fixed camera, focus motion, consistent arm, consistent position, mute colors, insanely detailed")
1073
+ prompt_number = gr.Slider(label="Timed prompt number", minimum=0, maximum=1000, value=0, step=1, info='Prompts will automatically appear')
1074
+
1075
+ @gr.render(inputs=prompt_number)
1076
+ def show_split(prompt_number):
1077
+ for digit in range(prompt_number):
1078
+ timed_prompt_id = gr.Textbox(value="timed_prompt_" + str(digit), visible=False)
1079
+ timed_prompt = gr.Textbox(label="Timed prompt #" + str(digit + 1), elem_id="timed_prompt_" + str(digit), value="")
1080
+ timed_prompt.change(fn=handle_timed_prompt_change, inputs=[timed_prompt_id, timed_prompt], outputs=[final_prompt])
1081
+
1082
+ final_prompt = gr.Textbox(label="Final prompt", value='', info='Use ; to separate in time; beware to write to stop the previous action')
1083
+ prompt_hint = gr.HTML("Video extension barely follows the prompt; to force to follow the prompt, you have to set the Distilled CFG Scale to 3.0 and the Context Frames to 2 but the video quality will be poor.")
1084
+ total_second_length = gr.Slider(label="Video Length to Generate (seconds)", minimum=1, maximum=120, value=2, step=0.1)
1085
+
1086
+ with gr.Row():
1087
+ start_button = gr.Button(value="🎥 Generate", variant="primary")
1088
+ start_button_video = gr.Button(value="🎥 Generate", variant="primary")
1089
+ end_button = gr.Button(elem_id="end-button", value="End Generation", variant="stop", interactive=False)
1090
+
1091
+ with gr.Accordion("Advanced settings", open=False):
1092
+ enable_preview = gr.Checkbox(label='Enable preview', value=True, info='Display a preview around each second generated but it costs 2 sec. for each second generated.')
1093
+ use_teacache = gr.Checkbox(label='Use TeaCache', value=False, info='Faster speed and no break in brightness, but often makes hands and fingers slightly worse.')
1094
+
1095
+ n_prompt = gr.Textbox(label="Negative Prompt", value="Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", info='Requires using normal CFG (undistilled) instead of Distilled (set Distilled=1 and CFG > 1).')
1096
+
1097
+ latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, info='Generate more frames at a time (larger chunks). Less degradation and better blending but higher VRAM cost. Should not change.')
1098
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=30, step=1, info='Increase for more quality, especially if using high non-distilled CFG. If your animation has very few motion, you may have brutal brightness change; this can be fixed increasing the steps.')
1099
+
1100
+ with gr.Row():
1101
+ no_resize = gr.Checkbox(label='Force Original Video Resolution (no Resizing)', value=False, info='Might run out of VRAM (720p requires > 24GB VRAM).')
1102
+ resolution = gr.Dropdown([
1103
+ ["409,600 px (working)", 640],
1104
+ ["451,584 px (working)", 672],
1105
+ ["495,616 px (VRAM pb on HF)", 704],
1106
+ ["589,824 px (not tested)", 768],
1107
+ ["692,224 px (not tested)", 832],
1108
+ ["746,496 px (not tested)", 864],
1109
+ ["921,600 px (not tested)", 960]
1110
+ ], value=672, label="Resolution (width x height)", info="Do not affect the generation time")
1111
+
1112
+ # 20250506 pftq: Reduced default distilled guidance scale to improve adherence to input video
1113
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, info='Use this instead of Distilled for more detail/control + Negative Prompt (make sure Distilled set to 1). Doubles render time. Should not change.')
1114
+ gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01, info='Prompt adherence at the cost of less details from the input video, but to a lesser extent than Context Frames; 3=follow the prompt but blurred motions & unsharped, 10=focus motion; changing this value is not recommended')
1115
+ rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, info='Should not change')
1116
+
1117
+
1118
+ # 20250506 pftq: Renamed slider to Number of Context Frames and updated description
1119
+ num_clean_frames = gr.Slider(label="Number of Context Frames", minimum=2, maximum=10, value=5, step=1, info="Retain more video details but increase memory use. Reduce to 2 to avoid memory issues or to give more weight to the prompt.")
1120
+
1121
+ default_vae = 32
1122
+ if high_vram:
1123
+ default_vae = 128
1124
+ elif free_mem_gb>=20:
1125
+ default_vae = 64
1126
+
1127
+ vae_batch = gr.Slider(label="VAE Batch Size for Input Video", minimum=4, maximum=256, value=default_vae, step=4, info="Reduce if running out of memory. Increase for better quality frames during fast motion.")
1128
+
1129
+
1130
+ gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.")
1131
+
1132
+ mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ")
1133
+ batch = gr.Slider(label="Batch Size (Number of Videos)", minimum=1, maximum=1000, value=1, step=1, info='Generate multiple videos each with a different seed.')
1134
+ with gr.Row():
1135
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=True, info='If checked, the seed is always different')
1136
+ seed = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, randomize=True)
1137
+
1138
+ with gr.Column():
1139
+ warning = gr.HTML(value = "<center><big>Your computer must <u>not</u> enter into standby mode.</big><br/>On Chrome, you can force to keep a tab alive in <code>chrome://discards/</code></center>", visible = False)
1140
+ result_video = gr.Video(label="Generated Frames", autoplay=True, show_share_button=False, height=512, loop=True)
1141
+ preview_image = gr.Image(label="Next Latents", height=200, visible=False)
1142
+ progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
1143
+ progress_bar = gr.HTML('', elem_classes='no-generating-animation')
1144
+
1145
+ # 20250506 pftq: Updated inputs to include num_clean_frames
1146
+ ips = [input_image, image_position, final_prompt, generation_mode, n_prompt, randomize_seed, seed, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf]
1147
+ ips_video = [input_video, final_prompt, n_prompt, randomize_seed, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch]
1148
+
1149
+ gr.Examples(
1150
+ label = "Examples from text",
1151
+ examples = [
1152
+ [
1153
+ None, # input_image
1154
+ 0, # image_position
1155
+ "Overcrowed street in Japan, photorealistic, realistic, intricate details, 8k, insanely detailed",
1156
+ "text", # generation_mode
1157
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1158
+ True, # randomize_seed
1159
+ 42, # seed
1160
+ 672, # resolution
1161
+ 1, # total_second_length
1162
+ 9, # latent_window_size
1163
+ 30, # steps
1164
+ 1.0, # cfg
1165
+ 10.0, # gs
1166
+ 0.0, # rs
1167
+ 6, # gpu_memory_preservation
1168
+ False, # enable_preview
1169
+ False, # use_teacache
1170
+ 16 # mp4_crf
1171
+ ]
1172
+ ],
1173
+ run_on_click = True,
1174
+ fn = process,
1175
+ inputs = ips,
1176
+ outputs = [result_video, preview_image, progress_desc, progress_bar, start_button, end_button],
1177
+ cache_examples = False,
1178
+ )
1179
+
1180
+ gr.Examples(
1181
+ label = "Examples from image",
1182
+ examples = [
1183
+ [
1184
+ "./img_examples/Example1.png", # input_image
1185
+ 0, # image_position
1186
+ "A dolphin emerges from the water, photorealistic, realistic, intricate details, 8k, insanely detailed",
1187
+ "image", # generation_mode
1188
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1189
+ True, # randomize_seed
1190
+ 42, # seed
1191
+ 672, # resolution
1192
+ 1, # total_second_length
1193
+ 9, # latent_window_size
1194
+ 30, # steps
1195
+ 1.0, # cfg
1196
+ 10.0, # gs
1197
+ 0.0, # rs
1198
+ 6, # gpu_memory_preservation
1199
+ False, # enable_preview
1200
+ True, # use_teacache
1201
+ 16 # mp4_crf
1202
+ ],
1203
+ [
1204
+ "./img_examples/Example2.webp", # input_image
1205
+ 0, # image_position
1206
+ "A man on the left and a woman on the right face each other ready to start a conversation, large space between the persons, full view, full-length view, 3D, pixar, 3D render, CGI. The man talks and the woman listens; A man on the left and a woman on the right face each other ready to start a conversation, large space between the persons, full view, full-length view, 3D, pixar, 3D render, CGI. The woman talks, the man stops talking and the man listens; A man on the left and a woman on the right face each other ready to start a conversation, large space between the persons, full view, full-length view, 3D, pixar, 3D render, CGI. The woman talks and the man listens",
1207
+ "image", # generation_mode
1208
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1209
+ True, # randomize_seed
1210
+ 42, # seed
1211
+ 672, # resolution
1212
+ 2, # total_second_length
1213
+ 9, # latent_window_size
1214
+ 30, # steps
1215
+ 1.0, # cfg
1216
+ 10.0, # gs
1217
+ 0.0, # rs
1218
+ 6, # gpu_memory_preservation
1219
+ False, # enable_preview
1220
+ True, # use_teacache
1221
+ 16 # mp4_crf
1222
+ ],
1223
+ [
1224
+ "./img_examples/Example2.webp", # input_image
1225
+ 0, # image_position
1226
+ "A man on the left and a woman on the right face each other ready to start a conversation, large space between the persons, full view, full-length view, 3D, pixar, 3D render, CGI. The woman talks and the man listens; A man on the left and a woman on the right face each other ready to start a conversation, large space between the persons, full view, full-length view, 3D, pixar, 3D render, CGI. The man talks, the woman stops talking and the woman listens A man on the left and a woman on the right face each other ready to start a conversation, large space between the persons, full view, full-length view, 3D, pixar, 3D render, CGI. The man talks and the woman listens",
1227
+ "image", # generation_mode
1228
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1229
+ True, # randomize_seed
1230
+ 42, # seed
1231
+ 672, # resolution
1232
+ 2, # total_second_length
1233
+ 9, # latent_window_size
1234
+ 30, # steps
1235
+ 1.0, # cfg
1236
+ 10.0, # gs
1237
+ 0.0, # rs
1238
+ 6, # gpu_memory_preservation
1239
+ False, # enable_preview
1240
+ True, # use_teacache
1241
+ 16 # mp4_crf
1242
+ ],
1243
+ [
1244
+ "./img_examples/Example3.jpg", # input_image
1245
+ 0, # image_position
1246
+ "A boy is walking to the right, full view, full-length view, cartoon",
1247
+ "image", # generation_mode
1248
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1249
+ True, # randomize_seed
1250
+ 42, # seed
1251
+ 672, # resolution
1252
+ 1, # total_second_length
1253
+ 9, # latent_window_size
1254
+ 30, # steps
1255
+ 1.0, # cfg
1256
+ 10.0, # gs
1257
+ 0.0, # rs
1258
+ 6, # gpu_memory_preservation
1259
+ False, # enable_preview
1260
+ True, # use_teacache
1261
+ 16 # mp4_crf
1262
+ ],
1263
+ [
1264
+ "./img_examples/Example4.webp", # input_image
1265
+ 100, # image_position
1266
+ "A building starting to explode, photorealistic, realisitc, 8k, insanely detailed",
1267
+ "image", # generation_mode
1268
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1269
+ True, # randomize_seed
1270
+ 42, # seed
1271
+ 672, # resolution
1272
+ 1, # total_second_length
1273
+ 9, # latent_window_size
1274
+ 30, # steps
1275
+ 1.0, # cfg
1276
+ 10.0, # gs
1277
+ 0.0, # rs
1278
+ 6, # gpu_memory_preservation
1279
+ False, # enable_preview
1280
+ False, # use_teacache
1281
+ 16 # mp4_crf
1282
+ ]
1283
+ ],
1284
+ run_on_click = True,
1285
+ fn = process,
1286
+ inputs = ips,
1287
+ outputs = [result_video, preview_image, progress_desc, progress_bar, start_button, end_button],
1288
+ cache_examples = False,
1289
+ )
1290
+
1291
+ gr.Examples(
1292
+ label = "Examples from video",
1293
+ examples = [
1294
+ [
1295
+ "./img_examples/Example1.mp4", # input_video
1296
+ "View of the sea as far as the eye can see, from the seaside, a piece of land is barely visible on the horizon at the middle, the sky is radiant, reflections of the sun in the water, photorealistic, realistic, intricate details, 8k, insanely detailed",
1297
+ "Missing arm, unrealistic position, impossible contortion, visible bone, muscle contraction, blurred, blurry", # n_prompt
1298
+ True, # randomize_seed
1299
+ 42, # seed
1300
+ 1, # batch
1301
+ 672, # resolution
1302
+ 1, # total_second_length
1303
+ 9, # latent_window_size
1304
+ 30, # steps
1305
+ 1.0, # cfg
1306
+ 10.0, # gs
1307
+ 0.0, # rs
1308
+ 6, # gpu_memory_preservation
1309
+ False, # enable_preview
1310
+ True, # use_teacache
1311
+ False, # no_resize
1312
+ 16, # mp4_crf
1313
+ 5, # num_clean_frames
1314
+ default_vae
1315
+ ]
1316
+ ],
1317
+ run_on_click = True,
1318
+ fn = process_video,
1319
+ inputs = ips_video,
1320
+ outputs = [result_video, preview_image, progress_desc, progress_bar, start_button_video, end_button],
1321
+ cache_examples = False,
1322
+ )
1323
+
1324
+ def save_preferences(preferences, value):
1325
+ preferences["generation-mode"] = value
1326
+ return preferences
1327
+
1328
+ def load_preferences(saved_prefs):
1329
+ saved_prefs = init_preferences(saved_prefs)
1330
+ return saved_prefs["generation-mode"]
1331
+
1332
+ def init_preferences(saved_prefs):
1333
+ if saved_prefs is None:
1334
+ saved_prefs = default_local_storage
1335
+ return saved_prefs
1336
+
1337
+ def check_parameters(generation_mode, input_image, input_video):
1338
+ if generation_mode == "image" and input_image is None:
1339
+ raise gr.Error("Please provide an image to extend.")
1340
+ if generation_mode == "video" and input_video is None:
1341
+ raise gr.Error("Please provide a video to extend.")
1342
+ return [gr.update(interactive=True), gr.update(visible = True)]
1343
+
1344
+ def handle_generation_mode_change(generation_mode_data):
1345
+ if generation_mode_data == "text":
1346
+ return [gr.update(visible = True), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = True), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False)]
1347
+ elif generation_mode_data == "image":
1348
+ return [gr.update(visible = False), gr.update(visible = True), gr.update(visible = True), gr.update(visible = False), gr.update(visible = True), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = False)]
1349
+ elif generation_mode_data == "video":
1350
+ return [gr.update(visible = False), gr.update(visible = False), gr.update(visible = False), gr.update(visible = True), gr.update(visible = False), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True)]
1351
+
1352
+ prompt_number.change(fn=handle_prompt_number_change, inputs=[], outputs=[])
1353
+ timeless_prompt.change(fn=handle_timeless_prompt_change, inputs=[timeless_prompt], outputs=[final_prompt])
1354
+ start_button.click(fn = check_parameters, inputs = [
1355
+ generation_mode, input_image, input_video
1356
+ ], outputs = [end_button, warning], queue = False, show_progress = False).success(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button, warning], scroll_to_output = True)
1357
+ start_button_video.click(fn = check_parameters, inputs = [
1358
+ generation_mode, input_image, input_video
1359
+ ], outputs = [end_button, warning], queue = False, show_progress = False).success(fn=process_video, inputs=ips_video, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button_video, end_button, warning], scroll_to_output = True)
1360
+ end_button.click(fn=end_process)
1361
+
1362
+ generation_mode.change(fn = save_preferences, inputs = [
1363
+ local_storage,
1364
+ generation_mode,
1365
+ ], outputs = [
1366
+ local_storage
1367
+ ])
1368
+
1369
+ generation_mode.change(
1370
+ fn=handle_generation_mode_change,
1371
+ inputs=[generation_mode],
1372
+ outputs=[text_to_video_hint, image_position, input_image, input_video, start_button, start_button_video, no_resize, batch, num_clean_frames, vae_batch, prompt_hint]
1373
+ )
1374
+
1375
+ # Update display when the page loads
1376
+ block.load(
1377
+ fn=handle_generation_mode_change, inputs = [
1378
+ generation_mode
1379
+ ], outputs = [
1380
+ text_to_video_hint, image_position, input_image, input_video, start_button, start_button_video, no_resize, batch, num_clean_frames, vae_batch, prompt_hint
1381
+ ]
1382
+ )
1383
+
1384
+ # Load saved preferences when the page loads
1385
+ block.load(
1386
+ fn=load_preferences, inputs = [
1387
+ local_storage
1388
+ ], outputs = [
1389
+ generation_mode
1390
+ ]
1391
+ )
1392
+
1393
+ block.launch(mcp_server=True, ssr_mode=False)
app_endframe.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers_helper.hf_login import login
2
+
3
+ import os
4
+
5
+ os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import traceback
10
+ import einops
11
+ import safetensors.torch as sf
12
+ import numpy as np
13
+ import argparse
14
+ import random
15
+ import math
16
+ # 20250506 pftq: Added for video input loading
17
+ import decord
18
+ # 20250506 pftq: Added for progress bars in video_encode
19
+ from tqdm import tqdm
20
+ # 20250506 pftq: Normalize file paths for Windows compatibility
21
+ import pathlib
22
+ # 20250506 pftq: for easier to read timestamp
23
+ from datetime import datetime
24
+ # 20250508 pftq: for saving prompt to mp4 comments metadata
25
+ import imageio_ffmpeg
26
+ import tempfile
27
+ import shutil
28
+ import subprocess
29
+ import spaces
30
+ from PIL import Image
31
+ from diffusers import AutoencoderKLHunyuanVideo
32
+ from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
33
+ from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
34
+ from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
35
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
36
+ from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
37
+ from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
38
+ from diffusers_helper.thread_utils import AsyncStream, async_run
39
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
40
+ from transformers import SiglipImageProcessor, SiglipVisionModel
41
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
42
+ from diffusers_helper.bucket_tools import find_nearest_bucket
43
+
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument('--share', action='store_true')
46
+ parser.add_argument("--server", type=str, default='0.0.0.0')
47
+ parser.add_argument("--port", type=int, required=False)
48
+ parser.add_argument("--inbrowser", action='store_true')
49
+ args = parser.parse_args()
50
+
51
+ print(args)
52
+
53
+ free_mem_gb = get_cuda_free_memory_gb(gpu)
54
+ high_vram = free_mem_gb > 60
55
+
56
+ print(f'Free VRAM {free_mem_gb} GB')
57
+ print(f'High-VRAM Mode: {high_vram}')
58
+
59
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
60
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
61
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
62
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
63
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
64
+
65
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
66
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
67
+
68
+ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePackI2V_HY', torch_dtype=torch.bfloat16).cpu()
69
+
70
+ vae.eval()
71
+ text_encoder.eval()
72
+ text_encoder_2.eval()
73
+ image_encoder.eval()
74
+ transformer.eval()
75
+
76
+ if not high_vram:
77
+ vae.enable_slicing()
78
+ vae.enable_tiling()
79
+
80
+ transformer.high_quality_fp32_output_for_inference = True
81
+ print('transformer.high_quality_fp32_output_for_inference = True')
82
+
83
+ transformer.to(dtype=torch.bfloat16)
84
+ vae.to(dtype=torch.float16)
85
+ image_encoder.to(dtype=torch.float16)
86
+ text_encoder.to(dtype=torch.float16)
87
+ text_encoder_2.to(dtype=torch.float16)
88
+
89
+ vae.requires_grad_(False)
90
+ text_encoder.requires_grad_(False)
91
+ text_encoder_2.requires_grad_(False)
92
+ image_encoder.requires_grad_(False)
93
+ transformer.requires_grad_(False)
94
+
95
+ if not high_vram:
96
+ # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
97
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
98
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
99
+ else:
100
+ text_encoder.to(gpu)
101
+ text_encoder_2.to(gpu)
102
+ image_encoder.to(gpu)
103
+ vae.to(gpu)
104
+ transformer.to(gpu)
105
+
106
+ stream = AsyncStream()
107
+
108
+ outputs_folder = './outputs/'
109
+ os.makedirs(outputs_folder, exist_ok=True)
110
+
111
+ # 20250506 pftq: Added function to encode input video frames into latents
112
+ @torch.no_grad()
113
+ def video_encode(video_path, resolution, no_resize, vae, vae_batch_size=16, device="cuda", width=None, height=None):
114
+ """
115
+ Encode a video into latent representations using the VAE.
116
+
117
+ Args:
118
+ video_path: Path to the input video file.
119
+ vae: AutoencoderKLHunyuanVideo model.
120
+ height, width: Target resolution for resizing frames.
121
+ vae_batch_size: Number of frames to process per batch.
122
+ device: Device for computation (e.g., "cuda").
123
+
124
+ Returns:
125
+ start_latent: Latent of the first frame (for compatibility with original code).
126
+ input_image_np: First frame as numpy array (for CLIP vision encoding).
127
+ history_latents: Latents of all frames (shape: [1, channels, frames, height//8, width//8]).
128
+ fps: Frames per second of the input video.
129
+ """
130
+ # 20250506 pftq: Normalize video path for Windows compatibility
131
+ video_path = str(pathlib.Path(video_path).resolve())
132
+ print(f"Processing video: {video_path}")
133
+
134
+ # 20250506 pftq: Check CUDA availability and fallback to CPU if needed
135
+ if device == "cuda" and not torch.cuda.is_available():
136
+ print("CUDA is not available, falling back to CPU")
137
+ device = "cpu"
138
+
139
+ try:
140
+ # 20250506 pftq: Load video and get FPS
141
+ print("Initializing VideoReader...")
142
+ vr = decord.VideoReader(video_path)
143
+ fps = vr.get_avg_fps() # Get input video FPS
144
+ num_real_frames = len(vr)
145
+ print(f"Video loaded: {num_real_frames} frames, FPS: {fps}")
146
+
147
+ # Truncate to nearest latent size (multiple of 4)
148
+ latent_size_factor = 4
149
+ num_frames = (num_real_frames // latent_size_factor) * latent_size_factor
150
+ if num_frames != num_real_frames:
151
+ print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility")
152
+ num_real_frames = num_frames
153
+
154
+ # 20250506 pftq: Read frames
155
+ print("Reading video frames...")
156
+ frames = vr.get_batch(range(num_real_frames)).asnumpy() # Shape: (num_real_frames, height, width, channels)
157
+ print(f"Frames read: {frames.shape}")
158
+
159
+ # 20250506 pftq: Get native video resolution
160
+ native_height, native_width = frames.shape[1], frames.shape[2]
161
+ print(f"Native video resolution: {native_width}x{native_height}")
162
+
163
+ # 20250506 pftq: Use native resolution if height/width not specified, otherwise use provided values
164
+ target_height = native_height if height is None else height
165
+ target_width = native_width if width is None else width
166
+
167
+ # 20250506 pftq: Adjust to nearest bucket for model compatibility
168
+ if not no_resize:
169
+ target_height, target_width = find_nearest_bucket(target_height, target_width, resolution=resolution)
170
+ print(f"Adjusted resolution: {target_width}x{target_height}")
171
+ else:
172
+ print(f"Using native resolution without resizing: {target_width}x{target_height}")
173
+
174
+ # 20250506 pftq: Preprocess frames to match original image processing
175
+ processed_frames = []
176
+ for i, frame in enumerate(frames):
177
+ #print(f"Preprocessing frame {i+1}/{num_frames}")
178
+ frame_np = resize_and_center_crop(frame, target_width=target_width, target_height=target_height)
179
+ processed_frames.append(frame_np)
180
+ processed_frames = np.stack(processed_frames) # Shape: (num_real_frames, height, width, channels)
181
+ print(f"Frames preprocessed: {processed_frames.shape}")
182
+
183
+ # 20250506 pftq: Save first frame for CLIP vision encoding
184
+ input_image_np = processed_frames[0]
185
+ end_of_input_video_image_np = processed_frames[-1]
186
+
187
+ # 20250506 pftq: Convert to tensor and normalize to [-1, 1]
188
+ print("Converting frames to tensor...")
189
+ frames_pt = torch.from_numpy(processed_frames).float() / 127.5 - 1
190
+ frames_pt = frames_pt.permute(0, 3, 1, 2) # Shape: (num_real_frames, channels, height, width)
191
+ frames_pt = frames_pt.unsqueeze(0) # Shape: (1, num_real_frames, channels, height, width)
192
+ frames_pt = frames_pt.permute(0, 2, 1, 3, 4) # Shape: (1, channels, num_real_frames, height, width)
193
+ print(f"Tensor shape: {frames_pt.shape}")
194
+
195
+ # 20250507 pftq: Save pixel frames for use in worker
196
+ input_video_pixels = frames_pt.cpu()
197
+
198
+ # 20250506 pftq: Move to device
199
+ print(f"Moving tensor to device: {device}")
200
+ frames_pt = frames_pt.to(device)
201
+ print("Tensor moved to device")
202
+
203
+ # 20250506 pftq: Move VAE to device
204
+ print(f"Moving VAE to device: {device}")
205
+ vae.to(device)
206
+ print("VAE moved to device")
207
+
208
+ # 20250506 pftq: Encode frames in batches
209
+ print(f"Encoding input video frames in VAE batch size {vae_batch_size} (reduce if memory issues here or if forcing video resolution)")
210
+ latents = []
211
+ vae.eval()
212
+ with torch.no_grad():
213
+ for i in tqdm(range(0, frames_pt.shape[2], vae_batch_size), desc="Encoding video frames", mininterval=0.1):
214
+ #print(f"Encoding batch {i//vae_batch_size + 1}: frames {i} to {min(i + vae_batch_size, frames_pt.shape[2])}")
215
+ batch = frames_pt[:, :, i:i + vae_batch_size] # Shape: (1, channels, batch_size, height, width)
216
+ try:
217
+ # 20250506 pftq: Log GPU memory before encoding
218
+ if device == "cuda":
219
+ free_mem = torch.cuda.memory_allocated() / 1024**3
220
+ #print(f"GPU memory before encoding: {free_mem:.2f} GB")
221
+ batch_latent = vae_encode(batch, vae)
222
+ # 20250506 pftq: Synchronize CUDA to catch issues
223
+ if device == "cuda":
224
+ torch.cuda.synchronize()
225
+ #print(f"GPU memory after encoding: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
226
+ latents.append(batch_latent)
227
+ #print(f"Batch encoded, latent shape: {batch_latent.shape}")
228
+ except RuntimeError as e:
229
+ print(f"Error during VAE encoding: {str(e)}")
230
+ if device == "cuda" and "out of memory" in str(e).lower():
231
+ print("CUDA out of memory, try reducing vae_batch_size or using CPU")
232
+ raise
233
+
234
+ # 20250506 pftq: Concatenate latents
235
+ print("Concatenating latents...")
236
+ history_latents = torch.cat(latents, dim=2) # Shape: (1, channels, frames, height//8, width//8)
237
+ print(f"History latents shape: {history_latents.shape}")
238
+
239
+ # 20250506 pftq: Get first frame's latent
240
+ start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8)
241
+ end_of_input_video_latent = history_latents[:, :, -1:] # Shape: (1, channels, 1, height//8, width//8)
242
+ print(f"Start latent shape: {start_latent.shape}")
243
+
244
+ # 20250506 pftq: Move VAE back to CPU to free GPU memory
245
+ if device == "cuda":
246
+ vae.to(cpu)
247
+ torch.cuda.empty_cache()
248
+ print("VAE moved back to CPU, CUDA cache cleared")
249
+
250
+ return start_latent, input_image_np, history_latents, fps, target_height, target_width, input_video_pixels, end_of_input_video_latent, end_of_input_video_image_np
251
+
252
+ except Exception as e:
253
+ print(f"Error in video_encode: {str(e)}")
254
+ raise
255
+
256
+
257
+ # 20250507 pftq: New function to encode a single image (end frame)
258
+ @torch.no_grad()
259
+ def image_encode(image_np, target_width, target_height, vae, image_encoder, feature_extractor, device="cuda"):
260
+ """
261
+ Encode a single image into a latent and compute its CLIP vision embedding.
262
+
263
+ Args:
264
+ image_np: Input image as numpy array.
265
+ target_width, target_height: Exact resolution to resize the image to (matches start frame).
266
+ vae: AutoencoderKLHunyuanVideo model.
267
+ image_encoder: SiglipVisionModel for CLIP vision encoding.
268
+ feature_extractor: SiglipImageProcessor for preprocessing.
269
+ device: Device for computation (e.g., "cuda").
270
+
271
+ Returns:
272
+ latent: Latent representation of the image (shape: [1, channels, 1, height//8, width//8]).
273
+ clip_embedding: CLIP vision embedding of the image.
274
+ processed_image_np: Processed image as numpy array (after resizing).
275
+ """
276
+ # 20250507 pftq: Process end frame with exact start frame dimensions
277
+ print("Processing end frame...")
278
+ try:
279
+ print(f"Using exact start frame resolution for end frame: {target_width}x{target_height}")
280
+
281
+ # Resize and preprocess image to match start frame
282
+ processed_image_np = resize_and_center_crop(image_np, target_width=target_width, target_height=target_height)
283
+
284
+ # Convert to tensor and normalize
285
+ image_pt = torch.from_numpy(processed_image_np).float() / 127.5 - 1
286
+ image_pt = image_pt.permute(2, 0, 1).unsqueeze(0).unsqueeze(2) # Shape: [1, channels, 1, height, width]
287
+ image_pt = image_pt.to(device)
288
+
289
+ # Move VAE to device
290
+ vae.to(device)
291
+
292
+ # Encode to latent
293
+ latent = vae_encode(image_pt, vae)
294
+ print(f"image_encode vae output shape: {latent.shape}")
295
+
296
+ # Move image encoder to device
297
+ image_encoder.to(device)
298
+
299
+ # Compute CLIP vision embedding
300
+ clip_embedding = hf_clip_vision_encode(processed_image_np, feature_extractor, image_encoder).last_hidden_state
301
+
302
+ # Move models back to CPU and clear cache
303
+ if device == "cuda":
304
+ vae.to(cpu)
305
+ image_encoder.to(cpu)
306
+ torch.cuda.empty_cache()
307
+ print("VAE and image encoder moved back to CPU, CUDA cache cleared")
308
+
309
+ print(f"End latent shape: {latent.shape}")
310
+ return latent, clip_embedding, processed_image_np
311
+
312
+ except Exception as e:
313
+ print(f"Error in image_encode: {str(e)}")
314
+ raise
315
+
316
+ # 20250508 pftq: for saving prompt to mp4 metadata comments
317
+ def set_mp4_comments_imageio_ffmpeg(input_file, comments):
318
+ try:
319
+ # Get the path to the bundled FFmpeg binary from imageio-ffmpeg
320
+ ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
321
+
322
+ # Check if input file exists
323
+ if not os.path.exists(input_file):
324
+ print(f"Error: Input file {input_file} does not exist")
325
+ return False
326
+
327
+ # Create a temporary file path
328
+ temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name
329
+
330
+ # FFmpeg command using the bundled binary
331
+ command = [
332
+ ffmpeg_path, # Use imageio-ffmpeg's FFmpeg
333
+ '-i', input_file, # input file
334
+ '-metadata', f'comment={comments}', # set comment metadata
335
+ '-c:v', 'copy', # copy video stream without re-encoding
336
+ '-c:a', 'copy', # copy audio stream without re-encoding
337
+ '-y', # overwrite output file if it exists
338
+ temp_file # temporary output file
339
+ ]
340
+
341
+ # Run the FFmpeg command
342
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
343
+
344
+ if result.returncode == 0:
345
+ # Replace the original file with the modified one
346
+ shutil.move(temp_file, input_file)
347
+ print(f"Successfully added comments to {input_file}")
348
+ return True
349
+ else:
350
+ # Clean up temp file if FFmpeg fails
351
+ if os.path.exists(temp_file):
352
+ os.remove(temp_file)
353
+ print(f"Error: FFmpeg failed with message:\n{result.stderr}")
354
+ return False
355
+
356
+ except Exception as e:
357
+ # Clean up temp file in case of other errors
358
+ if 'temp_file' in locals() and os.path.exists(temp_file):
359
+ os.remove(temp_file)
360
+ print(f"Error saving prompt to video metadata, ffmpeg may be required: "+str(e))
361
+ return False
362
+
363
+ # 20250506 pftq: Modified worker to accept video input, and clean frame count
364
+ @torch.no_grad()
365
+ def worker(input_video, end_frame, end_frame_weight, prompt, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
366
+
367
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
368
+
369
+ try:
370
+ # Clean GPU
371
+ if not high_vram:
372
+ unload_complete_models(
373
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
374
+ )
375
+
376
+ # Text encoding
377
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
378
+
379
+ if not high_vram:
380
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
381
+ load_model_as_complete(text_encoder_2, target_device=gpu)
382
+
383
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
384
+
385
+ if cfg == 1:
386
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
387
+ else:
388
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
389
+
390
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
391
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
392
+
393
+ # 20250506 pftq: Processing input video instead of image
394
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Video processing ...'))))
395
+
396
+ # 20250506 pftq: Encode video
397
+ start_latent, input_image_np, video_latents, fps, height, width, input_video_pixels, end_of_input_video_latent, end_of_input_video_image_np = video_encode(input_video, resolution, no_resize, vae, vae_batch_size=vae_batch, device=gpu)
398
+
399
+ #Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
400
+
401
+ # CLIP Vision
402
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
403
+
404
+ if not high_vram:
405
+ load_model_as_complete(image_encoder, target_device=gpu)
406
+
407
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
408
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
409
+ start_embedding = image_encoder_last_hidden_state
410
+
411
+ end_of_input_video_output = hf_clip_vision_encode(end_of_input_video_image_np, feature_extractor, image_encoder)
412
+ end_of_input_video_last_hidden_state = end_of_input_video_output.last_hidden_state
413
+ end_of_input_video_embedding = end_of_input_video_last_hidden_state
414
+
415
+ # 20250507 pftq: Process end frame if provided
416
+ end_latent = None
417
+ end_clip_embedding = None
418
+ if end_frame is not None:
419
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'End frame encoding ...'))))
420
+ end_latent, end_clip_embedding, _ = image_encode(
421
+ end_frame, target_width=width, target_height=height, vae=vae,
422
+ image_encoder=image_encoder, feature_extractor=feature_extractor, device=gpu
423
+ )
424
+
425
+ # Dtype
426
+ llama_vec = llama_vec.to(transformer.dtype)
427
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
428
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
429
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
430
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
431
+ end_of_input_video_embedding = end_of_input_video_embedding.to(transformer.dtype)
432
+
433
+ # 20250509 pftq: Restored original placement of total_latent_sections after video_encode
434
+ total_latent_sections = (total_second_length * fps) / (latent_window_size * 4)
435
+ total_latent_sections = int(max(round(total_latent_sections), 1))
436
+
437
+ for idx in range(batch):
438
+ if batch > 1:
439
+ print(f"Beginning video {idx+1} of {batch} with seed {seed} ")
440
+
441
+ job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+f"_framepack-videoinput-endframe_{width}-{total_second_length}sec_seed-{seed}_steps-{steps}_distilled-{gs}_cfg-{cfg}"
442
+
443
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
444
+
445
+ rnd = torch.Generator("cpu").manual_seed(seed)
446
+
447
+ history_latents = video_latents.cpu()
448
+ history_pixels = None
449
+ total_generated_latent_frames = 0
450
+ previous_video = None
451
+
452
+
453
+ # 20250509 Generate backwards with end frame for better end frame anchoring
454
+ if total_latent_sections > 4:
455
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
456
+ else:
457
+ latent_paddings = list(reversed(range(total_latent_sections)))
458
+
459
+ for section_index, latent_padding in enumerate(latent_paddings):
460
+ is_start_of_video = latent_padding == 0
461
+ is_end_of_video = latent_padding == latent_paddings[0]
462
+ latent_padding_size = latent_padding * latent_window_size
463
+
464
+ if stream.input_queue.top() == 'end':
465
+ stream.output_queue.push(('end', None))
466
+ return
467
+
468
+ if not high_vram:
469
+ unload_complete_models()
470
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
471
+
472
+ if use_teacache:
473
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
474
+ else:
475
+ transformer.initialize_teacache(enable_teacache=False)
476
+
477
+ def callback(d):
478
+ try:
479
+ preview = d['denoised']
480
+ preview = vae_decode_fake(preview)
481
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
482
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
483
+ if stream.input_queue.top() == 'end':
484
+ stream.output_queue.push(('end', None))
485
+ raise KeyboardInterrupt('User ends the task.')
486
+ current_step = d['i'] + 1
487
+ percentage = int(100.0 * current_step / steps)
488
+ hint = f'Sampling {current_step}/{steps}'
489
+ desc = f'Total frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / fps) :.2f} seconds (FPS-{fps}), Seed: {seed}, Video {idx+1} of {batch}. Generating part {total_latent_sections - section_index} of {total_latent_sections} backward...'
490
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
491
+ except ConnectionResetError as e:
492
+ print(f"Suppressed ConnectionResetError in callback: {e}")
493
+ return
494
+
495
+ # 20250509 pftq: Dynamic frame allocation like original num_clean_frames, fix split error
496
+ available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2]
497
+ if is_start_of_video:
498
+ effective_clean_frames = 1 # avoid jumpcuts from input video
499
+ else:
500
+ effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
501
+ clean_latent_pre_frames = effective_clean_frames
502
+ num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1
503
+ num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1
504
+ total_context_frames = num_2x_frames + num_4x_frames
505
+ total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames)
506
+
507
+ # 20250511 pftq: Dynamically adjust post_frames based on clean_latents_post
508
+ post_frames = 1 if is_end_of_video and end_latent is not None else effective_clean_frames # 20250511 pftq: Single frame for end_latent, otherwise padding causes still image
509
+ indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0)
510
+ clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split(
511
+ [clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1
512
+ )
513
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
514
+
515
+ # 20250509 pftq: Split context frames dynamically for 2x and 4x only
516
+ context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :]
517
+ split_sizes = [num_4x_frames, num_2x_frames]
518
+ split_sizes = [s for s in split_sizes if s > 0]
519
+ if split_sizes and context_frames.shape[2] >= sum(split_sizes):
520
+ splits = context_frames.split(split_sizes, dim=2)
521
+ split_idx = 0
522
+ clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :]
523
+ split_idx += 1 if num_4x_frames > 0 else 0
524
+ clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :]
525
+ else:
526
+ clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :]
527
+
528
+ clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents) # smoother motion but jumpcuts if end frame is too different, must change clean_latent_pre_frames to effective_clean_frames also
529
+ clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :] # smoother motion, must change post_frames to effective_clean_frames also
530
+
531
+ if is_end_of_video:
532
+ clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents)
533
+
534
+ # 20250509 pftq: handle end frame if available
535
+ if end_latent is not None:
536
+ #current_end_frame_weight = end_frame_weight * (latent_padding / latent_paddings[0])
537
+ #current_end_frame_weight = current_end_frame_weight * 0.5 + 0.5
538
+ current_end_frame_weight = end_frame_weight # changing this over time introduces discontinuity
539
+ # 20250511 pftq: Removed end frame weight adjustment as it has no effect
540
+ image_encoder_last_hidden_state = (1 - current_end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * current_end_frame_weight
541
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
542
+
543
+ # 20250511 pftq: Use end_latent only
544
+ if is_end_of_video:
545
+ clean_latents_post = end_latent.to(history_latents)[:, :, :1, :, :] # Ensure single frame
546
+
547
+ # 20250511 pftq: Pad clean_latents_pre to match clean_latent_pre_frames if needed
548
+ if clean_latents_pre.shape[2] < clean_latent_pre_frames:
549
+ clean_latents_pre = clean_latents_pre.repeat(1, 1, clean_latent_pre_frames // clean_latents_pre.shape[2], 1, 1)
550
+ # 20250511 pftq: Pad clean_latents_post to match post_frames if needed
551
+ if clean_latents_post.shape[2] < post_frames:
552
+ clean_latents_post = clean_latents_post.repeat(1, 1, post_frames // clean_latents_post.shape[2], 1, 1)
553
+
554
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
555
+
556
+ max_frames = min(latent_window_size * 4 - 3, history_latents.shape[2] * 4)
557
+ print(f"Generating video {idx+1} of {batch} with seed {seed}, part {total_latent_sections - section_index} of {total_latent_sections} backward")
558
+ generated_latents = sample_hunyuan(
559
+ transformer=transformer,
560
+ sampler='unipc',
561
+ width=width,
562
+ height=height,
563
+ frames=max_frames,
564
+ real_guidance_scale=cfg,
565
+ distilled_guidance_scale=gs,
566
+ guidance_rescale=rs,
567
+ num_inference_steps=steps,
568
+ generator=rnd,
569
+ prompt_embeds=llama_vec,
570
+ prompt_embeds_mask=llama_attention_mask,
571
+ prompt_poolers=clip_l_pooler,
572
+ negative_prompt_embeds=llama_vec_n,
573
+ negative_prompt_embeds_mask=llama_attention_mask_n,
574
+ negative_prompt_poolers=clip_l_pooler_n,
575
+ device=gpu,
576
+ dtype=torch.bfloat16,
577
+ image_embeddings=image_encoder_last_hidden_state,
578
+ latent_indices=latent_indices,
579
+ clean_latents=clean_latents,
580
+ clean_latent_indices=clean_latent_indices,
581
+ clean_latents_2x=clean_latents_2x,
582
+ clean_latent_2x_indices=clean_latent_2x_indices,
583
+ clean_latents_4x=clean_latents_4x,
584
+ clean_latent_4x_indices=clean_latent_4x_indices,
585
+ callback=callback,
586
+ )
587
+
588
+ if is_start_of_video:
589
+ generated_latents = torch.cat([video_latents[:, :, -1:].to(generated_latents), generated_latents], dim=2)
590
+
591
+ total_generated_latent_frames += int(generated_latents.shape[2])
592
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
593
+
594
+ if not high_vram:
595
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
596
+ load_model_as_complete(vae, target_device=gpu)
597
+
598
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
599
+ if history_pixels is None:
600
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
601
+ else:
602
+ section_latent_frames = (latent_window_size * 2 + 1) if is_start_of_video else (latent_window_size * 2)
603
+ overlapped_frames = latent_window_size * 4 - 3
604
+ current_pixels = vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
605
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
606
+
607
+ if not high_vram:
608
+ unload_complete_models()
609
+
610
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
611
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=fps, crf=mp4_crf)
612
+ print(f"Latest video saved: {output_filename}")
613
+ set_mp4_comments_imageio_ffmpeg(output_filename, f"Prompt: {prompt} | Negative Prompt: {n_prompt}")
614
+ print(f"Prompt saved to mp4 metadata comments: {output_filename}")
615
+
616
+ if previous_video is not None and os.path.exists(previous_video):
617
+ try:
618
+ os.remove(previous_video)
619
+ print(f"Previous partial video deleted: {previous_video}")
620
+ except Exception as e:
621
+ print(f"Error deleting previous partial video {previous_video}: {e}")
622
+ previous_video = output_filename
623
+
624
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
625
+ stream.output_queue.push(('file', output_filename))
626
+
627
+ if is_start_of_video:
628
+ break
629
+
630
+ history_pixels = torch.cat([input_video_pixels, history_pixels], dim=2)
631
+ #overlapped_frames = latent_window_size * 4 - 3
632
+ #history_pixels = soft_append_bcthw(input_video_pixels, history_pixels, overlapped_frames)
633
+
634
+ output_filename = os.path.join(outputs_folder, f'{job_id}_final.mp4')
635
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=fps, crf=mp4_crf)
636
+ print(f"Final video with input blend saved: {output_filename}")
637
+ set_mp4_comments_imageio_ffmpeg(output_filename, f"Prompt: {prompt} | Negative Prompt: {n_prompt}")
638
+ print(f"Prompt saved to mp4 metadata comments: {output_filename}")
639
+ stream.output_queue.push(('file', output_filename))
640
+
641
+ if previous_video is not None and os.path.exists(previous_video):
642
+ try:
643
+ os.remove(previous_video)
644
+ print(f"Previous partial video deleted: {previous_video}")
645
+ except Exception as e:
646
+ print(f"Error deleting previous partial video {previous_video}: {e}")
647
+ previous_video = output_filename
648
+
649
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
650
+
651
+ stream.output_queue.push(('file', output_filename))
652
+
653
+ seed = (seed + 1) % np.iinfo(np.int32).max
654
+
655
+ except:
656
+ traceback.print_exc()
657
+
658
+ if not high_vram:
659
+ unload_complete_models(
660
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
661
+ )
662
+
663
+ stream.output_queue.push(('end', None))
664
+ return
665
+
666
+ # 20250506 pftq: Modified process to pass clean frame count, etc
667
+ def get_duration(
668
+ input_video, end_frame, end_frame_weight, prompt, n_prompt,
669
+ randomize_seed,
670
+ seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache,
671
+ no_resize, mp4_crf, num_clean_frames, vae_batch):
672
+ return total_second_length * 60 * 2
673
+
674
+ @spaces.GPU(duration=get_duration)
675
+ def process(
676
+ input_video, end_frame, end_frame_weight, prompt, n_prompt,
677
+ randomize_seed,
678
+ seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache,
679
+ no_resize, mp4_crf, num_clean_frames, vae_batch):
680
+ global stream, high_vram
681
+
682
+ if torch.cuda.device_count() == 0:
683
+ gr.Warning('Set this space to GPU config to make it work.')
684
+ return None, None, None, None, None, None
685
+
686
+ if randomize_seed:
687
+ seed = random.randint(0, np.iinfo(np.int32).max)
688
+
689
+ # 20250506 pftq: Updated assertion for video input
690
+ assert input_video is not None, 'No input video!'
691
+
692
+ yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
693
+
694
+ # 20250507 pftq: Even the H100 needs offloading if the video dimensions are 720p or higher
695
+ if high_vram and (no_resize or resolution>640):
696
+ print("Disabling high vram mode due to no resize and/or potentially higher resolution...")
697
+ high_vram = False
698
+ vae.enable_slicing()
699
+ vae.enable_tiling()
700
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
701
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
702
+
703
+ # 20250508 pftq: automatically set distilled cfg to 1 if cfg is used
704
+ if cfg > 1:
705
+ gs = 1
706
+
707
+ stream = AsyncStream()
708
+
709
+ # 20250506 pftq: Pass num_clean_frames, vae_batch, etc
710
+ async_run(worker, input_video, end_frame, end_frame_weight, prompt, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch)
711
+
712
+ output_filename = None
713
+
714
+ while True:
715
+ flag, data = stream.output_queue.next()
716
+
717
+ if flag == 'file':
718
+ output_filename = data
719
+ yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
720
+
721
+ if flag == 'progress':
722
+ preview, desc, html = data
723
+ #yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
724
+ yield output_filename, gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True) # 20250506 pftq: Keep refreshing the video in case it got hidden when the tab was in the background
725
+
726
+ if flag == 'end':
727
+ yield output_filename, gr.update(visible=False), desc+' Video complete.', '', gr.update(interactive=True), gr.update(interactive=False)
728
+ break
729
+
730
+ def end_process():
731
+ stream.input_queue.push('end')
732
+
733
+ quick_prompts = [
734
+ 'The girl dances gracefully, with clear movements, full of charm.',
735
+ 'A character doing some simple body movements.',
736
+ ]
737
+ quick_prompts = [[x] for x in quick_prompts]
738
+
739
+ css = make_progress_bar_css()
740
+ block = gr.Blocks(css=css).queue(
741
+ max_size=10 # 20250507 pftq: Limit queue size
742
+ )
743
+ with block:
744
+ if torch.cuda.device_count() == 0:
745
+ with gr.Row():
746
+ gr.HTML("""
747
+ <p style="background-color: red;"><big><big><big><b>⚠️To use FramePack, <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true">duplicate this space</a> and set a GPU with 30 GB VRAM.</b>
748
+
749
+ You can't use FramePack directly here because this space runs on a CPU, which is not enough for FramePack. Please provide <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">feedback</a> if you have issues.
750
+ </big></big></big></p>
751
+ """)
752
+ # 20250506 pftq: Updated title to reflect video input functionality
753
+ gr.Markdown('# Framepack with Video Input (Video Extension) + End Frame')
754
+ with gr.Row():
755
+ with gr.Column():
756
+
757
+ # 20250506 pftq: Changed to Video input from Image
758
+ with gr.Row():
759
+ input_video = gr.Video(sources='upload', label="Input Video", height=320)
760
+ with gr.Column():
761
+ # 20250507 pftq: Added end_frame + weight
762
+ end_frame = gr.Image(sources='upload', type="numpy", label="End Frame (Optional) - Reduce context frames if very different from input video or if it is jumpcutting/slowing to still image.", height=320)
763
+ end_frame_weight = gr.Slider(label="End Frame Weight", minimum=0.0, maximum=1.0, value=1.0, step=0.01, info='Reduce to treat more as a reference image; no effect')
764
+
765
+ prompt = gr.Textbox(label="Prompt", value='')
766
+
767
+ with gr.Row():
768
+ start_button = gr.Button(value="Start Generation", variant="primary")
769
+ end_button = gr.Button(value="End Generation", variant="stop", interactive=False)
770
+
771
+ with gr.Accordion("Advanced settings", open=False):
772
+ with gr.Row():
773
+ use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.')
774
+ no_resize = gr.Checkbox(label='Force Original Video Resolution (No Resizing)', value=False, info='Might run out of VRAM (720p requires > 24GB VRAM).')
775
+
776
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=True, info='If checked, the seed is always different')
777
+ seed = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, randomize=True)
778
+
779
+ batch = gr.Slider(label="Batch Size (Number of Videos)", minimum=1, maximum=1000, value=1, step=1, info='Generate multiple videos each with a different seed.')
780
+
781
+ resolution = gr.Number(label="Resolution (max width or height)", value=640, precision=0)
782
+
783
+ total_second_length = gr.Slider(label="Additional Video Length to Generate (Seconds)", minimum=1, maximum=120, value=5, step=0.1)
784
+
785
+ # 20250506 pftq: Reduced default distilled guidance scale to improve adherence to input video
786
+ gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01, info='Prompt adherence at the cost of less details from the input video, but to a lesser extent than Context Frames.')
787
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, info='Use instead of Distilled for more detail/control + Negative Prompt (make sure Distilled=1). Doubles render time.') # Should not change
788
+ rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01) # Should not change
789
+
790
+ n_prompt = gr.Textbox(label="Negative Prompt", value="Missing arm, unrealistic position, blurred, blurry", info='Requires using normal CFG (undistilled) instead of Distilled (set Distilled=1 and CFG > 1).')
791
+
792
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, info='Expensive. Increase for more quality, especially if using high non-distilled CFG.')
793
+
794
+ # 20250506 pftq: Renamed slider to Number of Context Frames and updated description
795
+ num_clean_frames = gr.Slider(label="Number of Context Frames (Adherence to Video)", minimum=2, maximum=10, value=5, step=1, info="Expensive. Retain more video details. Reduce if memory issues or motion too restricted (jumpcut, ignoring prompt, still).")
796
+
797
+ default_vae = 32
798
+ if high_vram:
799
+ default_vae = 128
800
+ elif free_mem_gb>=20:
801
+ default_vae = 64
802
+
803
+ vae_batch = gr.Slider(label="VAE Batch Size for Input Video", minimum=4, maximum=256, value=default_vae, step=4, info="Expensive. Increase for better quality frames during fast motion. Reduce if running out of memory")
804
+
805
+ latent_window_size = gr.Slider(label="Latent Window Size", minimum=9, maximum=49, value=9, step=1, info='Expensive. Generate more frames at a time (larger chunks). Less degradation but higher VRAM cost.')
806
+
807
+ gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.")
808
+
809
+ mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ")
810
+
811
+ with gr.Column():
812
+ preview_image = gr.Image(label="Next Latents", height=200, visible=False)
813
+ result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True)
814
+ progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
815
+ progress_bar = gr.HTML('', elem_classes='no-generating-animation')
816
+
817
+ # 20250506 pftq: Updated inputs to include num_clean_frames
818
+ ips = [input_video, end_frame, end_frame_weight, prompt, n_prompt, randomize_seed, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch]
819
+ start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
820
+ end_button.click(fn=end_process)
821
+
822
+ block.launch(share=True)
diffusers_helper/bucket_tools.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 640: [
3
+ (416, 960),
4
+ (448, 864),
5
+ (480, 832),
6
+ (512, 768),
7
+ (544, 704),
8
+ (576, 672),
9
+ (608, 640),
10
+ (640, 608),
11
+ (672, 576),
12
+ (704, 544),
13
+ (768, 512),
14
+ (832, 480),
15
+ (864, 448),
16
+ (960, 416),
17
+ ],
18
+ 672: [
19
+ (480, 864),
20
+ (512, 832),
21
+ (544, 768),
22
+ (576, 704),
23
+ (608, 672),
24
+ (640, 640),
25
+ (672, 608),
26
+ (704, 576),
27
+ (768, 544),
28
+ (832, 512),
29
+ (864, 480),
30
+ ],
31
+ 704: [
32
+ (480, 960),
33
+ (512, 864),
34
+ (544, 832),
35
+ (576, 768),
36
+ (608, 704),
37
+ (640, 672),
38
+ (672, 640),
39
+ (704, 608),
40
+ (768, 576),
41
+ (832, 544),
42
+ (864, 512),
43
+ (960, 480),
44
+ ],
45
+ 768: [
46
+ (512, 960),
47
+ (544, 864),
48
+ (576, 832),
49
+ (608, 768),
50
+ (640, 704),
51
+ (672, 672),
52
+ (704, 640),
53
+ (768, 608),
54
+ (832, 576),
55
+ (864, 544),
56
+ (960, 512),
57
+ ],
58
+ 832: [
59
+ (544, 960),
60
+ (576, 864),
61
+ (608, 832),
62
+ (640, 768),
63
+ (672, 704),
64
+ (704, 672),
65
+ (768, 640),
66
+ (832, 608),
67
+ (864, 576),
68
+ (960, 544),
69
+ ],
70
+ 864: [
71
+ (576, 960),
72
+ (608, 864),
73
+ (640, 832),
74
+ (672, 768),
75
+ (704, 704),
76
+ (768, 672),
77
+ (832, 640),
78
+ (864, 608),
79
+ (960, 576),
80
+ ],
81
+ 960: [
82
+ (608, 960),
83
+ (640, 864),
84
+ (672, 832),
85
+ (704, 768),
86
+ (768, 704),
87
+ (832, 672),
88
+ (864, 640),
89
+ (960, 608),
90
+ ],
91
+ }
92
+
93
+
94
+ def find_nearest_bucket(h, w, resolution=640):
95
+ min_metric = float('inf')
96
+ best_bucket = None
97
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
98
+ metric = abs(h * bucket_w - w * bucket_h)
99
+ if metric <= min_metric:
100
+ min_metric = metric
101
+ best_bucket = (bucket_h, bucket_w)
102
+ print("The resolution of the generated video will be " + str(best_bucket))
103
+ return best_bucket
diffusers_helper/clip_vision.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10
+ image_encoder_output = image_encoder(**preprocessed)
11
+
12
+ return image_encoder_output
diffusers_helper/dit_common.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import accelerate.accelerator
3
+
4
+ from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5
+
6
+
7
+ accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8
+
9
+
10
+ def LayerNorm_forward(self, x):
11
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12
+
13
+
14
+ LayerNorm.forward = LayerNorm_forward
15
+ torch.nn.LayerNorm.forward = LayerNorm_forward
16
+
17
+
18
+ def FP32LayerNorm_forward(self, x):
19
+ origin_dtype = x.dtype
20
+ return torch.nn.functional.layer_norm(
21
+ x.float(),
22
+ self.normalized_shape,
23
+ self.weight.float() if self.weight is not None else None,
24
+ self.bias.float() if self.bias is not None else None,
25
+ self.eps,
26
+ ).to(origin_dtype)
27
+
28
+
29
+ FP32LayerNorm.forward = FP32LayerNorm_forward
30
+
31
+
32
+ def RMSNorm_forward(self, hidden_states):
33
+ input_dtype = hidden_states.dtype
34
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36
+
37
+ if self.weight is None:
38
+ return hidden_states.to(input_dtype)
39
+
40
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41
+
42
+
43
+ RMSNorm.forward = RMSNorm_forward
44
+
45
+
46
+ def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47
+ emb = self.linear(self.silu(conditioning_embedding))
48
+ scale, shift = emb.chunk(2, dim=1)
49
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50
+ return x
51
+
52
+
53
+ AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
diffusers_helper/gradio/progress_bar.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progress_html = '''
2
+ <div class="loader-container">
3
+ <div class="loader"></div>
4
+ <div class="progress-container">
5
+ <progress value="*number*" max="100"></progress>
6
+ </div>
7
+ <span>*text*</span>
8
+ </div>
9
+ '''
10
+
11
+ css = '''
12
+ .loader-container {
13
+ display: flex; /* Use flex to align items horizontally */
14
+ align-items: center; /* Center items vertically within the container */
15
+ white-space: nowrap; /* Prevent line breaks within the container */
16
+ }
17
+
18
+ .loader {
19
+ border: 8px solid #f3f3f3; /* Light grey */
20
+ border-top: 8px solid #3498db; /* Blue */
21
+ border-radius: 50%;
22
+ width: 30px;
23
+ height: 30px;
24
+ animation: spin 2s linear infinite;
25
+ }
26
+
27
+ @keyframes spin {
28
+ 0% { transform: rotate(0deg); }
29
+ 100% { transform: rotate(360deg); }
30
+ }
31
+
32
+ /* Style the progress bar */
33
+ progress {
34
+ appearance: none; /* Remove default styling */
35
+ height: 20px; /* Set the height of the progress bar */
36
+ border-radius: 5px; /* Round the corners of the progress bar */
37
+ background-color: #f3f3f3; /* Light grey background */
38
+ width: 100%;
39
+ vertical-align: middle !important;
40
+ }
41
+
42
+ /* Style the progress bar container */
43
+ .progress-container {
44
+ margin-left: 20px;
45
+ margin-right: 20px;
46
+ flex-grow: 1; /* Allow the progress container to take up remaining space */
47
+ }
48
+
49
+ /* Set the color of the progress bar fill */
50
+ progress::-webkit-progress-value {
51
+ background-color: #3498db; /* Blue color for the fill */
52
+ }
53
+
54
+ progress::-moz-progress-bar {
55
+ background-color: #3498db; /* Blue color for the fill in Firefox */
56
+ }
57
+
58
+ /* Style the text on the progress bar */
59
+ progress::after {
60
+ content: attr(value '%'); /* Display the progress value followed by '%' */
61
+ position: absolute;
62
+ top: 50%;
63
+ left: 50%;
64
+ transform: translate(-50%, -50%);
65
+ color: white; /* Set text color */
66
+ font-size: 14px; /* Set font size */
67
+ }
68
+
69
+ /* Style other texts */
70
+ .loader-container > span {
71
+ margin-left: 5px; /* Add spacing between the progress bar and the text */
72
+ }
73
+
74
+ .no-generating-animation > .generating {
75
+ display: none !important;
76
+ }
77
+
78
+ '''
79
+
80
+
81
+ def make_progress_bar_html(number, text):
82
+ return progress_html.replace('*number*', str(number)).replace('*text*', text)
83
+
84
+
85
+ def make_progress_bar_css():
86
+ return css
diffusers_helper/hf_login.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def login(token):
5
+ from huggingface_hub import login
6
+ import time
7
+
8
+ while True:
9
+ try:
10
+ login(token)
11
+ print('HF login ok.')
12
+ break
13
+ except Exception as e:
14
+ print(f'HF login failed: {e}. Retrying')
15
+ time.sleep(0.5)
16
+
17
+
18
+ hf_token = os.environ.get('HF_TOKEN', None)
19
+
20
+ if hf_token is not None:
21
+ login(hf_token)
diffusers_helper/hunyuan.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
+ from diffusers_helper.utils import crop_or_pad_yield_mask
5
+
6
+
7
+ @torch.no_grad()
8
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9
+ assert isinstance(prompt, str)
10
+
11
+ prompt = [prompt]
12
+
13
+ # LLAMA
14
+
15
+ prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
16
+ crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
17
+
18
+ llama_inputs = tokenizer(
19
+ prompt_llama,
20
+ padding="max_length",
21
+ max_length=max_length + crop_start,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ return_length=False,
25
+ return_overflowing_tokens=False,
26
+ return_attention_mask=True,
27
+ )
28
+
29
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
30
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
31
+ llama_attention_length = int(llama_attention_mask.sum())
32
+
33
+ llama_outputs = text_encoder(
34
+ input_ids=llama_input_ids,
35
+ attention_mask=llama_attention_mask,
36
+ output_hidden_states=True,
37
+ )
38
+
39
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
40
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
41
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
42
+
43
+ assert torch.all(llama_attention_mask.bool())
44
+
45
+ # CLIP
46
+
47
+ clip_l_input_ids = tokenizer_2(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=77,
51
+ truncation=True,
52
+ return_overflowing_tokens=False,
53
+ return_length=False,
54
+ return_tensors="pt",
55
+ ).input_ids
56
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
57
+
58
+ return llama_vec, clip_l_pooler
59
+
60
+
61
+ @torch.no_grad()
62
+ def vae_decode_fake(latents):
63
+ latent_rgb_factors = [
64
+ [-0.0395, -0.0331, 0.0445],
65
+ [0.0696, 0.0795, 0.0518],
66
+ [0.0135, -0.0945, -0.0282],
67
+ [0.0108, -0.0250, -0.0765],
68
+ [-0.0209, 0.0032, 0.0224],
69
+ [-0.0804, -0.0254, -0.0639],
70
+ [-0.0991, 0.0271, -0.0669],
71
+ [-0.0646, -0.0422, -0.0400],
72
+ [-0.0696, -0.0595, -0.0894],
73
+ [-0.0799, -0.0208, -0.0375],
74
+ [0.1166, 0.1627, 0.0962],
75
+ [0.1165, 0.0432, 0.0407],
76
+ [-0.2315, -0.1920, -0.1355],
77
+ [-0.0270, 0.0401, -0.0821],
78
+ [-0.0616, -0.0997, -0.0727],
79
+ [0.0249, -0.0469, -0.1703]
80
+ ] # From comfyui
81
+
82
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
83
+
84
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
85
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
86
+
87
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
88
+ images = images.clamp(0.0, 1.0)
89
+
90
+ return images
91
+
92
+
93
+ @torch.no_grad()
94
+ def vae_decode(latents, vae, image_mode=False):
95
+ latents = latents / vae.config.scaling_factor
96
+
97
+ if image_mode:
98
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
99
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
100
+ image = torch.cat(image, dim=2)
101
+ else:
102
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
103
+
104
+ return image
105
+
106
+
107
+ @torch.no_grad()
108
+ def vae_encode(image, vae):
109
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
110
+ latents = latents * vae.config.scaling_factor
111
+ return latents
diffusers_helper/k_diffusion/uni_pc_fm.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Better Flow Matching UniPC by Lvmin Zhang
2
+ # (c) 2025
3
+ # CC BY-SA 4.0
4
+ # Attribution-ShareAlike 4.0 International Licence
5
+
6
+
7
+ import torch
8
+
9
+ from tqdm.auto import trange
10
+
11
+
12
+ def expand_dims(v, dims):
13
+ return v[(...,) + (None,) * (dims - 1)]
14
+
15
+
16
+ class FlowMatchUniPC:
17
+ def __init__(self, model, extra_args, variant='bh1'):
18
+ self.model = model
19
+ self.variant = variant
20
+ self.extra_args = extra_args
21
+
22
+ def model_fn(self, x, t):
23
+ return self.model(x, t, **self.extra_args)
24
+
25
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
+ assert order <= len(model_prev_list)
27
+ dims = x.dim()
28
+
29
+ t_prev_0 = t_prev_list[-1]
30
+ lambda_prev_0 = - torch.log(t_prev_0)
31
+ lambda_t = - torch.log(t)
32
+ model_prev_0 = model_prev_list[-1]
33
+
34
+ h = lambda_t - lambda_prev_0
35
+
36
+ rks = []
37
+ D1s = []
38
+ for i in range(1, order):
39
+ t_prev_i = t_prev_list[-(i + 1)]
40
+ model_prev_i = model_prev_list[-(i + 1)]
41
+ lambda_prev_i = - torch.log(t_prev_i)
42
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
+ rks.append(rk)
44
+ D1s.append((model_prev_i - model_prev_0) / rk)
45
+
46
+ rks.append(1.)
47
+ rks = torch.tensor(rks, device=x.device)
48
+
49
+ R = []
50
+ b = []
51
+
52
+ hh = -h[0]
53
+ h_phi_1 = torch.expm1(hh)
54
+ h_phi_k = h_phi_1 / hh - 1
55
+
56
+ factorial_i = 1
57
+
58
+ if self.variant == 'bh1':
59
+ B_h = hh
60
+ elif self.variant == 'bh2':
61
+ B_h = torch.expm1(hh)
62
+ else:
63
+ raise NotImplementedError('Bad variant!')
64
+
65
+ for i in range(1, order + 1):
66
+ R.append(torch.pow(rks, i - 1))
67
+ b.append(h_phi_k * factorial_i / B_h)
68
+ factorial_i *= (i + 1)
69
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
70
+
71
+ R = torch.stack(R)
72
+ b = torch.tensor(b, device=x.device)
73
+
74
+ use_predictor = len(D1s) > 0
75
+
76
+ if use_predictor:
77
+ D1s = torch.stack(D1s, dim=1)
78
+ if order == 2:
79
+ rhos_p = torch.tensor([0.5], device=b.device)
80
+ else:
81
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
+ else:
83
+ D1s = None
84
+ rhos_p = None
85
+
86
+ if order == 1:
87
+ rhos_c = torch.tensor([0.5], device=b.device)
88
+ else:
89
+ rhos_c = torch.linalg.solve(R, b)
90
+
91
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
+
93
+ if use_predictor:
94
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
+ else:
96
+ pred_res = 0
97
+
98
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
+ model_t = self.model_fn(x_t, t)
100
+
101
+ if D1s is not None:
102
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
+ else:
104
+ corr_res = 0
105
+
106
+ D1_t = (model_t - model_prev_0)
107
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
+
109
+ return x_t, model_t
110
+
111
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
+ order = min(3, len(sigmas) - 2)
113
+ model_prev_list, t_prev_list = [], []
114
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
+ vec_t = sigmas[i].expand(x.shape[0])
116
+
117
+ if i == 0:
118
+ model_prev_list = [self.model_fn(x, vec_t)]
119
+ t_prev_list = [vec_t]
120
+ elif i < order:
121
+ init_order = i
122
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123
+ model_prev_list.append(model_x)
124
+ t_prev_list.append(vec_t)
125
+ else:
126
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127
+ model_prev_list.append(model_x)
128
+ t_prev_list.append(vec_t)
129
+
130
+ model_prev_list = model_prev_list[-order:]
131
+ t_prev_list = t_prev_list[-order:]
132
+
133
+ if callback is not None:
134
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135
+
136
+ return model_prev_list[-1]
137
+
138
+
139
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
140
+ assert variant in ['bh1', 'bh2']
141
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
diffusers_helper/k_diffusion/wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def append_dims(x, target_dims):
5
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
6
+
7
+
8
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
+ if guidance_rescale == 0:
10
+ return noise_cfg
11
+
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
+ return noise_cfg
17
+
18
+
19
+ def fm_wrapper(transformer, t_scale=1000.0):
20
+ def k_model(x, sigma, **extra_args):
21
+ dtype = extra_args['dtype']
22
+ cfg_scale = extra_args['cfg_scale']
23
+ cfg_rescale = extra_args['cfg_rescale']
24
+ concat_latent = extra_args['concat_latent']
25
+
26
+ original_dtype = x.dtype
27
+ sigma = sigma.float()
28
+
29
+ x = x.to(dtype)
30
+ timestep = (sigma * t_scale).to(dtype)
31
+
32
+ if concat_latent is None:
33
+ hidden_states = x
34
+ else:
35
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
+
37
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
+
39
+ if cfg_scale == 1.0:
40
+ pred_negative = torch.zeros_like(pred_positive)
41
+ else:
42
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
+
44
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
+
47
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
+
49
+ return x0.to(dtype=original_dtype)
50
+
51
+ return k_model
diffusers_helper/memory.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # By lllyasviel
2
+
3
+
4
+ import torch
5
+
6
+
7
+ cpu = torch.device('cpu')
8
+ gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9
+ gpu_complete_modules = []
10
+
11
+
12
+ class DynamicSwapInstaller:
13
+ @staticmethod
14
+ def _install_module(module: torch.nn.Module, **kwargs):
15
+ original_class = module.__class__
16
+ module.__dict__['forge_backup_original_class'] = original_class
17
+
18
+ def hacked_get_attr(self, name: str):
19
+ if '_parameters' in self.__dict__:
20
+ _parameters = self.__dict__['_parameters']
21
+ if name in _parameters:
22
+ p = _parameters[name]
23
+ if p is None:
24
+ return None
25
+ if p.__class__ == torch.nn.Parameter:
26
+ return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
27
+ else:
28
+ return p.to(**kwargs)
29
+ if '_buffers' in self.__dict__:
30
+ _buffers = self.__dict__['_buffers']
31
+ if name in _buffers:
32
+ return _buffers[name].to(**kwargs)
33
+ return super(original_class, self).__getattr__(name)
34
+
35
+ module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
36
+ '__getattr__': hacked_get_attr,
37
+ })
38
+
39
+ return
40
+
41
+ @staticmethod
42
+ def _uninstall_module(module: torch.nn.Module):
43
+ if 'forge_backup_original_class' in module.__dict__:
44
+ module.__class__ = module.__dict__.pop('forge_backup_original_class')
45
+ return
46
+
47
+ @staticmethod
48
+ def install_model(model: torch.nn.Module, **kwargs):
49
+ for m in model.modules():
50
+ DynamicSwapInstaller._install_module(m, **kwargs)
51
+ return
52
+
53
+ @staticmethod
54
+ def uninstall_model(model: torch.nn.Module):
55
+ for m in model.modules():
56
+ DynamicSwapInstaller._uninstall_module(m)
57
+ return
58
+
59
+
60
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
61
+ if hasattr(model, 'scale_shift_table'):
62
+ model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
63
+ return
64
+
65
+ for k, p in model.named_modules():
66
+ if hasattr(p, 'weight'):
67
+ p.to(target_device)
68
+ return
69
+
70
+
71
+ def get_cuda_free_memory_gb(device=None):
72
+ if device is None:
73
+ device = gpu
74
+
75
+ memory_stats = torch.cuda.memory_stats(device)
76
+ bytes_active = memory_stats['active_bytes.all.current']
77
+ bytes_reserved = memory_stats['reserved_bytes.all.current']
78
+ bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
79
+ bytes_inactive_reserved = bytes_reserved - bytes_active
80
+ bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
81
+ return bytes_total_available / (1024 ** 3)
82
+
83
+
84
+ def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
85
+ print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
86
+
87
+ for m in model.modules():
88
+ if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
89
+ torch.cuda.empty_cache()
90
+ return
91
+
92
+ if hasattr(m, 'weight'):
93
+ m.to(device=target_device)
94
+
95
+ model.to(device=target_device)
96
+ torch.cuda.empty_cache()
97
+ return
98
+
99
+
100
+ def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
101
+ print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
102
+
103
+ for m in model.modules():
104
+ if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
105
+ torch.cuda.empty_cache()
106
+ return
107
+
108
+ if hasattr(m, 'weight'):
109
+ m.to(device=cpu)
110
+
111
+ model.to(device=cpu)
112
+ torch.cuda.empty_cache()
113
+ return
114
+
115
+
116
+ def unload_complete_models(*args):
117
+ for m in gpu_complete_modules + list(args):
118
+ m.to(device=cpu)
119
+ print(f'Unloaded {m.__class__.__name__} as complete.')
120
+
121
+ gpu_complete_modules.clear()
122
+ torch.cuda.empty_cache()
123
+ return
124
+
125
+
126
+ def load_model_as_complete(model, target_device, unload=True):
127
+ if unload:
128
+ unload_complete_models()
129
+
130
+ model.to(device=target_device)
131
+ print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
132
+
133
+ gpu_complete_modules.append(model)
134
+ return
diffusers_helper/models/hunyuan_video_packed.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.utils import logging
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers_helper.dit_common import LayerNorm
18
+ from diffusers_helper.utils import zero_module
19
+
20
+
21
+ enabled_backends = []
22
+
23
+ if torch.backends.cuda.flash_sdp_enabled():
24
+ enabled_backends.append("flash")
25
+ if torch.backends.cuda.math_sdp_enabled():
26
+ enabled_backends.append("math")
27
+ if torch.backends.cuda.mem_efficient_sdp_enabled():
28
+ enabled_backends.append("mem_efficient")
29
+ if torch.backends.cuda.cudnn_sdp_enabled():
30
+ enabled_backends.append("cudnn")
31
+
32
+ print("Currently enabled native sdp backends:", enabled_backends)
33
+
34
+ try:
35
+ # raise NotImplementedError
36
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
37
+ print('Xformers is installed!')
38
+ except:
39
+ print('Xformers is not installed!')
40
+ xformers_attn_func = None
41
+
42
+ try:
43
+ # raise NotImplementedError
44
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
45
+ print('Flash Attn is installed!')
46
+ except:
47
+ print('Flash Attn is not installed!')
48
+ flash_attn_varlen_func = None
49
+ flash_attn_func = None
50
+
51
+ try:
52
+ # raise NotImplementedError
53
+ from sageattention import sageattn_varlen, sageattn
54
+ print('Sage Attn is installed!')
55
+ except:
56
+ print('Sage Attn is not installed!')
57
+ sageattn_varlen = None
58
+ sageattn = None
59
+
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ def pad_for_3d_conv(x, kernel_size):
65
+ b, c, t, h, w = x.shape
66
+ pt, ph, pw = kernel_size
67
+ pad_t = (pt - (t % pt)) % pt
68
+ pad_h = (ph - (h % ph)) % ph
69
+ pad_w = (pw - (w % pw)) % pw
70
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
+
72
+
73
+ def center_down_sample_3d(x, kernel_size):
74
+ # pt, ph, pw = kernel_size
75
+ # cp = (pt * ph * pw) // 2
76
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
+ # xc = xp[cp]
78
+ # return xc
79
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
+
81
+
82
+ def get_cu_seqlens(text_mask, img_len):
83
+ batch_size = text_mask.shape[0]
84
+ text_len = text_mask.sum(dim=1)
85
+ max_len = text_mask.shape[1] + img_len
86
+
87
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
+
89
+ for i in range(batch_size):
90
+ s = text_len[i] + img_len
91
+ s1 = i * max_len + s
92
+ s2 = (i + 1) * max_len
93
+ cu_seqlens[2 * i + 1] = s1
94
+ cu_seqlens[2 * i + 2] = s2
95
+
96
+ return cu_seqlens
97
+
98
+
99
+ def apply_rotary_emb_transposed(x, freqs_cis):
100
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
+ out = x.float() * cos + x_rotated.float() * sin
104
+ out = out.to(x)
105
+ return out
106
+
107
+
108
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
+ if sageattn is not None:
111
+ x = sageattn(q, k, v, tensor_layout='NHD')
112
+ return x
113
+
114
+ if flash_attn_func is not None:
115
+ x = flash_attn_func(q, k, v)
116
+ return x
117
+
118
+ if xformers_attn_func is not None:
119
+ x = xformers_attn_func(q, k, v)
120
+ return x
121
+
122
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
+ return x
124
+
125
+ B, L, H, C = q.shape
126
+
127
+ q = q.flatten(0, 1)
128
+ k = k.flatten(0, 1)
129
+ v = v.flatten(0, 1)
130
+
131
+ if sageattn_varlen is not None:
132
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
+ elif flash_attn_varlen_func is not None:
134
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
135
+ else:
136
+ raise NotImplementedError('No Attn Installed!')
137
+
138
+ x = x.unflatten(0, (B, L))
139
+
140
+ return x
141
+
142
+
143
+ class HunyuanAttnProcessorFlashAttnDouble:
144
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
145
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
146
+
147
+ query = attn.to_q(hidden_states)
148
+ key = attn.to_k(hidden_states)
149
+ value = attn.to_v(hidden_states)
150
+
151
+ query = query.unflatten(2, (attn.heads, -1))
152
+ key = key.unflatten(2, (attn.heads, -1))
153
+ value = value.unflatten(2, (attn.heads, -1))
154
+
155
+ query = attn.norm_q(query)
156
+ key = attn.norm_k(key)
157
+
158
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
159
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
160
+
161
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
162
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
163
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
164
+
165
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
166
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
167
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
168
+
169
+ encoder_query = attn.norm_added_q(encoder_query)
170
+ encoder_key = attn.norm_added_k(encoder_key)
171
+
172
+ query = torch.cat([query, encoder_query], dim=1)
173
+ key = torch.cat([key, encoder_key], dim=1)
174
+ value = torch.cat([value, encoder_value], dim=1)
175
+
176
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
177
+ hidden_states = hidden_states.flatten(-2)
178
+
179
+ txt_length = encoder_hidden_states.shape[1]
180
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
181
+
182
+ hidden_states = attn.to_out[0](hidden_states)
183
+ hidden_states = attn.to_out[1](hidden_states)
184
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
185
+
186
+ return hidden_states, encoder_hidden_states
187
+
188
+
189
+ class HunyuanAttnProcessorFlashAttnSingle:
190
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
191
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
192
+
193
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
194
+
195
+ query = attn.to_q(hidden_states)
196
+ key = attn.to_k(hidden_states)
197
+ value = attn.to_v(hidden_states)
198
+
199
+ query = query.unflatten(2, (attn.heads, -1))
200
+ key = key.unflatten(2, (attn.heads, -1))
201
+ value = value.unflatten(2, (attn.heads, -1))
202
+
203
+ query = attn.norm_q(query)
204
+ key = attn.norm_k(key)
205
+
206
+ txt_length = encoder_hidden_states.shape[1]
207
+
208
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
209
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
210
+
211
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
212
+ hidden_states = hidden_states.flatten(-2)
213
+
214
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
215
+
216
+ return hidden_states, encoder_hidden_states
217
+
218
+
219
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
220
+ def __init__(self, embedding_dim, pooled_projection_dim):
221
+ super().__init__()
222
+
223
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
224
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
225
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
226
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
227
+
228
+ def forward(self, timestep, guidance, pooled_projection):
229
+ timesteps_proj = self.time_proj(timestep)
230
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
231
+
232
+ guidance_proj = self.time_proj(guidance)
233
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
234
+
235
+ time_guidance_emb = timesteps_emb + guidance_emb
236
+
237
+ pooled_projections = self.text_embedder(pooled_projection)
238
+ conditioning = time_guidance_emb + pooled_projections
239
+
240
+ return conditioning
241
+
242
+
243
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
244
+ def __init__(self, embedding_dim, pooled_projection_dim):
245
+ super().__init__()
246
+
247
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
248
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
249
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
250
+
251
+ def forward(self, timestep, pooled_projection):
252
+ timesteps_proj = self.time_proj(timestep)
253
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
254
+
255
+ pooled_projections = self.text_embedder(pooled_projection)
256
+
257
+ conditioning = timesteps_emb + pooled_projections
258
+
259
+ return conditioning
260
+
261
+
262
+ class HunyuanVideoAdaNorm(nn.Module):
263
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
264
+ super().__init__()
265
+
266
+ out_features = out_features or 2 * in_features
267
+ self.linear = nn.Linear(in_features, out_features)
268
+ self.nonlinearity = nn.SiLU()
269
+
270
+ def forward(
271
+ self, temb: torch.Tensor
272
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
273
+ temb = self.linear(self.nonlinearity(temb))
274
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
275
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
276
+ return gate_msa, gate_mlp
277
+
278
+
279
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ num_attention_heads: int,
283
+ attention_head_dim: int,
284
+ mlp_width_ratio: str = 4.0,
285
+ mlp_drop_rate: float = 0.0,
286
+ attention_bias: bool = True,
287
+ ) -> None:
288
+ super().__init__()
289
+
290
+ hidden_size = num_attention_heads * attention_head_dim
291
+
292
+ self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
293
+ self.attn = Attention(
294
+ query_dim=hidden_size,
295
+ cross_attention_dim=None,
296
+ heads=num_attention_heads,
297
+ dim_head=attention_head_dim,
298
+ bias=attention_bias,
299
+ )
300
+
301
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
302
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
303
+
304
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ temb: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ ) -> torch.Tensor:
312
+ norm_hidden_states = self.norm1(hidden_states)
313
+
314
+ attn_output = self.attn(
315
+ hidden_states=norm_hidden_states,
316
+ encoder_hidden_states=None,
317
+ attention_mask=attention_mask,
318
+ )
319
+
320
+ gate_msa, gate_mlp = self.norm_out(temb)
321
+ hidden_states = hidden_states + attn_output * gate_msa
322
+
323
+ ff_output = self.ff(self.norm2(hidden_states))
324
+ hidden_states = hidden_states + ff_output * gate_mlp
325
+
326
+ return hidden_states
327
+
328
+
329
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
330
+ def __init__(
331
+ self,
332
+ num_attention_heads: int,
333
+ attention_head_dim: int,
334
+ num_layers: int,
335
+ mlp_width_ratio: float = 4.0,
336
+ mlp_drop_rate: float = 0.0,
337
+ attention_bias: bool = True,
338
+ ) -> None:
339
+ super().__init__()
340
+
341
+ self.refiner_blocks = nn.ModuleList(
342
+ [
343
+ HunyuanVideoIndividualTokenRefinerBlock(
344
+ num_attention_heads=num_attention_heads,
345
+ attention_head_dim=attention_head_dim,
346
+ mlp_width_ratio=mlp_width_ratio,
347
+ mlp_drop_rate=mlp_drop_rate,
348
+ attention_bias=attention_bias,
349
+ )
350
+ for _ in range(num_layers)
351
+ ]
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ temb: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ ) -> None:
360
+ self_attn_mask = None
361
+ if attention_mask is not None:
362
+ batch_size = attention_mask.shape[0]
363
+ seq_len = attention_mask.shape[1]
364
+ attention_mask = attention_mask.to(hidden_states.device).bool()
365
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).expand(-1, -1, seq_len, -1)
366
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
367
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
368
+ self_attn_mask[:, :, :, 0] = True
369
+
370
+ for block in self.refiner_blocks:
371
+ hidden_states = block(hidden_states, temb, self_attn_mask)
372
+
373
+ return hidden_states
374
+
375
+
376
+ class HunyuanVideoTokenRefiner(nn.Module):
377
+ def __init__(
378
+ self,
379
+ in_channels: int,
380
+ num_attention_heads: int,
381
+ attention_head_dim: int,
382
+ num_layers: int,
383
+ mlp_ratio: float = 4.0,
384
+ mlp_drop_rate: float = 0.0,
385
+ attention_bias: bool = True,
386
+ ) -> None:
387
+ super().__init__()
388
+
389
+ hidden_size = num_attention_heads * attention_head_dim
390
+
391
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
392
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
393
+ )
394
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
395
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
396
+ num_attention_heads=num_attention_heads,
397
+ attention_head_dim=attention_head_dim,
398
+ num_layers=num_layers,
399
+ mlp_width_ratio=mlp_ratio,
400
+ mlp_drop_rate=mlp_drop_rate,
401
+ attention_bias=attention_bias,
402
+ )
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ timestep: torch.LongTensor,
408
+ attention_mask: Optional[torch.LongTensor] = None,
409
+ ) -> torch.Tensor:
410
+ if attention_mask is None:
411
+ pooled_projections = hidden_states.mean(dim=1)
412
+ else:
413
+ original_dtype = hidden_states.dtype
414
+ mask_float = attention_mask.float().unsqueeze(-1)
415
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
416
+ pooled_projections = pooled_projections.to(original_dtype)
417
+
418
+ temb = self.time_text_embed(timestep, pooled_projections)
419
+ hidden_states = self.proj_in(hidden_states)
420
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
421
+
422
+ return hidden_states
423
+
424
+
425
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
426
+ def __init__(self, rope_dim, theta):
427
+ super().__init__()
428
+ self.DT, self.DY, self.DX = rope_dim
429
+ self.theta = theta
430
+
431
+ @torch.no_grad()
432
+ def get_frequency(self, dim, pos):
433
+ T, H, W = pos.shape
434
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
435
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
436
+ return freqs.cos(), freqs.sin()
437
+
438
+ @torch.no_grad()
439
+ def forward_inner(self, frame_indices, height, width, device):
440
+ GT, GY, GX = torch.meshgrid(
441
+ frame_indices.to(device=device, dtype=torch.float32),
442
+ torch.arange(0, height, device=device, dtype=torch.float32),
443
+ torch.arange(0, width, device=device, dtype=torch.float32),
444
+ indexing="ij"
445
+ )
446
+
447
+ FCT, FST = self.get_frequency(self.DT, GT)
448
+ FCY, FSY = self.get_frequency(self.DY, GY)
449
+ FCX, FSX = self.get_frequency(self.DX, GX)
450
+
451
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
452
+
453
+ return result.to(device)
454
+
455
+ @torch.no_grad()
456
+ def forward(self, frame_indices, height, width, device):
457
+ frame_indices = frame_indices.unbind(0)
458
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
459
+ results = torch.stack(results, dim=0)
460
+ return results
461
+
462
+
463
+ class AdaLayerNormZero(nn.Module):
464
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
465
+ super().__init__()
466
+ self.silu = nn.SiLU()
467
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
468
+ if norm_type == "layer_norm":
469
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
470
+ else:
471
+ raise ValueError(f"unknown norm_type {norm_type}")
472
+
473
+ def forward(
474
+ self,
475
+ x: torch.Tensor,
476
+ emb: Optional[torch.Tensor] = None,
477
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
478
+ emb = emb.unsqueeze(-2)
479
+ emb = self.linear(self.silu(emb))
480
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
481
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
482
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483
+
484
+
485
+ class AdaLayerNormZeroSingle(nn.Module):
486
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
487
+ super().__init__()
488
+
489
+ self.silu = nn.SiLU()
490
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
491
+ if norm_type == "layer_norm":
492
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
493
+ else:
494
+ raise ValueError(f"unknown norm_type {norm_type}")
495
+
496
+ def forward(
497
+ self,
498
+ x: torch.Tensor,
499
+ emb: Optional[torch.Tensor] = None,
500
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
501
+ emb = emb.unsqueeze(-2)
502
+ emb = self.linear(self.silu(emb))
503
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
504
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
505
+ return x, gate_msa
506
+
507
+
508
+ class AdaLayerNormContinuous(nn.Module):
509
+ def __init__(
510
+ self,
511
+ embedding_dim: int,
512
+ conditioning_embedding_dim: int,
513
+ elementwise_affine=True,
514
+ eps=1e-5,
515
+ bias=True,
516
+ norm_type="layer_norm",
517
+ ):
518
+ super().__init__()
519
+ self.silu = nn.SiLU()
520
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
521
+ if norm_type == "layer_norm":
522
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
523
+ else:
524
+ raise ValueError(f"unknown norm_type {norm_type}")
525
+
526
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
527
+ emb = emb.unsqueeze(-2)
528
+ emb = self.linear(self.silu(emb))
529
+ scale, shift = emb.chunk(2, dim=-1)
530
+ x = self.norm(x) * (1 + scale) + shift
531
+ return x
532
+
533
+
534
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
535
+ def __init__(
536
+ self,
537
+ num_attention_heads: int,
538
+ attention_head_dim: int,
539
+ mlp_ratio: float = 4.0,
540
+ qk_norm: str = "rms_norm",
541
+ ) -> None:
542
+ super().__init__()
543
+
544
+ hidden_size = num_attention_heads * attention_head_dim
545
+ mlp_dim = int(hidden_size * mlp_ratio)
546
+
547
+ self.attn = Attention(
548
+ query_dim=hidden_size,
549
+ cross_attention_dim=None,
550
+ dim_head=attention_head_dim,
551
+ heads=num_attention_heads,
552
+ out_dim=hidden_size,
553
+ bias=True,
554
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
555
+ qk_norm=qk_norm,
556
+ eps=1e-6,
557
+ pre_only=True,
558
+ )
559
+
560
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
561
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
562
+ self.act_mlp = nn.GELU(approximate="tanh")
563
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
564
+
565
+ def forward(
566
+ self,
567
+ hidden_states: torch.Tensor,
568
+ encoder_hidden_states: torch.Tensor,
569
+ temb: torch.Tensor,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
572
+ ) -> torch.Tensor:
573
+ text_seq_length = encoder_hidden_states.shape[1]
574
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
575
+
576
+ residual = hidden_states
577
+
578
+ # 1. Input normalization
579
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
580
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
581
+
582
+ norm_hidden_states, norm_encoder_hidden_states = (
583
+ norm_hidden_states[:, :-text_seq_length, :],
584
+ norm_hidden_states[:, -text_seq_length:, :],
585
+ )
586
+
587
+ # 2. Attention
588
+ attn_output, context_attn_output = self.attn(
589
+ hidden_states=norm_hidden_states,
590
+ encoder_hidden_states=norm_encoder_hidden_states,
591
+ attention_mask=attention_mask,
592
+ image_rotary_emb=image_rotary_emb,
593
+ )
594
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
595
+
596
+ # 3. Modulation and residual connection
597
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
598
+ hidden_states = gate * self.proj_out(hidden_states)
599
+ hidden_states = hidden_states + residual
600
+
601
+ hidden_states, encoder_hidden_states = (
602
+ hidden_states[:, :-text_seq_length, :],
603
+ hidden_states[:, -text_seq_length:, :],
604
+ )
605
+ return hidden_states, encoder_hidden_states
606
+
607
+
608
+ class HunyuanVideoTransformerBlock(nn.Module):
609
+ def __init__(
610
+ self,
611
+ num_attention_heads: int,
612
+ attention_head_dim: int,
613
+ mlp_ratio: float,
614
+ qk_norm: str = "rms_norm",
615
+ ) -> None:
616
+ super().__init__()
617
+
618
+ hidden_size = num_attention_heads * attention_head_dim
619
+
620
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
621
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
622
+
623
+ self.attn = Attention(
624
+ query_dim=hidden_size,
625
+ cross_attention_dim=None,
626
+ added_kv_proj_dim=hidden_size,
627
+ dim_head=attention_head_dim,
628
+ heads=num_attention_heads,
629
+ out_dim=hidden_size,
630
+ context_pre_only=False,
631
+ bias=True,
632
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
633
+ qk_norm=qk_norm,
634
+ eps=1e-6,
635
+ )
636
+
637
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
638
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
639
+
640
+ self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
641
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
642
+
643
+ def forward(
644
+ self,
645
+ hidden_states: torch.Tensor,
646
+ encoder_hidden_states: torch.Tensor,
647
+ temb: torch.Tensor,
648
+ attention_mask: Optional[torch.Tensor] = None,
649
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
650
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
651
+ # 1. Input normalization
652
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
653
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
654
+
655
+ # 2. Joint attention
656
+ attn_output, context_attn_output = self.attn(
657
+ hidden_states=norm_hidden_states,
658
+ encoder_hidden_states=norm_encoder_hidden_states,
659
+ attention_mask=attention_mask,
660
+ image_rotary_emb=freqs_cis,
661
+ )
662
+
663
+ # 3. Modulation and residual connection
664
+ hidden_states = hidden_states + attn_output * gate_msa
665
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
666
+
667
+ norm_hidden_states = self.norm2(hidden_states)
668
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
669
+
670
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
671
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
672
+
673
+ # 4. Feed-forward
674
+ ff_output = self.ff(norm_hidden_states)
675
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
676
+
677
+ hidden_states = hidden_states + gate_mlp * ff_output
678
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
679
+
680
+ return hidden_states, encoder_hidden_states
681
+
682
+
683
+ class ClipVisionProjection(nn.Module):
684
+ def __init__(self, in_channels, out_channels):
685
+ super().__init__()
686
+ self.up = nn.Linear(in_channels, out_channels * 3)
687
+ self.down = nn.Linear(out_channels * 3, out_channels)
688
+
689
+ def forward(self, x):
690
+ projected_x = self.down(nn.functional.silu(self.up(x)))
691
+ return projected_x
692
+
693
+
694
+ class HunyuanVideoPatchEmbed(nn.Module):
695
+ def __init__(self, patch_size, in_chans, embed_dim):
696
+ super().__init__()
697
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
698
+
699
+
700
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
701
+ def __init__(self, inner_dim):
702
+ super().__init__()
703
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
704
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
705
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
706
+
707
+ @torch.no_grad()
708
+ def initialize_weight_from_another_conv3d(self, another_layer):
709
+ weight = another_layer.weight.detach().clone()
710
+ bias = another_layer.bias.detach().clone()
711
+
712
+ sd = {
713
+ 'proj.weight': weight.clone(),
714
+ 'proj.bias': bias.clone(),
715
+ 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
716
+ 'proj_2x.bias': bias.clone(),
717
+ 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
718
+ 'proj_4x.bias': bias.clone(),
719
+ }
720
+
721
+ sd = {k: v.clone() for k, v in sd.items()}
722
+
723
+ self.load_state_dict(sd)
724
+ return
725
+
726
+
727
+ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
728
+ @register_to_config
729
+ def __init__(
730
+ self,
731
+ in_channels: int = 16,
732
+ out_channels: int = 16,
733
+ num_attention_heads: int = 24,
734
+ attention_head_dim: int = 128,
735
+ num_layers: int = 20,
736
+ num_single_layers: int = 40,
737
+ num_refiner_layers: int = 2,
738
+ mlp_ratio: float = 4.0,
739
+ patch_size: int = 2,
740
+ patch_size_t: int = 1,
741
+ qk_norm: str = "rms_norm",
742
+ guidance_embeds: bool = True,
743
+ text_embed_dim: int = 4096,
744
+ pooled_projection_dim: int = 768,
745
+ rope_theta: float = 256.0,
746
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
747
+ has_image_proj=False,
748
+ image_proj_dim=1152,
749
+ has_clean_x_embedder=False,
750
+ ) -> None:
751
+ super().__init__()
752
+
753
+ inner_dim = num_attention_heads * attention_head_dim
754
+ out_channels = out_channels or in_channels
755
+
756
+ # 1. Latent and condition embedders
757
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
758
+ self.context_embedder = HunyuanVideoTokenRefiner(
759
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
760
+ )
761
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
762
+
763
+ self.clean_x_embedder = None
764
+ self.image_projection = None
765
+
766
+ # 2. RoPE
767
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
768
+
769
+ # 3. Dual stream transformer blocks
770
+ self.transformer_blocks = nn.ModuleList(
771
+ [
772
+ HunyuanVideoTransformerBlock(
773
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
774
+ )
775
+ for _ in range(num_layers)
776
+ ]
777
+ )
778
+
779
+ # 4. Single stream transformer blocks
780
+ self.single_transformer_blocks = nn.ModuleList(
781
+ [
782
+ HunyuanVideoSingleTransformerBlock(
783
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
784
+ )
785
+ for _ in range(num_single_layers)
786
+ ]
787
+ )
788
+
789
+ # 5. Output projection
790
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
791
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
792
+
793
+ self.inner_dim = inner_dim
794
+ self.use_gradient_checkpointing = False
795
+ self.enable_teacache = False
796
+
797
+ if has_image_proj:
798
+ self.install_image_projection(image_proj_dim)
799
+
800
+ if has_clean_x_embedder:
801
+ self.install_clean_x_embedder()
802
+
803
+ self.high_quality_fp32_output_for_inference = False
804
+
805
+ def install_image_projection(self, in_channels):
806
+ self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
807
+ self.config['has_image_proj'] = True
808
+ self.config['image_proj_dim'] = in_channels
809
+
810
+ def install_clean_x_embedder(self):
811
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
812
+ self.config['has_clean_x_embedder'] = True
813
+
814
+ def enable_gradient_checkpointing(self):
815
+ self.use_gradient_checkpointing = True
816
+ print('self.use_gradient_checkpointing = True')
817
+
818
+ def disable_gradient_checkpointing(self):
819
+ self.use_gradient_checkpointing = False
820
+ print('self.use_gradient_checkpointing = False')
821
+
822
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
823
+ self.enable_teacache = enable_teacache
824
+ self.cnt = 0
825
+ self.num_steps = num_steps
826
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
827
+ self.accumulated_rel_l1_distance = 0
828
+ self.previous_modulated_input = None
829
+ self.previous_residual = None
830
+ self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
831
+
832
+ def gradient_checkpointing_method(self, block, *args):
833
+ if self.use_gradient_checkpointing:
834
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
835
+ else:
836
+ result = block(*args)
837
+ return result
838
+
839
+ def process_input_hidden_states(
840
+ self,
841
+ latents, latent_indices=None,
842
+ clean_latents=None, clean_latent_indices=None,
843
+ clean_latents_2x=None, clean_latent_2x_indices=None,
844
+ clean_latents_4x=None, clean_latent_4x_indices=None
845
+ ):
846
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
847
+ B, C, T, H, W = hidden_states.shape
848
+
849
+ if latent_indices is None:
850
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
851
+
852
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
853
+
854
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
855
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
856
+
857
+ if clean_latents is not None and clean_latent_indices is not None:
858
+ clean_latents = clean_latents.to(hidden_states)
859
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
860
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
861
+
862
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
863
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
864
+
865
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
866
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
867
+
868
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
869
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
870
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
871
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
872
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
873
+
874
+ clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
875
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
876
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
877
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
878
+
879
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
880
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
881
+
882
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
883
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
884
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
885
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
886
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
887
+
888
+ clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
889
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
890
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
891
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
892
+
893
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
894
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
895
+
896
+ return hidden_states, rope_freqs
897
+
898
+ def forward(
899
+ self,
900
+ hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
901
+ latent_indices=None,
902
+ clean_latents=None, clean_latent_indices=None,
903
+ clean_latents_2x=None, clean_latent_2x_indices=None,
904
+ clean_latents_4x=None, clean_latent_4x_indices=None,
905
+ image_embeddings=None,
906
+ attention_kwargs=None, return_dict=True
907
+ ):
908
+
909
+ if attention_kwargs is None:
910
+ attention_kwargs = {}
911
+
912
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
913
+ p, p_t = self.config['patch_size'], self.config['patch_size_t']
914
+ post_patch_num_frames = num_frames // p_t
915
+ post_patch_height = height // p
916
+ post_patch_width = width // p
917
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
918
+
919
+ hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
920
+
921
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
922
+ encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
923
+
924
+ if self.image_projection is not None:
925
+ assert image_embeddings is not None, 'You must use image embeddings!'
926
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
927
+ extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
928
+
929
+ # must cat before (not after) encoder_hidden_states, due to attn masking
930
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
931
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
932
+
933
+ if batch_size == 1:
934
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
935
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
936
+ text_len = encoder_attention_mask.sum().item()
937
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
938
+ attention_mask = None, None, None, None
939
+ else:
940
+ img_seq_len = hidden_states.shape[1]
941
+ txt_seq_len = encoder_hidden_states.shape[1]
942
+
943
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
944
+ cu_seqlens_kv = cu_seqlens_q
945
+ max_seqlen_q = img_seq_len + txt_seq_len
946
+ max_seqlen_kv = max_seqlen_q
947
+
948
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
949
+
950
+ if self.enable_teacache:
951
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
952
+
953
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
954
+ should_calc = True
955
+ self.accumulated_rel_l1_distance = 0
956
+ else:
957
+ curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
958
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
959
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
960
+
961
+ if should_calc:
962
+ self.accumulated_rel_l1_distance = 0
963
+
964
+ self.previous_modulated_input = modulated_inp
965
+ self.cnt += 1
966
+
967
+ if self.cnt == self.num_steps:
968
+ self.cnt = 0
969
+
970
+ if not should_calc:
971
+ hidden_states = hidden_states + self.previous_residual
972
+ else:
973
+ ori_hidden_states = hidden_states.clone()
974
+
975
+ for block_id, block in enumerate(self.transformer_blocks):
976
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
977
+ block,
978
+ hidden_states,
979
+ encoder_hidden_states,
980
+ temb,
981
+ attention_mask,
982
+ rope_freqs
983
+ )
984
+
985
+ for block_id, block in enumerate(self.single_transformer_blocks):
986
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
987
+ block,
988
+ hidden_states,
989
+ encoder_hidden_states,
990
+ temb,
991
+ attention_mask,
992
+ rope_freqs
993
+ )
994
+
995
+ self.previous_residual = hidden_states - ori_hidden_states
996
+ else:
997
+ for block_id, block in enumerate(self.transformer_blocks):
998
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
999
+ block,
1000
+ hidden_states,
1001
+ encoder_hidden_states,
1002
+ temb,
1003
+ attention_mask,
1004
+ rope_freqs
1005
+ )
1006
+
1007
+ for block_id, block in enumerate(self.single_transformer_blocks):
1008
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1009
+ block,
1010
+ hidden_states,
1011
+ encoder_hidden_states,
1012
+ temb,
1013
+ attention_mask,
1014
+ rope_freqs
1015
+ )
1016
+
1017
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1018
+
1019
+ hidden_states = hidden_states[:, -original_context_length:, :]
1020
+
1021
+ if self.high_quality_fp32_output_for_inference:
1022
+ hidden_states = hidden_states.to(dtype=torch.float32)
1023
+ if self.proj_out.weight.dtype != torch.float32:
1024
+ self.proj_out.to(dtype=torch.float32)
1025
+
1026
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1027
+
1028
+ hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1029
+ t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1030
+ pt=p_t, ph=p, pw=p)
1031
+
1032
+ if return_dict:
1033
+ return Transformer2DModelOutput(sample=hidden_states)
1034
+
1035
+ return hidden_states,
diffusers_helper/pipelines/k_diffusion_hunyuan.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
5
+ from diffusers_helper.k_diffusion.wrapper import fm_wrapper
6
+ from diffusers_helper.utils import repeat_to_batch_size
7
+
8
+
9
+ def flux_time_shift(t, mu=1.15, sigma=1.0):
10
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
11
+
12
+
13
+ def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
14
+ k = (y2 - y1) / (x2 - x1)
15
+ b = y1 - k * x1
16
+ mu = k * context_length + b
17
+ mu = min(mu, math.log(exp_max))
18
+ return mu
19
+
20
+
21
+ def get_flux_sigmas_from_mu(n, mu):
22
+ sigmas = torch.linspace(1, 0, steps=n + 1)
23
+ sigmas = flux_time_shift(sigmas, mu=mu)
24
+ return sigmas
25
+
26
+
27
+ @torch.inference_mode()
28
+ def sample_hunyuan(
29
+ transformer,
30
+ sampler='unipc',
31
+ initial_latent=None,
32
+ concat_latent=None,
33
+ strength=1.0,
34
+ width=512,
35
+ height=512,
36
+ frames=16,
37
+ real_guidance_scale=1.0,
38
+ distilled_guidance_scale=6.0,
39
+ guidance_rescale=0.0,
40
+ shift=None,
41
+ num_inference_steps=25,
42
+ batch_size=None,
43
+ generator=None,
44
+ prompt_embeds=None,
45
+ prompt_embeds_mask=None,
46
+ prompt_poolers=None,
47
+ negative_prompt_embeds=None,
48
+ negative_prompt_embeds_mask=None,
49
+ negative_prompt_poolers=None,
50
+ dtype=torch.bfloat16,
51
+ device=None,
52
+ negative_kwargs=None,
53
+ callback=None,
54
+ **kwargs,
55
+ ):
56
+ device = device or transformer.device
57
+
58
+ if batch_size is None:
59
+ batch_size = int(prompt_embeds.shape[0])
60
+
61
+ latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
62
+
63
+ B, C, T, H, W = latents.shape
64
+ seq_length = T * H * W // 4
65
+
66
+ if shift is None:
67
+ mu = calculate_flux_mu(seq_length, exp_max=7.0)
68
+ else:
69
+ mu = math.log(shift)
70
+
71
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
72
+
73
+ k_model = fm_wrapper(transformer)
74
+
75
+ if initial_latent is not None:
76
+ sigmas = sigmas * strength
77
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
78
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
79
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
80
+
81
+ if concat_latent is not None:
82
+ concat_latent = concat_latent.to(latents)
83
+
84
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
85
+
86
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
87
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
88
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
89
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
90
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
91
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
92
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
93
+
94
+ sampler_kwargs = dict(
95
+ dtype=dtype,
96
+ cfg_scale=real_guidance_scale,
97
+ cfg_rescale=guidance_rescale,
98
+ concat_latent=concat_latent,
99
+ positive=dict(
100
+ pooled_projections=prompt_poolers,
101
+ encoder_hidden_states=prompt_embeds,
102
+ encoder_attention_mask=prompt_embeds_mask,
103
+ guidance=distilled_guidance,
104
+ **kwargs,
105
+ ),
106
+ negative=dict(
107
+ pooled_projections=negative_prompt_poolers,
108
+ encoder_hidden_states=negative_prompt_embeds,
109
+ encoder_attention_mask=negative_prompt_embeds_mask,
110
+ guidance=distilled_guidance,
111
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
112
+ )
113
+ )
114
+
115
+ if sampler == 'unipc':
116
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
117
+ else:
118
+ raise NotImplementedError(f'Sampler {sampler} is not supported.')
119
+
120
+ return results
diffusers_helper/thread_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from threading import Thread, Lock
4
+
5
+
6
+ class Listener:
7
+ task_queue = []
8
+ lock = Lock()
9
+ thread = None
10
+
11
+ @classmethod
12
+ def _process_tasks(cls):
13
+ while True:
14
+ task = None
15
+ with cls.lock:
16
+ if cls.task_queue:
17
+ task = cls.task_queue.pop(0)
18
+
19
+ if task is None:
20
+ time.sleep(0.001)
21
+ continue
22
+
23
+ func, args, kwargs = task
24
+ try:
25
+ func(*args, **kwargs)
26
+ except Exception as e:
27
+ print(f"Error in listener thread: {e}")
28
+
29
+ @classmethod
30
+ def add_task(cls, func, *args, **kwargs):
31
+ with cls.lock:
32
+ cls.task_queue.append((func, args, kwargs))
33
+
34
+ if cls.thread is None:
35
+ cls.thread = Thread(target=cls._process_tasks, daemon=True)
36
+ cls.thread.start()
37
+
38
+
39
+ def async_run(func, *args, **kwargs):
40
+ Listener.add_task(func, *args, **kwargs)
41
+
42
+
43
+ class FIFOQueue:
44
+ def __init__(self):
45
+ self.queue = []
46
+ self.lock = Lock()
47
+
48
+ def push(self, item):
49
+ with self.lock:
50
+ self.queue.append(item)
51
+
52
+ def pop(self):
53
+ with self.lock:
54
+ if self.queue:
55
+ return self.queue.pop(0)
56
+ return None
57
+
58
+ def top(self):
59
+ with self.lock:
60
+ if self.queue:
61
+ return self.queue[0]
62
+ return None
63
+
64
+ def next(self):
65
+ while True:
66
+ with self.lock:
67
+ if self.queue:
68
+ return self.queue.pop(0)
69
+
70
+ time.sleep(0.001)
71
+
72
+
73
+ class AsyncStream:
74
+ def __init__(self):
75
+ self.input_queue = FIFOQueue()
76
+ self.output_queue = FIFOQueue()
diffusers_helper/utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top:top + target_height, left:left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, 'rt', encoding='utf-8') as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = ['.lora_B.', '__zero__']
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024 ** 2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, '_forward_inside_frozen_module'):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError('No file to resume!')
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(', ')
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ', '.join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
280
+ return x
281
+
282
+
283
+ def save_bcthw_as_png(x, output_filename):
284
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286
+ x = x.detach().cpu().to(torch.uint8)
287
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288
+ torchvision.io.write_png(x, output_filename)
289
+ return output_filename
290
+
291
+
292
+ def save_bchw_as_png(x, output_filename):
293
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295
+ x = x.detach().cpu().to(torch.uint8)
296
+ x = einops.rearrange(x, 'b c h w -> c h (b w)')
297
+ torchvision.io.write_png(x, output_filename)
298
+ return output_filename
299
+
300
+
301
+ def add_tensors_with_padding(tensor1, tensor2):
302
+ if tensor1.shape == tensor2.shape:
303
+ return tensor1 + tensor2
304
+
305
+ shape1 = tensor1.shape
306
+ shape2 = tensor2.shape
307
+
308
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309
+
310
+ padded_tensor1 = torch.zeros(new_shape)
311
+ padded_tensor2 = torch.zeros(new_shape)
312
+
313
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315
+
316
+ result = padded_tensor1 + padded_tensor2
317
+ return result
318
+
319
+
320
+ def print_free_mem():
321
+ torch.cuda.empty_cache()
322
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
323
+ free_mem_mb = free_mem / (1024 ** 2)
324
+ total_mem_mb = total_mem / (1024 ** 2)
325
+ print(f"Free memory: {free_mem_mb:.2f} MB")
326
+ print(f"Total memory: {total_mem_mb:.2f} MB")
327
+ return
328
+
329
+
330
+ def print_gpu_parameters(device, state_dict, log_count=1):
331
+ summary = {"device": device, "keys_count": len(state_dict)}
332
+
333
+ logged_params = {}
334
+ for i, (key, tensor) in enumerate(state_dict.items()):
335
+ if i >= log_count:
336
+ break
337
+ logged_params[key] = tensor.flatten()[:3].tolist()
338
+
339
+ summary["params"] = logged_params
340
+
341
+ print(str(summary))
342
+ return
343
+
344
+
345
+ def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346
+ from PIL import Image, ImageDraw, ImageFont
347
+
348
+ txt = Image.new("RGB", (width, height), color="white")
349
+ draw = ImageDraw.Draw(txt)
350
+ font = ImageFont.truetype(font_path, size=size)
351
+
352
+ if text == '':
353
+ return np.array(txt)
354
+
355
+ # Split text into lines that fit within the image width
356
+ lines = []
357
+ words = text.split()
358
+ current_line = words[0]
359
+
360
+ for word in words[1:]:
361
+ line_with_word = f"{current_line} {word}"
362
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363
+ current_line = line_with_word
364
+ else:
365
+ lines.append(current_line)
366
+ current_line = word
367
+
368
+ lines.append(current_line)
369
+
370
+ # Draw the text line by line
371
+ y = 0
372
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
373
+
374
+ for line in lines:
375
+ if y + line_height > height:
376
+ break # stop drawing if the next line will be outside the image
377
+ draw.text((0, y), line, fill="black", font=font)
378
+ y += line_height
379
+
380
+ return np.array(txt)
381
+
382
+
383
+ def blue_mark(x):
384
+ x = x.copy()
385
+ c = x[:, :, 2]
386
+ b = cv2.blur(c, (9, 9))
387
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388
+ return x
389
+
390
+
391
+ def green_mark(x):
392
+ x = x.copy()
393
+ x[:, :, 2] = -1
394
+ x[:, :, 0] = -1
395
+ return x
396
+
397
+
398
+ def frame_mark(x):
399
+ x = x.copy()
400
+ x[:64] = -1
401
+ x[-64:] = -1
402
+ x[:, :8] = 1
403
+ x[:, -8:] = 1
404
+ return x
405
+
406
+
407
+ @torch.inference_mode()
408
+ def pytorch2numpy(imgs):
409
+ results = []
410
+ for x in imgs:
411
+ y = x.movedim(0, -1)
412
+ y = y * 127.5 + 127.5
413
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414
+ results.append(y)
415
+ return results
416
+
417
+
418
+ @torch.inference_mode()
419
+ def numpy2pytorch(imgs):
420
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421
+ h = h.movedim(-1, 1)
422
+ return h
423
+
424
+
425
+ @torch.no_grad()
426
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
427
+ if zero_out:
428
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429
+ else:
430
+ return torch.cat([x, x[:count]], dim=0)
431
+
432
+
433
+ def weighted_mse(a, b, weight):
434
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435
+
436
+
437
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438
+ x = (x - x_min) / (x_max - x_min)
439
+ x = max(0.0, min(x, 1.0))
440
+ x = x ** sigma
441
+ return y_min + x * (y_max - y_min)
442
+
443
+
444
+ def expand_to_dims(x, target_dims):
445
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446
+
447
+
448
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449
+ if tensor is None:
450
+ return None
451
+
452
+ first_dim = tensor.shape[0]
453
+
454
+ if first_dim == batch_size:
455
+ return tensor
456
+
457
+ if batch_size % first_dim != 0:
458
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459
+
460
+ repeat_times = batch_size // first_dim
461
+
462
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463
+
464
+
465
+ def dim5(x):
466
+ return expand_to_dims(x, 5)
467
+
468
+
469
+ def dim4(x):
470
+ return expand_to_dims(x, 4)
471
+
472
+
473
+ def dim3(x):
474
+ return expand_to_dims(x, 3)
475
+
476
+
477
+ def crop_or_pad_yield_mask(x, length):
478
+ B, F, C = x.shape
479
+ device = x.device
480
+ dtype = x.dtype
481
+
482
+ if F < length:
483
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
484
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485
+ y[:, :F, :] = x
486
+ mask[:, :F] = True
487
+ return y, mask
488
+
489
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490
+
491
+
492
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
493
+ original_length = int(x.shape[dim])
494
+
495
+ if original_length >= minimal_length:
496
+ return x
497
+
498
+ if zero_pad:
499
+ padding_shape = list(x.shape)
500
+ padding_shape[dim] = minimal_length - original_length
501
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502
+ else:
503
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504
+ last_element = x[idx]
505
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506
+
507
+ return torch.cat([x, padding], dim=dim)
508
+
509
+
510
+ def lazy_positional_encoding(t, repeats=None):
511
+ if not isinstance(t, list):
512
+ t = [t]
513
+
514
+ from diffusers.models.embeddings import get_timestep_embedding
515
+
516
+ te = torch.tensor(t)
517
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518
+
519
+ if repeats is None:
520
+ return te
521
+
522
+ te = te[:, None, :].expand(-1, repeats, -1)
523
+
524
+ return te
525
+
526
+
527
+ def state_dict_offset_merge(A, B, C=None):
528
+ result = {}
529
+ keys = A.keys()
530
+
531
+ for key in keys:
532
+ A_value = A[key]
533
+ B_value = B[key].to(A_value)
534
+
535
+ if C is None:
536
+ result[key] = A_value + B_value
537
+ else:
538
+ C_value = C[key].to(A_value)
539
+ result[key] = A_value + B_value - C_value
540
+
541
+ return result
542
+
543
+
544
+ def state_dict_weighted_merge(state_dicts, weights):
545
+ if len(state_dicts) != len(weights):
546
+ raise ValueError("Number of state dictionaries must match number of weights")
547
+
548
+ if not state_dicts:
549
+ return {}
550
+
551
+ total_weight = sum(weights)
552
+
553
+ if total_weight == 0:
554
+ raise ValueError("Sum of weights cannot be zero")
555
+
556
+ normalized_weights = [w / total_weight for w in weights]
557
+
558
+ keys = state_dicts[0].keys()
559
+ result = {}
560
+
561
+ for key in keys:
562
+ result[key] = state_dicts[0][key] * normalized_weights[0]
563
+
564
+ for i in range(1, len(state_dicts)):
565
+ state_dict_value = state_dicts[i][key].to(result[key])
566
+ result[key] += state_dict_value * normalized_weights[i]
567
+
568
+ return result
569
+
570
+
571
+ def group_files_by_folder(all_files):
572
+ grouped_files = {}
573
+
574
+ for file in all_files:
575
+ folder_name = os.path.basename(os.path.dirname(file))
576
+ if folder_name not in grouped_files:
577
+ grouped_files[folder_name] = []
578
+ grouped_files[folder_name].append(file)
579
+
580
+ list_of_lists = list(grouped_files.values())
581
+ return list_of_lists
582
+
583
+
584
+ def generate_timestamp():
585
+ now = datetime.datetime.now()
586
+ timestamp = now.strftime('%y%m%d_%H%M%S')
587
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
588
+ random_number = random.randint(0, 9999)
589
+ return f"{timestamp}_{milliseconds}_{random_number}"
590
+
591
+
592
+ def write_PIL_image_with_png_info(image, metadata, path):
593
+ from PIL.PngImagePlugin import PngInfo
594
+
595
+ png_info = PngInfo()
596
+ for key, value in metadata.items():
597
+ png_info.add_text(key, value)
598
+
599
+ image.save(path, "PNG", pnginfo=png_info)
600
+ return image
601
+
602
+
603
+ def torch_safe_save(content, path):
604
+ torch.save(content, path + '_tmp')
605
+ os.replace(path + '_tmp', path)
606
+ return path
607
+
608
+
609
+ def move_optimizer_to_device(optimizer, device):
610
+ for state in optimizer.state.values():
611
+ for k, v in state.items():
612
+ if isinstance(v, torch.Tensor):
613
+ state[k] = v.to(device)
img_examples/Example1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a906a1d14d1699f67ca54865c7aa5857e55246f4ec63bbaf3edcf359e73bebd1
3
+ size 240647
img_examples/Example1.png ADDED

Git LFS Details

  • SHA256: a057c160bcf3ecfa41d150ec9550423f87efefb9a9793420fea382760daff98b
  • Pointer size: 131 Bytes
  • Size of remote file: 513 kB
img_examples/Example2.webp ADDED

Git LFS Details

  • SHA256: 736480a5f8d043eacad5758f0e80b427aabfa4d98839769615ee61f3fda9f77e
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
img_examples/Example3.jpg ADDED

Git LFS Details

  • SHA256: b1a9be93d2f117d687e08c91c043e67598bdb7c44f5c932f18a3026790fb82fa
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
img_examples/Example4.webp ADDED

Git LFS Details

  • SHA256: dd4e7ef35f4cfc8d44ff97f38b68ba7cc248ad5b54c89f8525f5046508f7c4a3
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
requirements.txt CHANGED
@@ -1,6 +1,24 @@
1
- diffusers
2
- transformers
3
- torch
4
- accelerate
5
- opencv-python
6
- numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.7.0
2
+ diffusers==0.33.1
3
+ transformers==4.52.4
4
+ sentencepiece==0.2.0
5
+ pillow==11.2.1
6
+ av==12.1.0
7
+ numpy==1.26.2
8
+ scipy==1.12.0
9
+ requests==2.32.4
10
+ torchsde==0.2.6
11
+ torch>=2.0.0
12
+ torchvision
13
+ torchaudio
14
+ einops
15
+ opencv-contrib-python
16
+ safetensors
17
+ huggingface_hub
18
+ decord
19
+ imageio_ffmpeg==0.6.0
20
+ sageattention==1.0.6
21
+ xformers==0.0.29.post3
22
+ bitsandbytes==0.46.0
23
+ pillow-heif==0.22.0
24
+ spaces[security]