fps-gaon / generation_core.py
rahul7star's picture
Migrated from GitHub
1030ba2 verified
import torch
import traceback
import einops
import numpy as np
import os
import threading
import json
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from diffusers_helper.hunyuan import (
encode_prompt_conds,
vae_decode,
vae_encode,
vae_decode_fake,
)
from diffusers_helper.utils import (
save_bcthw_as_mp4,
crop_or_pad_yield_mask,
soft_append_bcthw,
resize_and_center_crop,
generate_timestamp,
)
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import (
unload_complete_models,
load_model_as_complete,
move_model_to_device_with_memory_preservation,
offload_model_from_device_for_memory_preservation,
fake_diffusers_current_device,
gpu,
)
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket
from diffusers_helper.gradio.progress_bar import make_progress_bar_html
from ui import metadata as metadata_manager
@torch.no_grad()
def worker(
# --- Task I/O & Identity ---
task_id,
input_image,
output_folder,
output_queue_ref,
# --- Creative Parameters (The "Recipe") ---
prompt,
n_prompt,
seed,
total_second_length,
steps,
cfg,
gs,
gs_final,
gs_schedule_active,
rs,
preview_frequency,
segments_to_decode_csv,
# --- Environment & Debug Parameters ---
latent_window_size,
gpu_memory_preservation,
use_teacache,
use_fp32_transformer_output,
mp4_crf,
# --- Model & System Objects (Passed from main app) ---
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
vae,
feature_extractor,
image_encoder,
transformer,
high_vram,
# --- Control Flow ---
abort_event: threading.Event = None,
):
outputs_folder = (
os.path.expanduser(output_folder) if output_folder else "./outputs/"
)
os.makedirs(outputs_folder, exist_ok=True)
# --- Gemini: do not touch - "secret sauce"
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))
job_id = f"{generate_timestamp()}_task{task_id}"
output_queue_ref.push(
(
"progress",
(
task_id,
None,
f"Total Segments: {total_latent_sections}",
make_progress_bar_html(0, "Starting ..."),
),
)
)
# ---
parsed_segments_to_decode_set = set()
if segments_to_decode_csv:
try:
parsed_segments_to_decode_set = {
int(s.strip()) for s in segments_to_decode_csv.split(",") if s.strip()
}
except ValueError:
print(
f"Task {task_id}: Warning - Could not parse 'Segments to Decode CSV': \"{segments_to_decode_csv}\"."
)
final_output_filename = None
success = False
initial_gs_from_ui = gs
gs_final_value_for_schedule = (
gs_final if gs_final is not None else initial_gs_from_ui
)
original_fp32_setting = transformer.high_quality_fp32_output_for_inference
transformer.high_quality_fp32_output_for_inference = use_fp32_transformer_output
print(
f"Task {task_id}: transformer.high_quality_fp32_output_for_inference set to {use_fp32_transformer_output}"
)
try:
if not isinstance(input_image, np.ndarray):
raise ValueError(f"Task {task_id}: input_image is not a NumPy array.")
output_queue_ref.push(
(
"progress",
(
task_id,
None,
f"Total Segments: {total_latent_sections}",
make_progress_bar_html(0, "Image processing ..."),
),
)
)
if input_image.shape[-1] == 4:
pil_img = Image.fromarray(input_image)
input_image = np.array(pil_img.convert("RGB"))
H, W, C = input_image.shape
if C != 3:
raise ValueError(
f"Task {task_id}: Input image must be RGB, found {C} channels."
)
height, width = find_nearest_bucket(H, W, resolution=640)
input_image_np = resize_and_center_crop(
input_image, target_width=width, target_height=height
)
metadata_obj = PngInfo()
params_to_save_in_metadata = {
"prompt": prompt,
"n_prompt": n_prompt,
"seed": seed,
"total_second_length": total_second_length,
"steps": steps,
"cfg": cfg,
"gs": gs,
"gs_final": gs_final,
"gs_schedule_active": gs_schedule_active,
"rs": rs,
"preview_frequency": preview_frequency,
"segments_to_decode_csv": segments_to_decode_csv,
}
metadata_obj.add_text("parameters", json.dumps(params_to_save_in_metadata))
initial_image_with_params_path = os.path.join(
outputs_folder, f"{job_id}_initial_image_with_params.png"
)
try:
Image.fromarray(input_image_np).save(
initial_image_with_params_path, pnginfo=metadata_obj
)
except Exception as e_png:
print(
f"Task {task_id}: WARNING - Failed to save initial image with parameters: {e_png}"
)
# --- Gemini: do not touch - "secret sauce"
if not high_vram:
unload_complete_models(
text_encoder, text_encoder_2, image_encoder, vae, transformer
)
output_queue_ref.push(
(
"progress",
(
task_id,
None,
f"Total Segments: {total_latent_sections}",
make_progress_bar_html(0, "Text encoding ..."),
),
)
)
if not high_vram:
fake_diffusers_current_device(text_encoder, gpu)
load_model_as_complete(text_encoder_2, target_device=gpu)
llama_vec, clip_l_pooler = encode_prompt_conds(
prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
)
if cfg == 1:
llama_vec_n, clip_l_pooler_n = torch.zeros_like(
llama_vec
), torch.zeros_like(clip_l_pooler)
else:
llama_vec_n, clip_l_pooler_n = encode_prompt_conds(
n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
)
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(
llama_vec_n, length=512
)
input_image_pt = (
torch.from_numpy(input_image_np).float().permute(2, 0, 1).unsqueeze(0)
/ 127.5
- 1.0
)
input_image_pt = input_image_pt[:, :, None, :, :]
output_queue_ref.push(
(
"progress",
(
task_id,
None,
f"Total Segments: {total_latent_sections}",
make_progress_bar_html(0, "VAE encoding ..."),
),
)
)
if not high_vram:
load_model_as_complete(vae, target_device=gpu)
start_latent = vae_encode(input_image_pt, vae)
output_queue_ref.push(
(
"progress",
(
task_id,
None,
f"Total Segments: {total_latent_sections}",
make_progress_bar_html(0, "CLIP Vision encoding ..."),
),
)
)
if not high_vram:
load_model_as_complete(image_encoder, target_device=gpu)
image_encoder_output = hf_clip_vision_encode(
input_image_np, feature_extractor, image_encoder
)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
(
llama_vec,
llama_vec_n,
clip_l_pooler,
clip_l_pooler_n,
image_encoder_last_hidden_state,
) = [
t.to(transformer.dtype)
for t in [
llama_vec,
llama_vec_n,
clip_l_pooler,
clip_l_pooler_n,
image_encoder_last_hidden_state,
]
]
output_queue_ref.push(
(
"progress",
(
task_id,
None,
f"Total Segments: {total_latent_sections}",
make_progress_bar_html(0, "Start sampling ..."),
),
)
)
rnd = torch.Generator(device="cpu").manual_seed(int(seed))
num_frames = latent_window_size * 4 - 3
# overlapped_frames = num_frames
history_latents = torch.zeros(
size=(1, 16, 1 + 2 + 16, height // 8, width // 8),
dtype=torch.float32,
device="cpu",
)
history_pixels = None
total_generated_latent_frames = 0
latent_paddings = list(reversed(range(total_latent_sections)))
if total_latent_sections > 4:
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
# for latent_padding_iteration, latent_padding in enumerate(latent_paddings):
# if abort_event and abort_event.is_set(): raise KeyboardInterrupt("Abort signal received.")
# is_last_section = (latent_padding == 0)
# latent_padding_size = latent_padding * latent_window_size
# print(f'Task {task_id}: Seg {latent_padding_iteration + 1}/{total_latent_sections} (lp_val={latent_padding}), last_loop_seg={is_last_section}')
# ^ our code | v Flash code
for latent_padding_iteration, latent_padding in enumerate(latent_paddings):
if abort_event and abort_event.is_set():
raise KeyboardInterrupt("Abort signal received.")
is_last_section = latent_padding == 0
latent_padding_size = latent_padding * latent_window_size
# Added for consistent 1-indexed segment number for loop segments
current_loop_segment_number = latent_padding_iteration + 1
print(
f"Task {task_id}: Seg {current_loop_segment_number}/{total_latent_sections} (lp_val={latent_padding}), last_loop_seg={is_last_section}"
)
indices = torch.arange(
0,
sum([1, latent_padding_size, latent_window_size, 1, 2, 16]),
device="cpu",
).unsqueeze(0)
(
clean_latent_indices_pre,
_,
latent_indices,
clean_latent_indices_post,
clean_latent_2x_indices,
clean_latent_4x_indices,
) = indices.split(
[1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1
)
clean_latents_pre = start_latent.to(
history_latents.device, dtype=history_latents.dtype
)
clean_latent_indices = torch.cat(
[clean_latent_indices_pre, clean_latent_indices_post], dim=1
)
# current_history_depth_for_clean_split = history_latents.shape[2]; needed_depth_for_clean_split = 1 + 2 + 16
# history_latents_for_clean_split = history_latents
# if current_history_depth_for_clean_split < needed_depth_for_clean_split:
# padding_needed = needed_depth_for_clean_split - current_history_depth_for_clean_split
# pad_tensor = torch.zeros(history_latents.shape[0], history_latents.shape[1], padding_needed, history_latents.shape[3], history_latents.shape[4], dtype=history_latents.dtype, device=history_latents.device)
# history_latents_for_clean_split = torch.cat((history_latents, pad_tensor), dim=2)
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[
:, :, : 1 + 2 + 16, :, :
].split([1, 2, 16], dim=2)
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
if not high_vram:
unload_complete_models()
move_model_to_device_with_memory_preservation(
transformer,
target_device=gpu,
preserved_memory_gb=gpu_memory_preservation,
)
transformer.initialize_teacache(
enable_teacache=use_teacache, num_steps=steps
)
def callback_diffusion_step(d):
if abort_event and abort_event.is_set():
raise KeyboardInterrupt("Abort signal received during sampling.")
current_diffusion_step = d["i"] + 1
is_first_step = current_diffusion_step == 1
is_last_step = current_diffusion_step == steps
is_preview_step = preview_frequency > 0 and (
current_diffusion_step % preview_frequency == 0
)
if not (is_first_step or is_last_step or is_preview_step):
return
preview_latent = d["denoised"]
preview_img_np = vae_decode_fake(preview_latent)
preview_img_np = (
(preview_img_np * 255.0)
.detach()
.cpu()
.numpy()
.clip(0, 255)
.astype(np.uint8)
)
preview_img_np = einops.rearrange(
preview_img_np, "b c t h w -> (b h) (t w) c"
)
# percentage = int(100.0 * current_diffusion_step / steps)
# hint = f'Segment {latent_padding_iteration + 1}, Sampling {current_diffusion_step}/{steps}'
# current_video_frames_count = history_pixels.shape[2] if history_pixels is not None else 0
# desc = f'Task {task_id}: Vid Frames: {current_video_frames_count}, Len: {current_video_frames_count / 30 :.2f}s. Seg {latent_padding_iteration + 1}/{total_latent_sections}. Extending...'
# output_queue_ref.push(('progress', (task_id, preview_img_np, desc, make_progress_bar_html(percentage, hint))))
# ^ our code | v Flash code
percentage = int(100.0 * current_diffusion_step / steps)
hint = f"Segment {current_loop_segment_number}, Sampling {current_diffusion_step}/{steps}" # Updated hint
current_video_frames_count = (
history_pixels.shape[2] if history_pixels is not None else 0
)
desc = f"Task {task_id}: Vid Frames: {current_video_frames_count}, Len: {current_video_frames_count / 30 :.2f}s. Seg {current_loop_segment_number}/{total_latent_sections}. Extending..." # Updated desc
output_queue_ref.push(
(
"progress",
(
task_id,
preview_img_np,
desc,
make_progress_bar_html(percentage, hint),
),
)
)
current_segment_gs_to_use = initial_gs_from_ui
if gs_schedule_active and total_latent_sections > 1:
progress_for_gs = (
latent_padding_iteration / (total_latent_sections - 1)
if total_latent_sections > 1
else 0
)
current_segment_gs_to_use = (
initial_gs_from_ui
+ (gs_final_value_for_schedule - initial_gs_from_ui)
* progress_for_gs
)
generated_latents = sample_hunyuan(
transformer=transformer,
sampler="unipc",
width=width,
height=height,
frames=num_frames,
real_guidance_scale=cfg,
distilled_guidance_scale=current_segment_gs_to_use,
guidance_rescale=rs,
num_inference_steps=steps,
generator=rnd,
prompt_embeds=llama_vec.to(transformer.device),
prompt_embeds_mask=llama_attention_mask.to(transformer.device),
prompt_poolers=clip_l_pooler.to(transformer.device),
negative_prompt_embeds=llama_vec_n.to(transformer.device),
negative_prompt_embeds_mask=llama_attention_mask_n.to(
transformer.device
),
negative_prompt_poolers=clip_l_pooler_n.to(transformer.device),
device=transformer.device,
dtype=transformer.dtype,
image_embeddings=image_encoder_last_hidden_state.to(transformer.device),
latent_indices=latent_indices.to(transformer.device),
clean_latents=clean_latents.to(
transformer.device, dtype=transformer.dtype
),
clean_latent_indices=clean_latent_indices.to(transformer.device),
clean_latents_2x=clean_latents_2x.to(
transformer.device, dtype=transformer.dtype
),
clean_latent_2x_indices=clean_latent_2x_indices.to(transformer.device),
clean_latents_4x=clean_latents_4x.to(
transformer.device, dtype=transformer.dtype
),
clean_latent_4x_indices=clean_latent_4x_indices.to(transformer.device),
callback=callback_diffusion_step,
)
if is_last_section:
generated_latents = torch.cat(
[start_latent.to(generated_latents), generated_latents], dim=2
)
total_generated_latent_frames += int(generated_latents.shape[2])
history_latents = torch.cat(
[generated_latents.to(history_latents), history_latents], dim=2
)
if not high_vram:
offload_model_from_device_for_memory_preservation(
transformer, target_device=gpu, preserved_memory_gb=8
)
load_model_as_complete(vae, target_device=gpu)
real_history_latents = history_latents[
:, :, :total_generated_latent_frames, :, :
]
if history_pixels is None:
history_pixels = vae_decode(real_history_latents, vae).cpu()
else:
section_latent_frames = (
(latent_window_size * 2 + 1)
if is_last_section
else (latent_window_size * 2)
)
overlapped_frames = latent_window_size * 4 - 3
current_pixels = vae_decode(
real_history_latents[:, :, :section_latent_frames], vae
).cpu()
history_pixels = soft_append_bcthw(
current_pixels, history_pixels, overlapped_frames
)
if not high_vram:
unload_complete_models()
current_video_frame_count = history_pixels.shape[2]
# --- Gemini start again
# # Skip writing preview mp4 for this segment logic
# should_save_mp4_this_iteration = False
# current_segment_1_indexed = latent_padding_iteration # + 1
# if (latent_padding_iteration == 0) or is_last_section or (parsed_segments_to_decode_set and current_segment_1_indexed in parsed_segments_to_decode_set):
# should_save_mp4_this_iteration = True
# if should_save_mp4_this_iteration:
# segment_mp4_filename = os.path.join(outputs_folder, f'{job_id}_segment_{latent_padding_iteration}_frames_{current_video_frame_count}.mp4')
# save_bcthw_as_mp4(history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf)
# final_output_filename = segment_mp4_filename
# print(f"Task {task_id}: SAVED MP4 for segment {latent_padding_iteration} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}")
# output_queue_ref.push(('file', (task_id, segment_mp4_filename, f"Segment {latent_padding_iteration} MP4 saved ({current_video_frame_count} frames)")))
# else:
# print(f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_segment_1_indexed}.")
# if is_last_section: success = True; break
# --- Gemini start again
# ^ original code | v Flash code
# # Skip writing preview mp4 for this segment logic
# should_save_mp4_this_iteration = False
# # Use latent_padding_iteration directly here, as it's the 0-indexed loop counter
# current_segment_index = latent_padding_iteration
# # Condition 1: Always save the first segment (index 0)
# if current_segment_index == 0:
# should_save_mp4_this_iteration = True
# # Condition 2: Always save the last segment
# elif is_last_section:
# should_save_mp4_this_iteration = True
# # Condition 3: Save if the current segment index is in the parsed set
# elif parsed_segments_to_decode_set and (current_segment_index + 1) in parsed_segments_to_decode_set:
# # Add 1 here if segments_to_decode_csv assumes 1-based indexing for user input
# should_save_mp4_this_iteration = True
# # Condition 4: Save based on preview_frequency, if enabled (preview_frequency > 0)
# elif preview_frequency > 0 and current_segment_index % preview_frequency == 0:
# should_save_mp4_this_iteration = True
# if should_save_mp4_this_iteration:
# segment_mp4_filename = os.path.join(outputs_folder, f'{job_id}_segment_{latent_padding_iteration}_frames_{current_video_frame_count}.mp4')
# save_bcthw_as_mp4(history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf)
# final_output_filename = segment_mp4_filename
# print(f"Task {task_id}: SAVED MP4 for segment {latent_padding_iteration} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}")
# output_queue_ref.push(('file', (task_id, segment_mp4_filename, f"Segment {latent_padding_iteration} MP4 saved ({current_video_frame_count} frames)")))
# else:
# print(f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_segment_index}.")
# Determine if we should save an intermediate MP4 for this loop segment
should_save_mp4_this_iteration = False
# Condition 1: Always save the last segment of the loop
if is_last_section:
should_save_mp4_this_iteration = True
# Condition 2: Save if the current loop segment number is explicitly in the parsed set
elif (
parsed_segments_to_decode_set
and current_loop_segment_number in parsed_segments_to_decode_set
):
should_save_mp4_this_iteration = True
# Condition 3: Save based on preview_frequency, if enabled (preview_frequency > 0)
elif preview_frequency > 0 and (
current_loop_segment_number % preview_frequency == 0
):
should_save_mp4_this_iteration = True
if should_save_mp4_this_iteration:
segment_mp4_filename = os.path.join(
outputs_folder,
f"{job_id}_segment_{current_loop_segment_number}_frames_{current_video_frame_count}.mp4",
) # Updated filename to use 1-indexed segment
save_bcthw_as_mp4(
history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf
)
final_output_filename = segment_mp4_filename
print(
f"Task {task_id}: SAVED MP4 for segment {current_loop_segment_number} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}"
) # Updated log to use 1-indexed segment
output_queue_ref.push(
(
"file",
(
task_id,
segment_mp4_filename,
f"Segment {current_loop_segment_number} MP4 saved ({current_video_frame_count} frames)",
),
)
) # Updated output queue message to use 1-indexed segment
else:
print(
f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_loop_segment_number}."
) # Updated log to use 1-indexed segment
except KeyboardInterrupt:
print(f"Worker task {task_id} caught KeyboardInterrupt (likely abort signal).")
output_queue_ref.push(("aborted", task_id))
success = False
except Exception as e:
print(f"Error in worker task {task_id}: {e}")
traceback.print_exc()
output_queue_ref.push(("error", (task_id, str(e))))
success = False
finally:
transformer.high_quality_fp32_output_for_inference = original_fp32_setting
print(
f"Task {task_id}: Restored transformer.high_quality_fp32_output_for_inference to {original_fp32_setting}"
)
if not high_vram:
unload_complete_models(
text_encoder, text_encoder_2, image_encoder, vae, transformer
)
if final_output_filename and not os.path.dirname(
final_output_filename
) == os.path.abspath(outputs_folder):
final_output_filename = os.path.join(
outputs_folder, os.path.basename(final_output_filename)
)
output_queue_ref.push(("end", (task_id, success, final_output_filename)))