jbilcke-hf HF Staff commited on
Commit
9985ce2
·
verified ·
1 Parent(s): 0bafb17

Update handler_LAST_WORKING.py

Browse files
Files changed (1) hide show
  1. handler_LAST_WORKING.py +506 -344
handler_LAST_WORKING.py CHANGED
@@ -1,24 +1,23 @@
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
- import pathlib
4
- from typing import Dict, Any, Optional, Tuple
5
- import asyncio
6
- import base64
7
- import io
8
- import pprint
9
  import logging
 
10
  import random
11
- import traceback
12
  import os
13
  import numpy as np
14
  import torch
15
- import gc
16
-
17
- from diffusers import AutoencoderKLLTXVideo, LTXPipeline, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
18
- #from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
19
- #from teacache import apply_teacache
20
-
21
- from PIL import Image
 
 
 
 
22
 
23
  from varnish import Varnish
24
  from varnish.utils import is_truthy, process_input_image
@@ -27,14 +26,13 @@ from varnish.utils import is_truthy, process_input_image
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
30
-
31
  # Get token from environment
32
  hf_token = os.getenv("HF_API_TOKEN")
33
 
34
  # Constraints
35
  MAX_LARGE_SIDE = 1280
36
- MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
37
- MAX_FRAMES = 257
38
 
39
  # Check environment variable for pipeline support
40
  support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
@@ -48,10 +46,8 @@ class GenerationConfig:
48
  negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
49
 
50
  # video model settings (will be used during generation of the initial raw video clip)
51
- # we use small values to make things a bit faster
52
- width: int = 768
53
- height: int = 416
54
-
55
 
56
  # this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
57
  # after a quick benchmark using the value 70 seems like a sweet spot
@@ -62,8 +58,8 @@ class GenerationConfig:
62
  # visual glitches appear after about 169 frames, so we don't need more actually
63
  num_frames: int = (8 * 14) + 1
64
 
65
- # with the distilled model, a guidance scale of 1.0 is fine
66
- guidance_scale: float = 1.0
67
 
68
  num_inference_steps: int = 8
69
 
@@ -71,16 +67,16 @@ class GenerationConfig:
71
  seed: int = -1 # -1 means random seed
72
 
73
  # varnish settings (will be used for post-processing after the raw video clip has been generated
74
- fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
75
- double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
76
- super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
77
 
78
- grain_amount: float = 0.0 # be careful, adding film grian can negatively impact video compression
79
 
80
  # audio settings
81
  enable_audio: bool = False # Whether to generate audio
82
  audio_prompt: str = "" # Text prompt for audio generation
83
- audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
84
 
85
  # The range of the CRF scale is 0–51, where:
86
  # 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
@@ -92,18 +88,26 @@ class GenerationConfig:
92
  # The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
93
  quality: int = 18
94
 
95
- # TeaCache settings
96
- enable_teacache: bool = False
97
- teacache_threshold: float = 0.05 # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
 
98
 
99
- # Enhance-A-Video settings
100
- enable_enhance_a_video: bool = False
101
- enhance_a_video_weight: float = 5.0
102
 
103
- # LoRA settings
104
- lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
105
- lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
106
- lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
 
 
 
 
 
 
 
107
 
108
  def validate_and_adjust(self) -> 'GenerationConfig':
109
  """Validate and adjust parameters to meet constraints"""
@@ -111,7 +115,7 @@ class GenerationConfig:
111
  if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
