File size: 27,065 Bytes
1030ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
import torch
import traceback
import einops
import numpy as np
import os
import threading
import json
from PIL import Image
from PIL.PngImagePlugin import PngInfo

from diffusers_helper.hunyuan import (
    encode_prompt_conds,
    vae_decode,
    vae_encode,
    vae_decode_fake,
)
from diffusers_helper.utils import (
    save_bcthw_as_mp4,
    crop_or_pad_yield_mask,
    soft_append_bcthw,
    resize_and_center_crop,
    generate_timestamp,
)
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import (
    unload_complete_models,
    load_model_as_complete,
    move_model_to_device_with_memory_preservation,
    offload_model_from_device_for_memory_preservation,
    fake_diffusers_current_device,
    gpu,
)
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket
from diffusers_helper.gradio.progress_bar import make_progress_bar_html
from ui import metadata as metadata_manager


@torch.no_grad()
def worker(

    # --- Task I/O & Identity ---

    task_id,

    input_image,

    output_folder,

    output_queue_ref,

    # --- Creative Parameters (The "Recipe") ---

    prompt,

    n_prompt,

    seed,

    total_second_length,

    steps,

    cfg,

    gs,

    gs_final,

    gs_schedule_active,

    rs,

    preview_frequency,

    segments_to_decode_csv,

    # --- Environment & Debug Parameters ---

    latent_window_size,

    gpu_memory_preservation,

    use_teacache,

    use_fp32_transformer_output,

    mp4_crf,

    # --- Model & System Objects (Passed from main app) ---

    text_encoder,

    text_encoder_2,

    tokenizer,

    tokenizer_2,

    vae,

    feature_extractor,

    image_encoder,

    transformer,

    high_vram,

    # --- Control Flow ---

    abort_event: threading.Event = None,

):
    outputs_folder = (
        os.path.expanduser(output_folder) if output_folder else "./outputs/"
    )
    os.makedirs(outputs_folder, exist_ok=True)

    # --- Gemini: do not touch - "secret sauce"
    total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
    total_latent_sections = int(max(round(total_latent_sections), 1))

    job_id = f"{generate_timestamp()}_task{task_id}"
    output_queue_ref.push(
        (
            "progress",
            (
                task_id,
                None,
                f"Total Segments: {total_latent_sections}",
                make_progress_bar_html(0, "Starting ..."),
            ),
        )
    )
    # ---
    parsed_segments_to_decode_set = set()
    if segments_to_decode_csv:
        try:
            parsed_segments_to_decode_set = {
                int(s.strip()) for s in segments_to_decode_csv.split(",") if s.strip()
            }
        except ValueError:
            print(
                f"Task {task_id}: Warning - Could not parse 'Segments to Decode CSV': \"{segments_to_decode_csv}\"."
            )
    final_output_filename = None
    success = False
    initial_gs_from_ui = gs
    gs_final_value_for_schedule = (
        gs_final if gs_final is not None else initial_gs_from_ui
    )
    original_fp32_setting = transformer.high_quality_fp32_output_for_inference
    transformer.high_quality_fp32_output_for_inference = use_fp32_transformer_output
    print(
        f"Task {task_id}: transformer.high_quality_fp32_output_for_inference set to {use_fp32_transformer_output}"
    )

    try:
        if not isinstance(input_image, np.ndarray):
            raise ValueError(f"Task {task_id}: input_image is not a NumPy array.")

        output_queue_ref.push(
            (
                "progress",
                (
                    task_id,
                    None,
                    f"Total Segments: {total_latent_sections}",
                    make_progress_bar_html(0, "Image processing ..."),
                ),
            )
        )
        if input_image.shape[-1] == 4:
            pil_img = Image.fromarray(input_image)
            input_image = np.array(pil_img.convert("RGB"))
        H, W, C = input_image.shape
        if C != 3:
            raise ValueError(
                f"Task {task_id}: Input image must be RGB, found {C} channels."
            )
        height, width = find_nearest_bucket(H, W, resolution=640)
        input_image_np = resize_and_center_crop(
            input_image, target_width=width, target_height=height
        )

        metadata_obj = PngInfo()
        params_to_save_in_metadata = {
            "prompt": prompt,
            "n_prompt": n_prompt,
            "seed": seed,
            "total_second_length": total_second_length,
            "steps": steps,
            "cfg": cfg,
            "gs": gs,
            "gs_final": gs_final,
            "gs_schedule_active": gs_schedule_active,
            "rs": rs,
            "preview_frequency": preview_frequency,
            "segments_to_decode_csv": segments_to_decode_csv,
        }
        metadata_obj.add_text("parameters", json.dumps(params_to_save_in_metadata))
        initial_image_with_params_path = os.path.join(
            outputs_folder, f"{job_id}_initial_image_with_params.png"
        )
        try:
            Image.fromarray(input_image_np).save(
                initial_image_with_params_path, pnginfo=metadata_obj
            )
        except Exception as e_png:
            print(
                f"Task {task_id}: WARNING - Failed to save initial image with parameters: {e_png}"
            )

        # --- Gemini: do not touch - "secret sauce"
        if not high_vram:
            unload_complete_models(
                text_encoder, text_encoder_2, image_encoder, vae, transformer
            )
        output_queue_ref.push(
            (
                "progress",
                (
                    task_id,
                    None,
                    f"Total Segments: {total_latent_sections}",
                    make_progress_bar_html(0, "Text encoding ..."),
                ),
            )
        )
        if not high_vram:
            fake_diffusers_current_device(text_encoder, gpu)
            load_model_as_complete(text_encoder_2, target_device=gpu)
        llama_vec, clip_l_pooler = encode_prompt_conds(
            prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
        )
        if cfg == 1:
            llama_vec_n, clip_l_pooler_n = torch.zeros_like(
                llama_vec
            ), torch.zeros_like(clip_l_pooler)
        else:
            llama_vec_n, clip_l_pooler_n = encode_prompt_conds(
                n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
            )
        llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
        llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(
            llama_vec_n, length=512
        )
        input_image_pt = (
            torch.from_numpy(input_image_np).float().permute(2, 0, 1).unsqueeze(0)
            / 127.5
            - 1.0
        )
        input_image_pt = input_image_pt[:, :, None, :, :]
        output_queue_ref.push(
            (
                "progress",
                (
                    task_id,
                    None,
                    f"Total Segments: {total_latent_sections}",
                    make_progress_bar_html(0, "VAE encoding ..."),
                ),
            )
        )
        if not high_vram:
            load_model_as_complete(vae, target_device=gpu)
        start_latent = vae_encode(input_image_pt, vae)
        output_queue_ref.push(
            (
                "progress",
                (
                    task_id,
                    None,
                    f"Total Segments: {total_latent_sections}",
                    make_progress_bar_html(0, "CLIP Vision encoding ..."),
                ),
            )
        )
        if not high_vram:
            load_model_as_complete(image_encoder, target_device=gpu)
        image_encoder_output = hf_clip_vision_encode(
            input_image_np, feature_extractor, image_encoder
        )
        image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
        (
            llama_vec,
            llama_vec_n,
            clip_l_pooler,
            clip_l_pooler_n,
            image_encoder_last_hidden_state,
        ) = [
            t.to(transformer.dtype)
            for t in [
                llama_vec,
                llama_vec_n,
                clip_l_pooler,
                clip_l_pooler_n,
                image_encoder_last_hidden_state,
            ]
        ]

        output_queue_ref.push(
            (
                "progress",
                (
                    task_id,
                    None,
                    f"Total Segments: {total_latent_sections}",
                    make_progress_bar_html(0, "Start sampling ..."),
                ),
            )
        )
        rnd = torch.Generator(device="cpu").manual_seed(int(seed))
        num_frames = latent_window_size * 4 - 3
        # overlapped_frames = num_frames

        history_latents = torch.zeros(
            size=(1, 16, 1 + 2 + 16, height // 8, width // 8),
            dtype=torch.float32,
            device="cpu",
        )
        history_pixels = None
        total_generated_latent_frames = 0
        latent_paddings = list(reversed(range(total_latent_sections)))
        if total_latent_sections > 4:
            latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]

        # for latent_padding_iteration, latent_padding in enumerate(latent_paddings):
        #     if abort_event and abort_event.is_set(): raise KeyboardInterrupt("Abort signal received.")
        #     is_last_section = (latent_padding == 0)
        #     latent_padding_size = latent_padding * latent_window_size
        #     print(f'Task {task_id}: Seg {latent_padding_iteration + 1}/{total_latent_sections} (lp_val={latent_padding}), last_loop_seg={is_last_section}')

        # ^ our code | v Flash code

        for latent_padding_iteration, latent_padding in enumerate(latent_paddings):
            if abort_event and abort_event.is_set():
                raise KeyboardInterrupt("Abort signal received.")
            is_last_section = latent_padding == 0
            latent_padding_size = latent_padding * latent_window_size
            # Added for consistent 1-indexed segment number for loop segments
            current_loop_segment_number = latent_padding_iteration + 1
            print(
                f"Task {task_id}: Seg {current_loop_segment_number}/{total_latent_sections} (lp_val={latent_padding}), last_loop_seg={is_last_section}"
            )

            indices = torch.arange(
                0,
                sum([1, latent_padding_size, latent_window_size, 1, 2, 16]),
                device="cpu",
            ).unsqueeze(0)
            (
                clean_latent_indices_pre,
                _,
                latent_indices,
                clean_latent_indices_post,
                clean_latent_2x_indices,
                clean_latent_4x_indices,
            ) = indices.split(
                [1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1
            )
            clean_latents_pre = start_latent.to(
                history_latents.device, dtype=history_latents.dtype
            )
            clean_latent_indices = torch.cat(
                [clean_latent_indices_pre, clean_latent_indices_post], dim=1
            )
            # current_history_depth_for_clean_split = history_latents.shape[2]; needed_depth_for_clean_split = 1 + 2 + 16
            # history_latents_for_clean_split = history_latents
            # if current_history_depth_for_clean_split < needed_depth_for_clean_split:
            #     padding_needed = needed_depth_for_clean_split - current_history_depth_for_clean_split
            #     pad_tensor = torch.zeros(history_latents.shape[0], history_latents.shape[1], padding_needed, history_latents.shape[3], history_latents.shape[4], dtype=history_latents.dtype, device=history_latents.device)
            #     history_latents_for_clean_split = torch.cat((history_latents, pad_tensor), dim=2)
            clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[
                :, :, : 1 + 2 + 16, :, :
            ].split([1, 2, 16], dim=2)
            clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)

            if not high_vram:
                unload_complete_models()
                move_model_to_device_with_memory_preservation(
                    transformer,
                    target_device=gpu,
                    preserved_memory_gb=gpu_memory_preservation,
                )
            transformer.initialize_teacache(
                enable_teacache=use_teacache, num_steps=steps
            )

            def callback_diffusion_step(d):
                if abort_event and abort_event.is_set():
                    raise KeyboardInterrupt("Abort signal received during sampling.")
                current_diffusion_step = d["i"] + 1
                is_first_step = current_diffusion_step == 1
                is_last_step = current_diffusion_step == steps
                is_preview_step = preview_frequency > 0 and (
                    current_diffusion_step % preview_frequency == 0
                )
                if not (is_first_step or is_last_step or is_preview_step):
                    return
                preview_latent = d["denoised"]
                preview_img_np = vae_decode_fake(preview_latent)
                preview_img_np = (
                    (preview_img_np * 255.0)
                    .detach()
                    .cpu()
                    .numpy()
                    .clip(0, 255)
                    .astype(np.uint8)
                )
                preview_img_np = einops.rearrange(
                    preview_img_np, "b c t h w -> (b h) (t w) c"
                )

                # percentage = int(100.0 * current_diffusion_step / steps)
                # hint = f'Segment {latent_padding_iteration + 1}, Sampling {current_diffusion_step}/{steps}'
                # current_video_frames_count = history_pixels.shape[2] if history_pixels is not None else 0
                # desc = f'Task {task_id}: Vid Frames: {current_video_frames_count}, Len: {current_video_frames_count / 30 :.2f}s. Seg {latent_padding_iteration + 1}/{total_latent_sections}. Extending...'
                # output_queue_ref.push(('progress', (task_id, preview_img_np, desc, make_progress_bar_html(percentage, hint))))

                # ^ our code | v Flash code

                percentage = int(100.0 * current_diffusion_step / steps)
                hint = f"Segment {current_loop_segment_number}, Sampling {current_diffusion_step}/{steps}"  # Updated hint
                current_video_frames_count = (
                    history_pixels.shape[2] if history_pixels is not None else 0
                )
                desc = f"Task {task_id}: Vid Frames: {current_video_frames_count}, Len: {current_video_frames_count / 30 :.2f}s. Seg {current_loop_segment_number}/{total_latent_sections}. Extending..."  # Updated desc
                output_queue_ref.push(
                    (
                        "progress",
                        (
                            task_id,
                            preview_img_np,
                            desc,
                            make_progress_bar_html(percentage, hint),
                        ),
                    )
                )

            current_segment_gs_to_use = initial_gs_from_ui
            if gs_schedule_active and total_latent_sections > 1:
                progress_for_gs = (
                    latent_padding_iteration / (total_latent_sections - 1)
                    if total_latent_sections > 1
                    else 0
                )
                current_segment_gs_to_use = (
                    initial_gs_from_ui
                    + (gs_final_value_for_schedule - initial_gs_from_ui)
                    * progress_for_gs
                )

            generated_latents = sample_hunyuan(
                transformer=transformer,
                sampler="unipc",
                width=width,
                height=height,
                frames=num_frames,
                real_guidance_scale=cfg,
                distilled_guidance_scale=current_segment_gs_to_use,
                guidance_rescale=rs,
                num_inference_steps=steps,
                generator=rnd,
                prompt_embeds=llama_vec.to(transformer.device),
                prompt_embeds_mask=llama_attention_mask.to(transformer.device),
                prompt_poolers=clip_l_pooler.to(transformer.device),
                negative_prompt_embeds=llama_vec_n.to(transformer.device),
                negative_prompt_embeds_mask=llama_attention_mask_n.to(
                    transformer.device
                ),
                negative_prompt_poolers=clip_l_pooler_n.to(transformer.device),
                device=transformer.device,
                dtype=transformer.dtype,
                image_embeddings=image_encoder_last_hidden_state.to(transformer.device),
                latent_indices=latent_indices.to(transformer.device),
                clean_latents=clean_latents.to(
                    transformer.device, dtype=transformer.dtype
                ),
                clean_latent_indices=clean_latent_indices.to(transformer.device),
                clean_latents_2x=clean_latents_2x.to(
                    transformer.device, dtype=transformer.dtype
                ),
                clean_latent_2x_indices=clean_latent_2x_indices.to(transformer.device),
                clean_latents_4x=clean_latents_4x.to(
                    transformer.device, dtype=transformer.dtype
                ),
                clean_latent_4x_indices=clean_latent_4x_indices.to(transformer.device),
                callback=callback_diffusion_step,
            )

            if is_last_section:
                generated_latents = torch.cat(
                    [start_latent.to(generated_latents), generated_latents], dim=2
                )

            total_generated_latent_frames += int(generated_latents.shape[2])
            history_latents = torch.cat(
                [generated_latents.to(history_latents), history_latents], dim=2
            )

            if not high_vram:
                offload_model_from_device_for_memory_preservation(
                    transformer, target_device=gpu, preserved_memory_gb=8
                )
                load_model_as_complete(vae, target_device=gpu)

            real_history_latents = history_latents[
                :, :, :total_generated_latent_frames, :, :
            ]

            if history_pixels is None:
                history_pixels = vae_decode(real_history_latents, vae).cpu()
            else:
                section_latent_frames = (
                    (latent_window_size * 2 + 1)
                    if is_last_section
                    else (latent_window_size * 2)
                )
                overlapped_frames = latent_window_size * 4 - 3
                current_pixels = vae_decode(
                    real_history_latents[:, :, :section_latent_frames], vae
                ).cpu()
                history_pixels = soft_append_bcthw(
                    current_pixels, history_pixels, overlapped_frames
                )

            if not high_vram:
                unload_complete_models()

            current_video_frame_count = history_pixels.shape[2]

            # --- Gemini start again
            # # Skip writing preview mp4 for this segment logic
            # should_save_mp4_this_iteration = False
            # current_segment_1_indexed = latent_padding_iteration # + 1
            # if (latent_padding_iteration == 0) or is_last_section or (parsed_segments_to_decode_set and current_segment_1_indexed in parsed_segments_to_decode_set):
            #     should_save_mp4_this_iteration = True
            # if should_save_mp4_this_iteration:
            #     segment_mp4_filename = os.path.join(outputs_folder, f'{job_id}_segment_{latent_padding_iteration}_frames_{current_video_frame_count}.mp4')
            #     save_bcthw_as_mp4(history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf)
            #     final_output_filename = segment_mp4_filename
            #     print(f"Task {task_id}: SAVED MP4 for segment {latent_padding_iteration} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}")
            #     output_queue_ref.push(('file', (task_id, segment_mp4_filename, f"Segment {latent_padding_iteration} MP4 saved ({current_video_frame_count} frames)")))
            # else:
            #     print(f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_segment_1_indexed}.")

            # if is_last_section: success = True; break

            # --- Gemini start again

            # ^ original code | v Flash code

            # # Skip writing preview mp4 for this segment logic
            # should_save_mp4_this_iteration = False
            # # Use latent_padding_iteration directly here, as it's the 0-indexed loop counter
            # current_segment_index = latent_padding_iteration

            # # Condition 1: Always save the first segment (index 0)
            # if current_segment_index == 0:
            #     should_save_mp4_this_iteration = True
            # # Condition 2: Always save the last segment
            # elif is_last_section:
            #     should_save_mp4_this_iteration = True
            # # Condition 3: Save if the current segment index is in the parsed set
            # elif parsed_segments_to_decode_set and (current_segment_index + 1) in parsed_segments_to_decode_set:
            #     # Add 1 here if segments_to_decode_csv assumes 1-based indexing for user input
            #     should_save_mp4_this_iteration = True
            # # Condition 4: Save based on preview_frequency, if enabled (preview_frequency > 0)
            # elif preview_frequency > 0 and current_segment_index % preview_frequency == 0:
            #     should_save_mp4_this_iteration = True

            # if should_save_mp4_this_iteration:
            #     segment_mp4_filename = os.path.join(outputs_folder, f'{job_id}_segment_{latent_padding_iteration}_frames_{current_video_frame_count}.mp4')
            #     save_bcthw_as_mp4(history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf)
            #     final_output_filename = segment_mp4_filename
            #     print(f"Task {task_id}: SAVED MP4 for segment {latent_padding_iteration} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}")
            #     output_queue_ref.push(('file', (task_id, segment_mp4_filename, f"Segment {latent_padding_iteration} MP4 saved ({current_video_frame_count} frames)")))
            # else:
            #     print(f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_segment_index}.")

            # Determine if we should save an intermediate MP4 for this loop segment
            should_save_mp4_this_iteration = False

            # Condition 1: Always save the last segment of the loop
            if is_last_section:
                should_save_mp4_this_iteration = True
            # Condition 2: Save if the current loop segment number is explicitly in the parsed set
            elif (
                parsed_segments_to_decode_set
                and current_loop_segment_number in parsed_segments_to_decode_set
            ):
                should_save_mp4_this_iteration = True
            # Condition 3: Save based on preview_frequency, if enabled (preview_frequency > 0)
            elif preview_frequency > 0 and (
                current_loop_segment_number % preview_frequency == 0
            ):
                should_save_mp4_this_iteration = True

            if should_save_mp4_this_iteration:
                segment_mp4_filename = os.path.join(
                    outputs_folder,
                    f"{job_id}_segment_{current_loop_segment_number}_frames_{current_video_frame_count}.mp4",
                )  # Updated filename to use 1-indexed segment
                save_bcthw_as_mp4(
                    history_pixels, segment_mp4_filename, fps=30, crf=mp4_crf
                )
                final_output_filename = segment_mp4_filename
                print(
                    f"Task {task_id}: SAVED MP4 for segment {current_loop_segment_number} to {segment_mp4_filename}. Total video frames: {current_video_frame_count}"
                )  # Updated log to use 1-indexed segment
                output_queue_ref.push(
                    (
                        "file",
                        (
                            task_id,
                            segment_mp4_filename,
                            f"Segment {current_loop_segment_number} MP4 saved ({current_video_frame_count} frames)",
                        ),
                    )
                )  # Updated output queue message to use 1-indexed segment
            else:
                print(
                    f"Task {task_id}: SKIPPED MP4 save for intermediate segment {current_loop_segment_number}."
                )  # Updated log to use 1-indexed segment

    except KeyboardInterrupt:
        print(f"Worker task {task_id} caught KeyboardInterrupt (likely abort signal).")
        output_queue_ref.push(("aborted", task_id))
        success = False
    except Exception as e:
        print(f"Error in worker task {task_id}: {e}")
        traceback.print_exc()
        output_queue_ref.push(("error", (task_id, str(e))))
        success = False
    finally:
        transformer.high_quality_fp32_output_for_inference = original_fp32_setting
        print(
            f"Task {task_id}: Restored transformer.high_quality_fp32_output_for_inference to {original_fp32_setting}"
        )
        if not high_vram:
            unload_complete_models(
                text_encoder, text_encoder_2, image_encoder, vae, transformer
            )
        if final_output_filename and not os.path.dirname(
            final_output_filename
        ) == os.path.abspath(outputs_folder):
            final_output_filename = os.path.join(
                outputs_folder, os.path.basename(final_output_filename)
            )
        output_queue_ref.push(("end", (task_id, success, final_output_filename)))