import torch import traceback import einops import numpy as np import os import threading import json from PIL import Image from PIL.PngImagePlugin import PngInfo from diffusers_helper.hunyuan import ( encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake, ) from diffusers_helper.utils import ( save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp, ) from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan from diffusers_helper.memory import ( unload_complete_models, load_model_as_complete, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, gpu, ) from diffusers_helper.clip_vision import hf_clip_vision_encode from diffusers_helper.bucket_tools import find_nearest_bucket from diffusers_helper.gradio.progress_bar import make_progress_bar_html from ui import metadata as metadata_manager @torch.no_grad() def worker( # --- Task I/O & Identity --- task_id, input_image, output_folder, output_queue_ref, # --- Creative Parameters (The "Recipe") --- prompt, n_prompt, seed, total_second_length, steps, cfg, gs, gs_final, gs_schedule_active, rs, preview_frequency, segments_to_decode_csv, # --- Environment & Debug Parameters --- latent_window_size, gpu_memory_preservation, use_teacache, use_fp32_transformer_output, mp4_crf, # --- Model & System Objects (Passed from main app) --- text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, feature_extractor, image_encoder, transformer, high_vram, # --- Control Flow --- abort_event: threading.Event = None, ): outputs_folder = ( os.path.expanduser(output_folder) if output_folder else "./outputs/" ) os.makedirs(outputs_folder, exist_ok=True) # --- Gemini: do not touch - "secret sauce" total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) total_latent_sections = int(max(round(total_latent_sections), 1)) job_id = f"{generate_timestamp()}_task{task_id}" output_queue_ref.push( ( "progress", ( task_id, None, f"Total Segments: {total_latent_sections}", make_progress_bar_html(0, "Starting ..."), ), ) ) # --- parsed_segments_to_decode_set = set() if segments_to_decode_csv: try: parsed_segments_to_decode_set = { int(s.strip()) for s in segments_to_decode_csv.split(",") if s.strip() } except ValueError: print( f"Task {task_id}: Warning - Could not parse 'Segments to Decode CSV': \"{segments_to_decode_csv}\"." ) final_output_filename = None success = False initial_gs_from_ui = gs gs_final_value_for_schedule = ( gs_final if gs_final is not None else initial_gs_from_ui ) original_fp32_setting = transformer.high_quality_fp32_output_for_inference transformer.high_quality_fp32_output_for_inference = use_fp32_transformer_output print( f"Task {task_id}: transformer.high_quality_fp32_output_for_inference set to {use_fp32_transformer_output}" ) try: if not isinstance(input_image, np.ndarray): raise ValueError(f"Task {task_id}: input_image is not a NumPy array.") output_queue_ref.push( ( "progress", ( task_id, None, f"Total Segments: {total_latent_sections}", make_progress_bar_html(0, "Image processing ..."), ), ) ) if input_image.shape[-1] == 4: pil_img = Image.fromarray(input_image) input_image = np.array(pil_img.convert("RGB")) H, W, C = input_image.shape if C != 3: raise ValueError( f"Task {task_id}: Input image must be RGB, found {C} channels." ) height, width = find_nearest_bucket(H, W, resolution=640) input_image_np = resize_and_center_crop( input_image, target_width=width, target_height=height ) metadata_obj = PngInfo() params_to_save_in_metadata = { "prompt": prompt, "n_prompt": n_prompt, "seed": seed, "total_second_length": total_second_length, "steps": steps, "cfg": cfg, "gs": gs, "gs_final": gs_final, "gs_schedule_active": gs_schedule_active, "rs": rs, "preview_frequency": preview_frequency, "segments_to_decode_csv": segments_to_decode_csv, } metadata_obj.add_text("parameters", json.dumps(params_to_save_in_metadata)) initial_image_with_params_path = os.path.join( outputs_folder, f"{job_id}_initial_image_with_params.png" ) try: Image.fromarray(input_image_np).save( initial_image_with_params_path, pnginfo=metadata_obj ) except Exception as e_png: print( f"Task {task_id}: WARNING - Failed to save initial image with parameters: {e_png}" ) # --- Gemini: do not touch - "secret sauce" if not high_vram: unload_complete_models( text_encoder, text_encoder_2, image_encoder, vae, transformer ) output_queue_ref.push( ( "progress", ( task_id, None, f"Total Segments: {total_latent_sections}", make_progress_bar_html(0, "Text encoding ..."), ), ) ) if not high_vram: fake_diffusers_current_device(text_encoder, gpu) load_model_as_complete(text_encoder_2, target_device=gpu) llama_vec, clip_l_pooler = encode_prompt_conds( prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2 ) if cfg == 1: llama_vec_n, clip_l_pooler_n = torch.zeros_like( llama_vec ), torch.zeros_like(clip_l_pooler) else: llama_vec_n, clip_l_pooler_n = encode_prompt_conds( n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2 ) llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask( llama_vec_n, length=512 ) input_image_pt = ( torch.from_numpy(input_image_np).float().permute(2, 0, 1).unsqueeze(0) / 127.5 - 1.0 ) input_image_pt = input_image_pt[:, :, None, :, :] output_queue_ref.push( ( "progress", ( task_id, None, f"Total Segments: {total_latent_sections}", make_progress_bar_html(0, "VAE encoding ..."), ), ) ) if not high_vram: load_model_as_complete(vae, target_device=gpu) start_latent = vae_encode(input_image_pt, vae) output_queue_ref.push( ( "progress", ( task_id, None, f"Total Segments: {total_latent_sections}", make_progress_bar_html(0, "CLIP Vision encoding ..."), ), ) ) if not high_vram: load_model_as_complete(image_encoder, target_device=gpu) image_encoder_output = hf_clip_vision_encode( input_image_np, feature_extractor, image_encoder ) image_encoder_last_hidden_state = image_encoder_output.last_hidden_state ( llama_vec, llama_vec_n, clip_l_pooler, clip_l_pooler_n, image_encoder_last_hidden_state, ) = [ t.to(transformer.dtype) for t in [ llama_vec, llama_vec_n, clip_l_pooler, clip_l_pooler_n, image_encoder_last_hidden_state, ] ] output_queue_ref.push( ( "progress", ( task_id, None, f"Total Segments: {total_latent_sections}", make_progress_bar_html(0, "Start sampling ..."), ), ) ) rnd = torch.Generator(device="cpu").manual_seed(int(seed)) num_frames = latent_window_size * 4 - 3 # overlapped_frames = num_frames history_latents = torch.zeros( size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32, device="cpu", ) history_pixels = None total_generated_latent_frames = 0 latent_paddings = list(reversed(range(total_latent_sections))) if total_latent_sections > 4: latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] # for latent_padding_iteration, latent_padding in enumerate(latent_paddings): # if abort_event and abort_event.is_set(): raise KeyboardInterrupt("Abort signal received.") # is_last_section = (latent_padding == 0) # latent_padding_size = latent_padding * latent_window_size # print(f'Task {task_id}: Seg {latent_padding_iteration + 1}/{total_latent_sections} (lp_val={latent_padding}), last_loop_seg={is_last_section}') # ^ our code | v Flash code for latent_padding_iteration, latent_padding in enumerate(latent_paddings): if abort_event and abort_event.is_set(): raise KeyboardInterrupt("Abort signal received.") is_last_section = latent_padding == 0 latent_padding_size = latent_padding * latent_window_size # Added for consistent 1-indexed segment number for loop segments current_loop_segment_number = latent_padding_iteration + 1 print( f"Task {task_id}: Seg {current_loop_segment_number}/{total_latent_sections} (lp_val={latent_padding}), last_loop_seg={is_last_section}" ) indices = torch.arange( 0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16]), device="cpu", ).unsqueeze(0) ( clean_latent_indices_pre, _, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices, ) = indices.split( [1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1 ) clean_latents_pre = start_latent.to( history_latents.device, dtype=history_latents.dtype ) clean_latent_indices = torch.cat( [clean_latent_indices_pre, clean_latent_indices_post], dim=1 ) # current_history_depth_for_clean_split = history_latents.shape[2]; needed_depth_for_clean_split = 1 + 2 + 16 # history_latents_for_clean_split = history_latents # if current_history_depth_for_clean_split < needed_depth_for_clean_split: # padding_needed = needed_depth_for_clean_split - current_history_depth_for_clean_split # pad_tensor = torch.zeros(history_latents.shape[0], history_latents.shape[1], padding_needed, history_latents.shape[3], history_latents.shape[4], dtype=history_latents.dtype, device=history_latents.device) # history_latents_for_clean_split = torch.cat((history_latents, pad_tensor), dim=2) clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[ :, :, : 1 + 2 + 16, :, : ].split([1, 2, 16], dim=2) clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) if not high_vram: unload_complete_models() move_model_to_device_with_memory_preservation( transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation, ) transformer.initialize_teacache( enable_teacache=use_teacache, num_steps=steps ) def callback_diffusion_step(d): if abort_event and abort_event.is_set(): raise KeyboardInterrupt("Abort signal received during sampling.") current_diffusion_step = d["i"] + 1 is_first_step = current_diffusion_step == 1 is_last_step = current_diffusion_step == steps is_preview_step = preview_frequency > 0 and ( current_diffusion_step % preview_frequency == 0 ) if not (is_first_step or is_last_step or is_preview_step): return preview_latent = d["denoised"] preview_img_np = vae_decode_fake(preview_latent) preview_img_np = ( (preview_img_np * 255.0) .detach() .cpu() .numpy() .clip(0, 255) .astype(np.uint8) ) preview_img_np = einops.rearrange( preview_img_np, "b c t h w -> (b h) (t w) c" ) # percentage = int(100.0 * current_diffusion_step / steps) # hint = f'Segment {latent_padding_iteration + 1}, Sampling {current_diffusion_step}/{steps}' # current_video_frames_count = history_pixels.shape[2] if history_pixels is not None else 0 # desc = f'Task {task_id}: Vid Frames: {current_video_frames_count}, Len: {current_video_frames_count / 30 :.2f}s. Seg {latent_padding_iteration + 1}/{total_latent_sections}. Extending...' # output_queue_ref.push(('progress', (task_id, preview_img_np, desc, make_progress_bar_html(percentage, hint)))) # ^ our code | v Flash code percentage = int(100.0 * current_diffusion_step / steps) hint = f"Segment {current_loop_segment_number}, Sampling {current_diffusion_step}/{steps}" # Updated hint current_video_frames_count = ( history_pixels.shape[2] if history_pixels is not None else 0 ) desc = f"Task {task_id}: Vid Frames: {current_video_frames_count}, Len: {current_video_frames_count / 30 :.2f}s. Seg {current_loop_segment_number}/{total_latent_sections}. Extending..." # Updated desc output_queue_ref.push( ( "progress", ( task_id, preview_img_np, desc, make_progress_bar_html(percentage, hint), ), ) ) current_segment_gs_to_use = initial_gs_from_ui if gs_schedule_active and total_latent_sections > 1: progress_for_gs = ( latent_padding_iteration / (total_latent_sections - 1) if total_latent_sections > 1 else 0 ) current_segment_gs_to_use = ( initial_gs_from_ui + (gs_final_value_for_schedule - initial_gs_from_ui) * progress_for_gs ) generated_latents = sample_hunyuan( transformer=transformer, sampler="unipc", width=width, height=height, frames=num_frames, real_guidance_scale=cfg, distilled_guidance_scale=current_segment_gs_to_use, guidance_rescale=rs, num_inference_steps=steps, generator=rnd, prompt_embeds=llama_vec.to(transformer.device), prompt_embeds_mask=llama_attention_mask.to(transformer.device), prompt_poolers=clip_l_pooler.to(transformer.device), negative_prompt_embeds=llama_vec_n.to(transformer.device), negative_prompt_embeds_mask=llama_attention_mask_n.to( transformer.device ), negative_prompt_poolers=clip_l_pooler_n.to(transformer.device), device=transformer.device, dtype=transformer.dtype, image_embeddings=image_encoder_last_hidden_state.to(transformer.device), latent_indices=latent_indices.to(transformer.device), clean_latents=clean_latents.to( transformer.device, dtype=transformer.dtype ), clean_latent_indices=clean_latent_indices.to(transformer.device), clean_latents_2x=clean_latents_2x.to( transformer.device, dtype=transformer.dtype ), clean_latent_2x_indices=clean_latent_2x_indices.to(transformer.device), clean_latents_4x=clean_latents_4x.to( transformer.device, dtype=transformer.dtype ), clean_latent_4x_indices=clean_latent_4x_indices.to(transformer.device), callback=callback_diffusion_step, ) if is_last_section: generated_latents = torch.cat( [start_latent.to(generated_latents), generated_latents], dim=2 ) total_generated_latent_frames += int(generated_latents.shape[2]) history_latents = torch.cat( [generated_latents.to(history_latents), history_latents], dim=2 ) if not high_vram: offload_model_from_device_for_memory_preservation( transformer, target_device=gpu, preserved_memory_gb=8 ) load_model_as_complete(vae, target_device=gpu) real_history_latents = history_latents[ :, :, :total_generated_latent_frames, :, : ] if history_pixels is None: history_pixels = vae_decode(real_history_latents, vae).cpu() else: section_latent_frames = ( (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) ) overlapped_frames = latent_window_size * 4 - 3 current_pixels = vae_decode( real_history_latents[:, :, :section_latent_frames], vae ).cpu() history_pixels = soft_append_bcthw( current_pixels, history_pixels, overlapped_frames ) if not high_vram: unload_complete_models() current_video_frame_count = history_pixels.shape[2] # --- Gemini start again # # Skip writing preview mp4 for this segment logic # should_save_mp4_this_iteration = False # current_segment_1_indexed = latent_padding_iteration # + 1 # if (latent_padding_iteration == 0) or is_last_section or (parsed_segments_to_decode_set and current_segment_1_indexed in parsed_segments_to_decode_set): # should_save_mp4_this_iteration = True # if should_save_mp4_this_iteration: # segment_mp4_filename = os.path.join(outputs_folder, f'{job_id}_segment_{latent_padding_iteration}_frames_{current_video_frame_count}.mp4') # save_bcthw_as_mp4(history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf) # final_output_filename = segment_mp4_filename # print(f"Task {task_id}: SAVED MP4 for segment {latent_padding_iteration} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}") # output_queue_ref.push(('file', (task_id, segment_mp4_filename, f"Segment {latent_padding_iteration} MP4 saved ({current_video_frame_count} frames)"))) # else: # print(f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_segment_1_indexed}.") # if is_last_section: success = True; break # --- Gemini start again # ^ original code | v Flash code # # Skip writing preview mp4 for this segment logic # should_save_mp4_this_iteration = False # # Use latent_padding_iteration directly here, as it's the 0-indexed loop counter # current_segment_index = latent_padding_iteration # # Condition 1: Always save the first segment (index 0) # if current_segment_index == 0: # should_save_mp4_this_iteration = True # # Condition 2: Always save the last segment # elif is_last_section: # should_save_mp4_this_iteration = True # # Condition 3: Save if the current segment index is in the parsed set # elif parsed_segments_to_decode_set and (current_segment_index + 1) in parsed_segments_to_decode_set: # # Add 1 here if segments_to_decode_csv assumes 1-based indexing for user input # should_save_mp4_this_iteration = True # # Condition 4: Save based on preview_frequency, if enabled (preview_frequency > 0) # elif preview_frequency > 0 and current_segment_index % preview_frequency == 0: # should_save_mp4_this_iteration = True # if should_save_mp4_this_iteration: # segment_mp4_filename = os.path.join(outputs_folder, f'{job_id}_segment_{latent_padding_iteration}_frames_{current_video_frame_count}.mp4') # save_bcthw_as_mp4(history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf) # final_output_filename = segment_mp4_filename # print(f"Task {task_id}: SAVED MP4 for segment {latent_padding_iteration} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}") # output_queue_ref.push(('file', (task_id, segment_mp4_filename, f"Segment {latent_padding_iteration} MP4 saved ({current_video_frame_count} frames)"))) # else: # print(f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_segment_index}.") # Determine if we should save an intermediate MP4 for this loop segment should_save_mp4_this_iteration = False # Condition 1: Always save the last segment of the loop if is_last_section: should_save_mp4_this_iteration = True # Condition 2: Save if the current loop segment number is explicitly in the parsed set elif ( parsed_segments_to_decode_set and current_loop_segment_number in parsed_segments_to_decode_set ): should_save_mp4_this_iteration = True # Condition 3: Save based on preview_frequency, if enabled (preview_frequency > 0) elif preview_frequency > 0 and ( current_loop_segment_number % preview_frequency == 0 ): should_save_mp4_this_iteration = True if should_save_mp4_this_iteration: segment_mp4_filename = os.path.join( outputs_folder, f"{job_id}_segment_{current_loop_segment_number}_frames_{current_video_frame_count}.mp4", ) # Updated filename to use 1-indexed segment save_bcthw_as_mp4( history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf ) final_output_filename = segment_mp4_filename print( f"Task {task_id}: SAVED MP4 for segment {current_loop_segment_number} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}" ) # Updated log to use 1-indexed segment output_queue_ref.push( ( "file", ( task_id, segment_mp4_filename, f"Segment {current_loop_segment_number} MP4 saved ({current_video_frame_count} frames)", ), ) ) # Updated output queue message to use 1-indexed segment else: print( f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_loop_segment_number}." ) # Updated log to use 1-indexed segment except KeyboardInterrupt: print(f"Worker task {task_id} caught KeyboardInterrupt (likely abort signal).") output_queue_ref.push(("aborted", task_id)) success = False except Exception as e: print(f"Error in worker task {task_id}: {e}") traceback.print_exc() output_queue_ref.push(("error", (task_id, str(e)))) success = False finally: transformer.high_quality_fp32_output_for_inference = original_fp32_setting print( f"Task {task_id}: Restored transformer.high_quality_fp32_output_for_inference to {original_fp32_setting}" ) if not high_vram: unload_complete_models( text_encoder, text_encoder_2, image_encoder, vae, transformer ) if final_output_filename and not os.path.dirname( final_output_filename ) == os.path.abspath(outputs_folder): final_output_filename = os.path.join( outputs_folder, os.path.basename(final_output_filename) ) output_queue_ref.push(("end", (task_id, success, final_output_filename)))