112
  (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
113
  # For other resolutions, ensure total pixels don't exceed max
114
- MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
115
 
116
  # If total pixels exceed maximum, scale down proportionally
117
  total_pixels = self.width * self.height
@@ -131,371 +135,527 @@ class GenerationConfig:
131
  # Set random seed if not specified
132
  if self.seed == -1:
133
  self.seed = random.randint(0, 2**32 - 1)
134
-
135
- return self
136
 
137
- class EndpointHandler:
138
- """Handles video generation requests using LTX models and Varnish post-processing"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- def __init__(self, model_path: str = ""):
141
- """Initialize the handler with LTX models and Varnish
142
-
143
- Args:
144
- model_path: Path to LTX model weights
145
- """
146
- print("EndpointHandler.__init__(): initializing..")
147
- # Enable TF32 for potential speedup on Ampere GPUs
148
- #torch.backends.cuda.matmul.allow_tf32 = True
149
-
150
- # use distilled weights
151
- model_path = Path("/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors")
152
 
153
- print("EndpointHandler.__init__(): initializing LTXVideoTransformer3DModel..")
154
- transformer = LTXVideoTransformer3DModel.from_single_file(
155
- model_path, torch_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
- print("EndpointHandler.__init__(): initializing AutoencoderKLLTXVideo..")
159
- vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
160
-
161
- if support_image_prompt:
162
- print("EndpointHandler.__init__(): initializing LTXImageToVideoPipeline..")
163
- self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
164
- "/repository",
165
- transformer=transformer,
166
- vae=vae,
167
- torch_dtype=torch.bfloat16
168
- ).to("cuda")
169
-
170
- #apply_teacache(self.image_to_video)
171
 
172
- # Compilation requires some time to complete, so it is best suited for
173
- # situations where you prepare your pipeline once and then perform the
174
- # same type of inference operations multiple times.
175
- # For example, calling the compiled pipeline on a different image size
176
- # triggers compilation again which can be expensive.
177
- #self.image_to_video = torch.compile(self.image_to_video, mode="reduce-overhead", fullgraph=True)
178
 
179
- else:
180
- print("EndpointHandler.__init__(): initializing LTXPipeline..")
181
- # Initialize models with bfloat16 precision
182
- self.text_to_video = LTXPipeline.from_pretrained(
183
- "/repository",
184
- transformer=transformer,
185
- vae=vae,
186
- torch_dtype=torch.bfloat16
187
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- #apply_teacache(self.text_to_video)
190
-
191
- # Compilation requires some time to complete, so it is best suited for
192
- # situations where you prepare your pipeline once and then perform the
193
- # same type of inference operations multiple times.
194
- # For example, calling the compiled pipeline on a different image size
195
- # triggers compilation again which can be expensive.
196
- #self.text_to_video = torch.compile(self.text_to_video, mode="reduce-overhead", fullgraph=True)
197
-
198
-
199
- # Initialize LoRA tracking
200
- self._current_lora_model = None
201
-
202
- #if support_image_prompt:
203
- # # Enable CPU offload for memory efficiency
204
- # self.image_to_video.enable_model_cpu_offload()
205
- # # Inject enhance-a-video functionality
206
- # inject_enhance_for_ltx(self.image_to_video.transformer)
207
- #else:
208
- # # Enable CPU offload for memory efficiency
209
- # self.text_to_video.enable_model_cpu_offload()
210
- # # Inject enhance-a-video functionality
211
- # inject_enhance_for_ltx(self.text_to_video.transformer)
212
-
213
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # Initialize Varnish for post-processing
215
  self.varnish = Varnish(
216
  device="cuda",
217
  model_base_dir="/repository/varnish",
218
-
219
- # there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
220
- # not sure how to fix that.. :/
221
- #
222
- # it says:
223
- # File "dist-packages/varnish.py", line 152, in __init__
224
- # self._setup_mmaudio()
225
- # File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
226
- # net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
227
- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
228
- # File "dist-packages/torch/serialization.py", line 1384, in load
229
- # return _legacy_load(
230
- # ^^^^^^^^^^^^^
231
- # File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
232
- # magic_number = pickle_module.load(f, **pickle_load_args)
233
- # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
234
- # _pickle.UnpicklingError: invalid load key, '<'.
235
- enable_mmaudio=False,
236
  )
237
-
238
- # Determine if TeaCache is already installed or not
239
- self.text_to_video_teacache = False
240
- self.image_to_video_teacache = False
241
-
242
-
243
- async def process_frames(
244
- self,
245
- frames: torch.Tensor,
246
- config: GenerationConfig
247
- ) -> tuple[str, dict]:
248
- """Post-process generated frames using Varnish
249
 
250
- Args:
251
- frames: Generated video frames tensor
252
- config: Generation configuration
253
-
254
- Returns:
255
- Tuple of (video data URI, metadata dictionary)
256
- """
 
 
 
257
  try:
258
- # Process video with Varnish
259
- result = await self.varnish(
260
- input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
261
- fps=config.fps, # this is the FPS of the final output video. This number can be used by Varnish to calculate the duration of a clip ((using frames * factor) / fps etc)
262
- double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
263
- super_resolution=config.super_resolution, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
264
- grain_amount=config.grain_amount,
265
- enable_audio=config.enable_audio,
266
- audio_prompt=config.audio_prompt,
267
- audio_negative_prompt=config.audio_negative_prompt,
268
- )
269
-
270
- # Convert to data URI
271
- video_uri = await result.write(type="data-uri", quality=config.quality)
272
 
273
- # Collect metadata
274
- metadata = {
275
- "width": result.metadata.width,
276
- "height": result.metadata.height,
277
- "num_frames": result.metadata.frame_count,
278
- "fps": result.metadata.fps,
279
- "duration": result.metadata.duration,
280
- "seed": config.seed,
281
- }
282
 
283
- return video_uri, metadata
284
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  except Exception as e:
286
- logger.error(f"Error in process_frames: {str(e)}")
287
- raise RuntimeError(f"Failed to process frames: {str(e)}")
288
-
289
-
 
290
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
291
- """Process incoming requests for video generation
292
 
293
  Args:
294
- data: Request data containing:
295
- - inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
296
- - parameters (dict):
297
- - prompt (required, string): list of concepts to keep in the video.
298
- - negative_prompt (optional, string): list of concepts to ignore in the video.
299
- - width (optional, int, default to 768): width, or horizontal size in pixels.
300
- - height (optional, int, default to 512): height, or vertical size in pixels.
301
- - input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
302
- - num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
303
- - guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
304
- - num_inference_steps (optional, int, default to 50): number of inference steps
305
- - seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
306
- - fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
307
- - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
308
- - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
309
- - grain_amount (optional, float): amount of film grain to add to the output video
310
- - enable_audio (optional, bool): automatically generate an audio track
311
- - audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
312
- - audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
313
- - quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
314
- - enable_teacache (optional, bool, default to False): Generate faster at the cost of a slight quality loss
315
- - teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
316
- - enable_enhance_a_video (optional, bool, default to False): enable the enhance_a_video optimization
317
- - enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
318
- - lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
319
- - lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
320
- - lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
321
  Returns:
322
- Dictionary containing:
323
- - video: Base64 encoded MP4 data URI
324
- - content-type: MIME type
325
- - metadata: Generation metadata
326
  """
327
- inputs = data.get("inputs", dict())
328
- #print(inputs)
329
 
330
- input_prompt = inputs.get("prompt", "")
331
- input_image = inputs.get("image")
 
 
 
 
 
 
 
332
 
333
- params = data.get("parameters", dict())
334
-
335
- if not input_image and not input_prompt:
336
  raise ValueError("Either prompt or image must be provided")
337
-
338
- #logger.debug(f"Raw parameters:")
339
- # pprint.pprint(params)
340
-
341
  # Create and validate configuration
342
  config = GenerationConfig(
343
  # general content settings
344
  prompt=input_prompt,
345
  negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
346
-
347
- # video model settings (will be used during generation of the initial raw video clip)
348
  width=params.get("width", GenerationConfig.width),
349
  height=params.get("height", GenerationConfig.height),
350
  input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
351
  num_frames=params.get("num_frames", GenerationConfig.num_frames),
352
  guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
353
  num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
354
-
 
 
 
 
 
 
 
 
 
 
355
  # reproducible generation settings
356
  seed=params.get("seed", GenerationConfig.seed),
357
 
358
- # varnish settings (will be used for post-processing after the raw video clip has been generated)
359
- fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
360
- double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
361
- super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
362
  grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
363
  enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
364
  audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
365
  audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
366
  quality=params.get("quality", GenerationConfig.quality),
367
 
368
- # TeaCache settings
369
- enable_teacache=params.get("enable_teacache", False),
370
-
371
- # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
372
- teacache_threshold=params.get("teacache_threshold", 0.05),
373
-
374
 
375
- # Add enhance-a-video settings
376
- enable_enhance_a_video=params.get("enable_enhance_a_video", False),
377
- enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
378
-
379
- # LoRA settings
380
- lora_model_name=params.get("lora_model_name", ""),
381
- lora_model_weight_file=params.get("lora_model_weight_file", ""),
382
- lora_model_trigger=params.get("lora_model_trigger", ""),
383
  ).validate_and_adjust()
384
 
385
- #logger.debug(f"Global request settings:")
386
- #pprint.pprint(config)
387
-
388
  try:
389
- with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
390
- # Set random seeds
391
  random.seed(config.seed)
392
  np.random.seed(config.seed)
393
  torch.manual_seed(config.seed)
394
- generator = torch.Generator(device='cuda')
395
- generator = generator.manual_seed(config.seed)
396
-
397
- # Configure enhance-a-video
398
- #if config.enable_enhance_a_video:
399
- # enable_enhance()
400
- # set_enhance_weight(config.enhance_a_video_weight)
401
 
402
- # Prepare generation parameters for the video model (we omit params that are destined to Varnish, or things like the seed which is set externally)
403
- generation_kwargs = {
404
- # general content settings
405
- "prompt": config.prompt,
406
- "negative_prompt": config.negative_prompt,
407
-
408
- # video model settings (will be used during generation of the initial raw video clip)
409
- "width": config.width,
410
- "height": config.height,
411
- "num_frames": config.num_frames,
412
- "guidance_scale": config.guidance_scale,
413
- "num_inference_steps": config.num_inference_steps,
414
-
415
- # constants
416
- "output_type": "pt",
417
- "generator": generator,
418
-
419
- # Timestep for decoding VAE noise: the timestep at which generated video is decoded
420
- "decode_timestep": 0.05,
421
-
422
- # Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
423
- "decode_noise_scale": 0.025,
424
- }
425
- #logger.info(f"Video model generation settings:")
426
- #pprint.pprint(generation_kwargs)
427
-
428
- # Handle LoRA loading/unloading
429
- if hasattr(self, '_current_lora_model'):
430
- if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
431
- # Unload previous LoRA if it exists and is different
432
- if hasattr(self.text_to_video, 'unload_lora_weights'):
433
- print("Unloading LoRA weights for the text_to_video pipeline..")
434
- self.text_to_video.unload_lora_weights()
435
-
436
- if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
437
- print("Unloading LoRA weights for the image_to_video pipeline..")
438
- self.image_to_video.unload_lora_weights()
439
-
440
- if config.lora_model_name:
441
- # Load new LoRA
442
- if hasattr(self.text_to_video, 'load_lora_weights'):
443
- print("Loading LoRA weights for the text_to_video pipeline..")
444
- self.text_to_video.load_lora_weights(
445
- config.lora_model_name,
446
- weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
447
- token=hf_token,
448
- )
449
- if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
450
- print("Loading LoRA weights for the image_to_video pipeline..")
451
- self.image_to_video.load_lora_weights(
452
- config.lora_model_name,
453
- weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
454
- token=hf_token,
455
  )
456
- self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
457
-
458
- # Modify prompt if trigger word is provided
459
- if config.lora_model_trigger:
460
- generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
461
-
462
- #enhance_a_video_config = EnhanceAVideoConfig(
463
- # weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
464
- # # doing some testing
465
- # num_frames_callback=lambda: (8 + 1),
466
- # # num_frames_callback=lambda: config.num_frames,
467
- # # num_frames_callback=lambda: (config.num_frames - 1),
468
- #
469
- # _attention_type=1
470
- #)
471
 
472
- # Check if image-to-video generation is requested
473
- if support_image_prompt and input_image:
474
- processed_image = process_input_image(
475
- input_image,
476
- config.width,
477
- config.height,
478
- config.input_image_quality,
479
- )
480
- generation_kwargs["image"] = processed_image
481
- # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
482
- # apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
483
- frames = self.image_to_video(**generation_kwargs).frames
484
- else:
485
- # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
486
- # apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
487
- frames = self.text_to_video(**generation_kwargs).frames
488
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  try:
490
  loop = asyncio.get_event_loop()
491
  except RuntimeError:
492
  loop = asyncio.new_event_loop()
493
  asyncio.set_event_loop(loop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
- video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
496
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
  torch.cuda.empty_cache()
498
- torch.cuda.reset_peak_memory_stats()
499
  gc.collect()
500
 
501
  return {
@@ -503,8 +663,10 @@ class EndpointHandler:
503
  "content-type": "video/mp4",
504
  "metadata": metadata
505
  }
506
-
507
  except Exception as e:
508
- message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
509
- print(message)
510
- raise RuntimeError(message)
 
 
 
1
  from dataclasses import dataclass
2
  from pathlib import Path
 
 
 
 
 
 
3
  import logging
4
+ import base64
5
  import random
6
+ import gc
7
  import os
8
  import numpy as np
9
  import torch
10
+ from typing import Dict, Any, Optional, List, Union, Tuple
11
+ import json
12
+ from safetensors import safe_open
13
+
14
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
15
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
16
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
17
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter
18
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
19
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
20
+ from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
21
 
22
  from varnish import Varnish
23
  from varnish.utils import is_truthy, process_input_image
 
26
  logging.basicConfig(level=logging.INFO)
27
  logger = logging.getLogger(__name__)
28
 
 
29
  # Get token from environment
30
  hf_token = os.getenv("HF_API_TOKEN")
31
 
32
  # Constraints
33
  MAX_LARGE_SIDE = 1280
34
+ MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
35
+ MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
36
 
37
  # Check environment variable for pipeline support
38
  support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
 
46
  negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
47
 
48
  # video model settings (will be used during generation of the initial raw video clip)
49
+ width: int = 1216 # 768
50
+ height: int = 704 # 416
 
 
51
 
52
  # this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
53
  # after a quick benchmark using the value 70 seems like a sweet spot
 
58
  # visual glitches appear after about 169 frames, so we don't need more actually
59
  num_frames: int = (8 * 14) + 1
60
 
61
+ # values between 3.0 and 4.0 are nice
62
+ guidance_scale: float = 3.0
63
 
64
  num_inference_steps: int = 8
65
 
 
67
  seed: int = -1 # -1 means random seed
68
 
69
  # varnish settings (will be used for post-processing after the raw video clip has been generated
70
+ fps: int = 30 # FPS of the final video (only applied at the very end, when converting to mp4)
71
+ double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
72
+ super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
73
 
74
+ grain_amount: float = 0.0 # be careful, adding film grain can negatively impact video compression
75
 
76
  # audio settings
77
  enable_audio: bool = False # Whether to generate audio
78
  audio_prompt: str = "" # Text prompt for audio generation
79
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
80
 
81
  # The range of the CRF scale is 0–51, where:
82
  # 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
 
88
  # The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
89
  quality: int = 18
90
 
91
+ # STG (Spatiotemporal Guidance) settings
92
+ stg_scale: float = 0.0
93
+ stg_rescale: float = 1.0
94
+ stg_mode: str = "attention_values" # Can be "attention_values", "attention_skip", "residual", or "transformer_block"
95
 
96
+ # VAE noise augmentation
97
+ decode_timestep: float = 0.05
98
+ decode_noise_scale: float = 0.025
99
 
100
+ # Other advanced settings
101
+ image_cond_noise_scale: float = 0.15
102
+ mixed_precision: bool = True # Use mixed precision for inference
103
+ stochastic_sampling: bool = True # Use stochastic sampling
104
+
105
+ # Sampling settings
106
+ sampler: Optional[str] = "from_checkpoint" # "uniform" or "linear-quadratic" or None (use default from checkpoint)
107
+
108
+ # Prompt enhancement
109
+ enhance_prompt: bool = False # Whether to enhance the prompt using an LLM
110
+ prompt_enhancement_words_threshold: int = 50 # Enhance prompt only if it has fewer words than this
111
 
112
  def validate_and_adjust(self) -> 'GenerationConfig':
113
  """Validate and adjust parameters to meet constraints"""
 
115
  if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
116
  (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
117
  # For other resolutions, ensure total pixels don't exceed max
118
+ MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
119
 
120
  # If total pixels exceed maximum, scale down proportionally
121
  total_pixels = self.width * self.height
 
135
  # Set random seed if not specified
136
  if self.seed == -1:
137
  self.seed = random.randint(0, 2**32 - 1)
 
 
138
 
139
+ # Set up STG parameters
140
+ if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values":
141
+ self.stg_mode = "attention_values"
142
+ elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip":
143
+ self.stg_mode = "attention_skip"
144
+ elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual":
145
+ self.stg_mode = "residual"
146
+ elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block":
147
+ self.stg_mode = "transformer_block"
148
+
149
+ # Check if we should enhance the prompt
150
+ if self.enhance_prompt and self.prompt:
151
+ prompt_word_count = len(self.prompt.split())
152
+ if prompt_word_count >= self.prompt_enhancement_words_threshold:
153
+ logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.")
154
+ self.enhance_prompt = False
155
 
156
+ return self
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ def load_image_to_tensor_with_resize_and_crop(
159
+ image_input: Union[str, bytes],
160
+ target_height: int = 704,
161
+ target_width: int = 1216,
162
+ quality: int = 100
163
+ ) -> torch.Tensor:
164
+ """Load and process an image into a tensor.
165
+
166
+ Args:
167
+ image_input: Either a file path (str) or image data (bytes)
168
+ target_height: Desired height of output tensor
169
+ target_width: Desired width of output tensor
170
+ quality: JPEG quality to use when re-encoding (to simulate lower quality images)
171
+ """
172
+ from PIL import Image
173
+ import io
174
+ import numpy as np
175
+
176
+ # Handle base64 data URI
177
+ if isinstance(image_input, str) and image_input.startswith('data:'):
178
+ header, encoded = image_input.split(",", 1)
179
+ image_data = base64.b64decode(encoded)
180
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
181
+ # Handle raw bytes
182
+ elif isinstance(image_input, bytes):
183
+ image = Image.open(io.BytesIO(image_input)).convert("RGB")
184
+ # Handle file path
185
+ elif isinstance(image_input, str):
186
+ image = Image.open(image_input).convert("RGB")
187
+ else:
188
+ raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
189
+
190
+ # Apply JPEG compression if quality < 100 (to simulate a video frame)
191
+ if quality < 100:
192
+ buffer = io.BytesIO()
193
+ image.save(buffer, format="JPEG", quality=quality)
194
+ buffer.seek(0)
195
+ image = Image.open(buffer).convert("RGB")
196
+
197
+ input_width, input_height = image.size
198
+ aspect_ratio_target = target_width / target_height
199
+ aspect_ratio_frame = input_width / input_height
200
+ if aspect_ratio_frame > aspect_ratio_target:
201
+ new_width = int(input_height * aspect_ratio_target)
202
+ new_height = input_height
203
+ x_start = (input_width - new_width) // 2
204
+ y_start = 0
205
+ else:
206
+ new_width = input_width
207
+ new_height = int(input_width / aspect_ratio_target)
208
+ x_start = 0
209
+ y_start = (input_height - new_height) // 2
210
+
211
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
212
+ image = image.resize((target_width, target_height))
213
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
214
+ frame_tensor = (frame_tensor / 127.5) - 1.0
215
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
216
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
217
+
218
+ def calculate_padding(
219
+ source_height: int, source_width: int, target_height: int, target_width: int
220
+ ) -> tuple[int, int, int, int]:
221
+ """Calculate padding to reach target dimensions"""
222
+ # Calculate total padding needed
223
+ pad_height = target_height - source_height
224
+ pad_width = target_width - source_width
225
+
226
+ # Calculate padding for each side
227
+ pad_top = pad_height // 2
228
+ pad_bottom = pad_height - pad_top # Handles odd padding
229
+ pad_left = pad_width // 2
230
+ pad_right = pad_width - pad_left # Handles odd padding
231
+
232
+ # Return padded tensor
233
+ # Padding format is (left, right, top, bottom)
234
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
235
+ return padding
236
+
237
+ def prepare_conditioning(
238
+ conditioning_media_paths: List[str],
239
+ conditioning_strengths: List[float],
240
+ conditioning_start_frames: List[int],
241
+ height: int,
242
+ width: int,
243
+ num_frames: int,
244
+ input_image_quality: int = 100,
245
+ pipeline: Optional[LTXVideoPipeline] = None,
246
+ ) -> Optional[List[ConditioningItem]]:
247
+ """Prepare conditioning items based on input media paths and their parameters"""
248
+ conditioning_items = []
249
+ for path, strength, start_frame in zip(
250
+ conditioning_media_paths, conditioning_strengths, conditioning_start_frames
251
+ ):
252
+ # Load and process the conditioning image
253
+ frame_tensor = load_image_to_tensor_with_resize_and_crop(
254
+ path, height, width, quality=input_image_quality
255
+ )
256
+
257
+ # Trim frame count if needed
258
+ if pipeline:
259
+ frame_count = 1 # For image inputs, it's always 1
260
+ frame_count = pipeline.trim_conditioning_sequence(
261
+ start_frame, frame_count, num_frames
262
+ )
263
+
264
+ conditioning_items.append(
265
+ ConditioningItem(frame_tensor, start_frame, strength)
266
  )
267
 
268
+ return conditioning_items
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ def create_ltx_video_pipeline(
271
+ config: GenerationConfig,
272
+ device: str = "cuda"
273
+ ) -> LTXVideoPipeline:
274
+ """Create and configure the LTX video pipeline"""
 
275
 
276
+ ckpt_path = "/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors"
277
+
278
+ # Get allowed inference steps from config if available
279
+ allowed_inference_steps = None
280
+
281
+ assert os.path.exists(
282
+ ckpt_path
283
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
284
+
285
+ with safe_open(ckpt_path, framework="pt") as f:
286
+ metadata = f.metadata()
287
+ config_str = metadata.get("config")
288
+ configs = json.loads(config_str)
289
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
290
+
291
+ # Initialize model components
292
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
293
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
294
+
295
+ # Use constructor if sampler is specified, otherwise use from_pretrained
296
+ if config.sampler:
297
+ scheduler = RectifiedFlowScheduler(
298
+ sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
299
+ )
300
+ else:
301
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
302
+
303
+ text_encoder = T5EncoderModel.from_pretrained("/repository/text_encoder")
304
+ patchifier = SymmetricPatchifier(patch_size=1)
305
+ tokenizer = T5Tokenizer.from_pretrained("/repository/tokenizer")
306
+
307
+ # Move models to the correct device
308
+ vae = vae.to(device)
309
+ transformer = transformer.to(device)
310
+ text_encoder = text_encoder.to(device)
311
+
312
+ # Set up precision
313
+ vae = vae.to(torch.bfloat16)
314
+ transformer = transformer.to(torch.bfloat16)
315
+ text_encoder = text_encoder.to(torch.bfloat16)
316
+
317
+ # Initialize prompt enhancer components if needed
318
+ prompt_enhancer_components = {
319
+ "prompt_enhancer_image_caption_model": None,
320
+ "prompt_enhancer_image_caption_processor": None,
321
+ "prompt_enhancer_llm_model": None,
322
+ "prompt_enhancer_llm_tokenizer": None
323
+ }
324
+
325
+ if config.enhance_prompt:
326
+ try:
327
+ # Use default models or ones specified by config
328
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
329
+ "MiaoshouAI/Florence-2-large-PromptGen-v2.0",
330
+ trust_remote_code=True
331
+ )
332
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
333
+ "MiaoshouAI/Florence-2-large-PromptGen-v2.0",
334
+ trust_remote_code=True
335
+ )
336
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
337
+ "unsloth/Llama-3.2-3B-Instruct",
338
+ torch_dtype="bfloat16",
339
+ )
340
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
341
+ "unsloth/Llama-3.2-3B-Instruct",
342
+ )
343
 
344
+ prompt_enhancer_components = {
345
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
346
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
347
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
348
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer
349
+ }
350
+ except Exception as e:
351
+ logger.warning(f"Failed to load prompt enhancer models: {e}")
352
+ config.enhance_prompt = False
353
+
354
+ # Construct the pipeline
355
+ pipeline = LTXVideoPipeline(
356
+ transformer=transformer,
357
+ patchifier=patchifier,
358
+ text_encoder=text_encoder,
359
+ tokenizer=tokenizer,
360
+ scheduler=scheduler,
361
+ vae=vae,
362
+ allowed_inference_steps=allowed_inference_steps,
363
+ **prompt_enhancer_components
364
+ )
365
+
366
+ return pipeline
 
367
 
368
+ class EndpointHandler:
369
+ """Handler for the LTX Video endpoint"""
370
+
371
+ def __init__(self, model_path: str = "/repository/"):
372
+ """Initialize the endpoint handler
373
+
374
+ Args:
375
+ model_path: Path to model weights (not used, as weights are in current directory)
376
+ """
377
+ # Enable TF32 for potential speedup on Ampere GPUs
378
+ torch.backends.cuda.matmul.allow_tf32 = True
379
+
380
  # Initialize Varnish for post-processing
381
  self.varnish = Varnish(
382
  device="cuda",
383
  model_base_dir="/repository/varnish",
384
+ enable_mmaudio=False, # Disable audio generation for now, since it is broken
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  )
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ # The actual LTX pipeline will be loaded during inference to save memory
388
+ self.pipeline = None
389
+
390
+ # Perform warm-up inference
391
+ logger.info("Performing warm-up inference...")
392
+ self._warmup()
393
+ logger.info("Warm-up completed!")
394
+
395
+ def _warmup(self):
396
+ """Perform a warm-up inference to prepare the model for future requests"""
397
  try:
398
+ # Create a simple test configuration
399
+ test_config = GenerationConfig(
400
+ prompt="an astronaut is riding a cow in the desert, during golden hour",
401
+ negative_prompt="worst quality, lowres",
402
+ width=768, # Using smaller resolution for faster warm-up
403
+ height=416,
404
+ num_frames=33, # Just enough frames for a valid video
405
+ guidance_scale=1.0,
406
+ num_inference_steps=4, # Fewer steps for faster warm-up
407
+ seed=42, # Fixed seed for consistent warm-up
408
+ fps=16, # Lower FPS for faster processing
409
+ enable_audio=False, # No audio for warm-up
410
+ mixed_precision=True,
411
+ ).validate_and_adjust()
412
 
413
+ # Create the pipeline if it doesn't exist
414
+ if self.pipeline is None:
415
+ self.pipeline = create_ltx_video_pipeline(test_config)
 
 
 
 
 
 
416
 
417
+ # Run a quick inference
418
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
419
+ # Set seeds for reproducibility
420
+ random.seed(test_config.seed)
421
+ np.random.seed(test_config.seed)
422
+ torch.manual_seed(test_config.seed)
423
+ generator = torch.Generator(device='cuda').manual_seed(test_config.seed)
424
+
425
+ # Generate video
426
+ result = self.pipeline(
427
+ height=test_config.height,
428
+ width=test_config.width,
429
+ num_frames=test_config.num_frames,
430
+ frame_rate=test_config.fps,
431
+ prompt=test_config.prompt,
432
+ negative_prompt=test_config.negative_prompt,
433
+ guidance_scale=test_config.guidance_scale,
434
+ num_inference_steps=test_config.num_inference_steps,
435
+ generator=generator,
436
+ output_type="pt",
437
+ mixed_precision=test_config.mixed_precision,
438
+ is_video=True,
439
+ vae_per_channel_normalize=True,
440
+ )
441
+
442
+ # Just get the frames without full processing (faster warm-up)
443
+ frames = result.images
444
+
445
+ # Clean up
446
+ del result
447
+ torch.cuda.empty_cache()
448
+ gc.collect()
449
+
450
+ logger.info(f"Warm-up successful! Generated {frames.shape[2]} frames at {frames.shape[3]}x{frames.shape[4]}")
451
+
452
  except Exception as e:
453
+ # Log the error but don't fail initialization
454
+ import traceback
455
+ error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}"
456
+ logger.warning(error_message)
457
+
458
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
459
+ """Process inference requests
460
 
461
  Args:
462
+ data: Request data containing inputs and parameters
463
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  Returns:
465
+ Dictionary with generated video and metadata
 
 
 
466
  """
467
+ # Extract inputs and parameters
468
+ inputs = data.get("inputs", {})
469
 
470
+ # Support both formats:
471
+ # 1. {"inputs": {"prompt": "...", "image": "..."}}
472
+ # 2. {"inputs": "..."} (prompt only)
473
+ if isinstance(inputs, str):
474
+ input_prompt = inputs
475
+ input_image = None
476
+ else:
477
+ input_prompt = inputs.get("prompt", "")
478
+ input_image = inputs.get("image")
479
 
480
+ params = data.get("parameters", {})
481
+
482
+ if not input_prompt and not input_image:
483
  raise ValueError("Either prompt or image must be provided")
484
+
 
 
 
485
  # Create and validate configuration
486
  config = GenerationConfig(
487
  # general content settings
488
  prompt=input_prompt,
489
  negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
490
+
491
+ # video model settings
492
  width=params.get("width", GenerationConfig.width),
493
  height=params.get("height", GenerationConfig.height),
494
  input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
495
  num_frames=params.get("num_frames", GenerationConfig.num_frames),
496
  guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
497
  num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
498
+
499
+ # STG settings
500
+ stg_scale=params.get("stg_scale", GenerationConfig.stg_scale),
501
+ stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale),
502
+ stg_mode=params.get("stg_mode", GenerationConfig.stg_mode),
503
+
504
+ # VAE noise settings
505
+ decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep),
506
+ decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale),
507
+ image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale),
508
+
509
  # reproducible generation settings
510
  seed=params.get("seed", GenerationConfig.seed),
511
 
512
+ # varnish settings
513
+ fps=params.get("fps", GenerationConfig.fps),
514
+ double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
515
+ super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
516
  grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
517
  enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
518
  audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
519
  audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
520
  quality=params.get("quality", GenerationConfig.quality),
521
 
522
+ # advanced settings
523
+ mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
524
+ stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling),
525
+ sampler=params.get("sampler", GenerationConfig.sampler),
 
 
526
 
527
+ # prompt enhancement
528
+ enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt),
529
+ prompt_enhancement_words_threshold=params.get(
530
+ "prompt_enhancement_words_threshold",
531
+ GenerationConfig.prompt_enhancement_words_threshold
532
+ ),
 
 
533
  ).validate_and_adjust()
534
 
 
 
 
535
  try:
536
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
537
+ # Set random seeds for reproducibility
538
  random.seed(config.seed)
539
  np.random.seed(config.seed)
540
  torch.manual_seed(config.seed)
541
+ generator = torch.Generator(device='cuda').manual_seed(config.seed)
 
 
 
 
 
 
542
 
543
+ # Create pipeline if not already created
544
+ if self.pipeline is None:
545
+ self.pipeline = create_ltx_video_pipeline(config)
546
+
547
+ # Prepare conditioning items if an image is provided
548
+ conditioning_items = None
549
+ if input_image:
550
+ conditioning_items = [
551
+ ConditioningItem(
552
+ load_image_to_tensor_with_resize_and_crop(
553
+ input_image,
554
+ config.height,
555
+ config.width,
556
+ quality=config.input_image_quality
557
+ ),
558
+ 0, # Start frame
559
+ 1.0 # Conditioning strength
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  )
561
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
+ # Set up spatiotemporal guidance strategy
564
+ if config.stg_mode == "attention_values":
565
+ skip_layer_strategy = SkipLayerStrategy.AttentionValues
566
+ elif config.stg_mode == "attention_skip":
567
+ skip_layer_strategy = SkipLayerStrategy.AttentionSkip
568
+ elif config.stg_mode == "residual":
569
+ skip_layer_strategy = SkipLayerStrategy.Residual
570
+ elif config.stg_mode == "transformer_block":
571
+ skip_layer_strategy = SkipLayerStrategy.TransformerBlock
572
+
573
+ # Generate video with LTX pipeline
574
+ result = self.pipeline(
575
+ height=config.height,
576
+ width=config.width,
577
+ num_frames=config.num_frames,
578
+ frame_rate=config.fps,
579
+ prompt=config.prompt,
580
+ negative_prompt=config.negative_prompt,
581
+ guidance_scale=config.guidance_scale,
582
+ num_inference_steps=config.num_inference_steps,
583
+ generator=generator,
584
+ output_type="pt", # Return as PyTorch tensor
585
+ skip_layer_strategy=skip_layer_strategy,
586
+ stg_scale=config.stg_scale,
587
+ do_rescaling=config.stg_rescale != 1.0,
588
+ rescaling_scale=config.stg_rescale,
589
+ conditioning_items=conditioning_items,
590
+ decode_timestep=config.decode_timestep,
591
+ decode_noise_scale=config.decode_noise_scale,
592
+ image_cond_noise_scale=config.image_cond_noise_scale,
593
+ mixed_precision=config.mixed_precision,
594
+ is_video=True,
595
+ vae_per_channel_normalize=True,
596
+ stochastic_sampling=config.stochastic_sampling,
597
+ enhance_prompt=config.enhance_prompt,
598
+ )
599
+
600
+ # Get the generated frames
601
+ frames = result.images
602
+
603
+ # FIX: Convert LTX output format to varnish-compatible format
604
+ # LTX outputs: [batch, channels, frames, height, width]
605
+ # We need: [frames, channels, height, width] for varnish
606
+ frames = frames.squeeze(0) # Remove batch: [channels, frames, height, width]
607
+ frames = frames.permute(1, 0, 2, 3) # Reorder to: [frames, channels, height, width]
608
+
609
+ # Convert from [0, 1] to [0, 255] range
610
+ frames = frames * 255.0
611
+
612
+ # Convert to uint8
613
+ frames = frames.to(torch.uint8)
614
+
615
+ # Process the generated frames with Varnish
616
+ import asyncio
617
  try:
618
  loop = asyncio.get_event_loop()
619
  except RuntimeError:
620
  loop = asyncio.new_event_loop()
621
  asyncio.set_event_loop(loop)
622
+
623
+ # Process with Varnish for post-processing
624
+ varnish_result = loop.run_until_complete(
625
+ self.varnish(
626
+ frames,
627
+ fps=config.fps,
628
+ double_num_frames=config.double_num_frames,
629
+ super_resolution=config.super_resolution,
630
+ grain_amount=config.grain_amount,
631
+ enable_audio=config.enable_audio,
632
+ audio_prompt=config.audio_prompt or config.prompt,
633
+ audio_negative_prompt=config.audio_negative_prompt,
634
+ )
635
+ )
636
 
637
+ # Get the final video as a data URI
638
+ video_uri = loop.run_until_complete(
639
+ varnish_result.write(
640
+ type="data-uri",
641
+ quality=config.quality
642
+ )
643
+ )
644
+
645
+ # Prepare metadata about the generated video
646
+ metadata = {
647
+ "width": varnish_result.metadata.width,
648
+ "height": varnish_result.metadata.height,
649
+ "num_frames": varnish_result.metadata.frame_count,
650
+ "fps": varnish_result.metadata.fps,
651
+ "duration": varnish_result.metadata.duration,
652
+ "seed": config.seed,
653
+ "prompt": config.prompt,
654
+ }
655
+
656
+ # Clean up to prevent CUDA OOM errors
657
+ del result
658
  torch.cuda.empty_cache()
 
659
  gc.collect()
660
 
661
  return {
 
663
  "content-type": "video/mp4",
664
  "metadata": metadata
665
  }
666
+
667
  except Exception as e:
668
+ # Log the error and reraise
669
+ import traceback
670
+ error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
671
+ logger.error(error_message)
672
+ raise RuntimeError(error_message)