Update handler_LAST_WORKING.py
Browse files- handler_LAST_WORKING.py +506 -344
handler_LAST_WORKING.py
CHANGED
@@ -1,24 +1,23 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from pathlib import Path
|
3 |
-
import pathlib
|
4 |
-
from typing import Dict, Any, Optional, Tuple
|
5 |
-
import asyncio
|
6 |
-
import base64
|
7 |
-
import io
|
8 |
-
import pprint
|
9 |
import logging
|
|
|
10 |
import random
|
11 |
-
import
|
12 |
import os
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
-
import
|
16 |
-
|
17 |
-
from
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
from
|
|
|
|
|
|
|
|
|
22 |
|
23 |
from varnish import Varnish
|
24 |
from varnish.utils import is_truthy, process_input_image
|
@@ -27,14 +26,13 @@ from varnish.utils import is_truthy, process_input_image
|
|
27 |
logging.basicConfig(level=logging.INFO)
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
30 |
-
|
31 |
# Get token from environment
|
32 |
hf_token = os.getenv("HF_API_TOKEN")
|
33 |
|
34 |
# Constraints
|
35 |
MAX_LARGE_SIDE = 1280
|
36 |
-
MAX_SMALL_SIDE = 768
|
37 |
-
MAX_FRAMES =
|
38 |
|
39 |
# Check environment variable for pipeline support
|
40 |
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
@@ -48,10 +46,8 @@ class GenerationConfig:
|
|
48 |
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
|
49 |
|
50 |
# video model settings (will be used during generation of the initial raw video clip)
|
51 |
-
|
52 |
-
|
53 |
-
height: int = 416
|
54 |
-
|
55 |
|
56 |
# this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
|
57 |
# after a quick benchmark using the value 70 seems like a sweet spot
|
@@ -62,8 +58,8 @@ class GenerationConfig:
|
|
62 |
# visual glitches appear after about 169 frames, so we don't need more actually
|
63 |
num_frames: int = (8 * 14) + 1
|
64 |
|
65 |
-
#
|
66 |
-
guidance_scale: float =
|
67 |
|
68 |
num_inference_steps: int = 8
|
69 |
|
@@ -71,16 +67,16 @@ class GenerationConfig:
|
|
71 |
seed: int = -1 # -1 means random seed
|
72 |
|
73 |
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
74 |
-
fps: int = 30
|
75 |
-
double_num_frames: bool = False
|
76 |
-
super_resolution: bool = False
|
77 |
|
78 |
-
grain_amount: float = 0.0
|
79 |
|
80 |
# audio settings
|
81 |
enable_audio: bool = False # Whether to generate audio
|
82 |
audio_prompt: str = "" # Text prompt for audio generation
|
83 |
-
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
84 |
|
85 |
# The range of the CRF scale is 0–51, where:
|
86 |
# 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
|
@@ -92,18 +88,26 @@ class GenerationConfig:
|
|
92 |
# The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
93 |
quality: int = 18
|
94 |
|
95 |
-
#
|
96 |
-
|
97 |
-
|
|
|
98 |
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
|
103 |
-
#
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
109 |
"""Validate and adjust parameters to meet constraints"""
|
@@ -111,7 +115,7 @@ class GenerationConfig:
|
|
111 |
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
|
112 |
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
|
113 |
# For other resolutions, ensure total pixels don't exceed max
|
114 |
-
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE
|
115 |
|
116 |
# If total pixels exceed maximum, scale down proportionally
|
117 |
total_pixels = self.width * self.height
|
@@ -131,371 +135,527 @@ class GenerationConfig:
|
|
131 |
# Set random seed if not specified
|
132 |
if self.seed == -1:
|
133 |
self.seed = random.randint(0, 2**32 - 1)
|
134 |
-
|
135 |
-
return self
|
136 |
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
-
|
141 |
-
"""Initialize the handler with LTX models and Varnish
|
142 |
-
|
143 |
-
Args:
|
144 |
-
model_path: Path to LTX model weights
|
145 |
-
"""
|
146 |
-
print("EndpointHandler.__init__(): initializing..")
|
147 |
-
# Enable TF32 for potential speedup on Ampere GPUs
|
148 |
-
#torch.backends.cuda.matmul.allow_tf32 = True
|
149 |
-
|
150 |
-
# use distilled weights
|
151 |
-
model_path = Path("/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors")
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
)
|
157 |
|
158 |
-
|
159 |
-
vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
|
160 |
-
|
161 |
-
if support_image_prompt:
|
162 |
-
print("EndpointHandler.__init__(): initializing LTXImageToVideoPipeline..")
|
163 |
-
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
|
164 |
-
"/repository",
|
165 |
-
transformer=transformer,
|
166 |
-
vae=vae,
|
167 |
-
torch_dtype=torch.bfloat16
|
168 |
-
).to("cuda")
|
169 |
-
|
170 |
-
#apply_teacache(self.image_to_video)
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
#self.image_to_video = torch.compile(self.image_to_video, mode="reduce-overhead", fullgraph=True)
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
# Initialize Varnish for post-processing
|
215 |
self.varnish = Varnish(
|
216 |
device="cuda",
|
217 |
model_base_dir="/repository/varnish",
|
218 |
-
|
219 |
-
# there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
|
220 |
-
# not sure how to fix that.. :/
|
221 |
-
#
|
222 |
-
# it says:
|
223 |
-
# File "dist-packages/varnish.py", line 152, in __init__
|
224 |
-
# self._setup_mmaudio()
|
225 |
-
# File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
|
226 |
-
# net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
|
227 |
-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
228 |
-
# File "dist-packages/torch/serialization.py", line 1384, in load
|
229 |
-
# return _legacy_load(
|
230 |
-
# ^^^^^^^^^^^^^
|
231 |
-
# File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
|
232 |
-
# magic_number = pickle_module.load(f, **pickle_load_args)
|
233 |
-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
234 |
-
# _pickle.UnpicklingError: invalid load key, '<'.
|
235 |
-
enable_mmaudio=False,
|
236 |
)
|
237 |
-
|
238 |
-
# Determine if TeaCache is already installed or not
|
239 |
-
self.text_to_video_teacache = False
|
240 |
-
self.image_to_video_teacache = False
|
241 |
-
|
242 |
-
|
243 |
-
async def process_frames(
|
244 |
-
self,
|
245 |
-
frames: torch.Tensor,
|
246 |
-
config: GenerationConfig
|
247 |
-
) -> tuple[str, dict]:
|
248 |
-
"""Post-process generated frames using Varnish
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
""
|
|
|
|
|
|
|
257 |
try:
|
258 |
-
#
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
|
273 |
-
#
|
274 |
-
|
275 |
-
|
276 |
-
"height": result.metadata.height,
|
277 |
-
"num_frames": result.metadata.frame_count,
|
278 |
-
"fps": result.metadata.fps,
|
279 |
-
"duration": result.metadata.duration,
|
280 |
-
"seed": config.seed,
|
281 |
-
}
|
282 |
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
except Exception as e:
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
290 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
291 |
-
"""Process
|
292 |
|
293 |
Args:
|
294 |
-
data: Request data containing
|
295 |
-
|
296 |
-
- parameters (dict):
|
297 |
-
- prompt (required, string): list of concepts to keep in the video.
|
298 |
-
- negative_prompt (optional, string): list of concepts to ignore in the video.
|
299 |
-
- width (optional, int, default to 768): width, or horizontal size in pixels.
|
300 |
-
- height (optional, int, default to 512): height, or vertical size in pixels.
|
301 |
-
- input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
|
302 |
-
- num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
|
303 |
-
- guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
|
304 |
-
- num_inference_steps (optional, int, default to 50): number of inference steps
|
305 |
-
- seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
|
306 |
-
- fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
|
307 |
-
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
|
308 |
-
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
|
309 |
-
- grain_amount (optional, float): amount of film grain to add to the output video
|
310 |
-
- enable_audio (optional, bool): automatically generate an audio track
|
311 |
-
- audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
|
312 |
-
- audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
|
313 |
-
- quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
314 |
-
- enable_teacache (optional, bool, default to False): Generate faster at the cost of a slight quality loss
|
315 |
-
- teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
|
316 |
-
- enable_enhance_a_video (optional, bool, default to False): enable the enhance_a_video optimization
|
317 |
-
- enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
|
318 |
-
- lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
|
319 |
-
- lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
|
320 |
-
- lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
|
321 |
Returns:
|
322 |
-
Dictionary
|
323 |
-
- video: Base64 encoded MP4 data URI
|
324 |
-
- content-type: MIME type
|
325 |
-
- metadata: Generation metadata
|
326 |
"""
|
327 |
-
|
328 |
-
|
329 |
|
330 |
-
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
-
params = data.get("parameters",
|
334 |
-
|
335 |
-
if not
|
336 |
raise ValueError("Either prompt or image must be provided")
|
337 |
-
|
338 |
-
#logger.debug(f"Raw parameters:")
|
339 |
-
# pprint.pprint(params)
|
340 |
-
|
341 |
# Create and validate configuration
|
342 |
config = GenerationConfig(
|
343 |
# general content settings
|
344 |
prompt=input_prompt,
|
345 |
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
346 |
-
|
347 |
-
# video model settings
|
348 |
width=params.get("width", GenerationConfig.width),
|
349 |
height=params.get("height", GenerationConfig.height),
|
350 |
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
|
351 |
num_frames=params.get("num_frames", GenerationConfig.num_frames),
|
352 |
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
353 |
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
# reproducible generation settings
|
356 |
seed=params.get("seed", GenerationConfig.seed),
|
357 |
|
358 |
-
# varnish settings
|
359 |
-
fps=params.get("fps", GenerationConfig.fps),
|
360 |
-
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
|
361 |
-
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
|
362 |
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
363 |
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
364 |
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
365 |
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
366 |
quality=params.get("quality", GenerationConfig.quality),
|
367 |
|
368 |
-
#
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
teacache_threshold=params.get("teacache_threshold", 0.05),
|
373 |
-
|
374 |
|
375 |
-
#
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
lora_model_weight_file=params.get("lora_model_weight_file", ""),
|
382 |
-
lora_model_trigger=params.get("lora_model_trigger", ""),
|
383 |
).validate_and_adjust()
|
384 |
|
385 |
-
#logger.debug(f"Global request settings:")
|
386 |
-
#pprint.pprint(config)
|
387 |
-
|
388 |
try:
|
389 |
-
with torch.amp.
|
390 |
-
# Set random seeds
|
391 |
random.seed(config.seed)
|
392 |
np.random.seed(config.seed)
|
393 |
torch.manual_seed(config.seed)
|
394 |
-
generator = torch.Generator(device='cuda')
|
395 |
-
generator = generator.manual_seed(config.seed)
|
396 |
-
|
397 |
-
# Configure enhance-a-video
|
398 |
-
#if config.enable_enhance_a_video:
|
399 |
-
# enable_enhance()
|
400 |
-
# set_enhance_weight(config.enhance_a_video_weight)
|
401 |
|
402 |
-
#
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
# Timestep for decoding VAE noise: the timestep at which generated video is decoded
|
420 |
-
"decode_timestep": 0.05,
|
421 |
-
|
422 |
-
# Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
|
423 |
-
"decode_noise_scale": 0.025,
|
424 |
-
}
|
425 |
-
#logger.info(f"Video model generation settings:")
|
426 |
-
#pprint.pprint(generation_kwargs)
|
427 |
-
|
428 |
-
# Handle LoRA loading/unloading
|
429 |
-
if hasattr(self, '_current_lora_model'):
|
430 |
-
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
|
431 |
-
# Unload previous LoRA if it exists and is different
|
432 |
-
if hasattr(self.text_to_video, 'unload_lora_weights'):
|
433 |
-
print("Unloading LoRA weights for the text_to_video pipeline..")
|
434 |
-
self.text_to_video.unload_lora_weights()
|
435 |
-
|
436 |
-
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
|
437 |
-
print("Unloading LoRA weights for the image_to_video pipeline..")
|
438 |
-
self.image_to_video.unload_lora_weights()
|
439 |
-
|
440 |
-
if config.lora_model_name:
|
441 |
-
# Load new LoRA
|
442 |
-
if hasattr(self.text_to_video, 'load_lora_weights'):
|
443 |
-
print("Loading LoRA weights for the text_to_video pipeline..")
|
444 |
-
self.text_to_video.load_lora_weights(
|
445 |
-
config.lora_model_name,
|
446 |
-
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
|
447 |
-
token=hf_token,
|
448 |
-
)
|
449 |
-
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
|
450 |
-
print("Loading LoRA weights for the image_to_video pipeline..")
|
451 |
-
self.image_to_video.load_lora_weights(
|
452 |
-
config.lora_model_name,
|
453 |
-
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
|
454 |
-
token=hf_token,
|
455 |
)
|
456 |
-
|
457 |
-
|
458 |
-
# Modify prompt if trigger word is provided
|
459 |
-
if config.lora_model_trigger:
|
460 |
-
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
|
461 |
-
|
462 |
-
#enhance_a_video_config = EnhanceAVideoConfig(
|
463 |
-
# weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
|
464 |
-
# # doing some testing
|
465 |
-
# num_frames_callback=lambda: (8 + 1),
|
466 |
-
# # num_frames_callback=lambda: config.num_frames,
|
467 |
-
# # num_frames_callback=lambda: (config.num_frames - 1),
|
468 |
-
#
|
469 |
-
# _attention_type=1
|
470 |
-
#)
|
471 |
|
472 |
-
#
|
473 |
-
if
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
try:
|
490 |
loop = asyncio.get_event_loop()
|
491 |
except RuntimeError:
|
492 |
loop = asyncio.new_event_loop()
|
493 |
asyncio.set_event_loop(loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
|
495 |
-
|
496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
497 |
torch.cuda.empty_cache()
|
498 |
-
torch.cuda.reset_peak_memory_stats()
|
499 |
gc.collect()
|
500 |
|
501 |
return {
|
@@ -503,8 +663,10 @@ class EndpointHandler:
|
|
503 |
"content-type": "video/mp4",
|
504 |
"metadata": metadata
|
505 |
}
|
506 |
-
|
507 |
except Exception as e:
|
508 |
-
|
509 |
-
|
510 |
-
|
|
|
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import logging
|
4 |
+
import base64
|
5 |
import random
|
6 |
+
import gc
|
7 |
import os
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
+
from typing import Dict, Any, Optional, List, Union, Tuple
|
11 |
+
import json
|
12 |
+
from safetensors import safe_open
|
13 |
+
|
14 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
15 |
+
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
16 |
+
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
17 |
+
from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter
|
18 |
+
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
|
19 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
20 |
+
from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
21 |
|
22 |
from varnish import Varnish
|
23 |
from varnish.utils import is_truthy, process_input_image
|
|
|
26 |
logging.basicConfig(level=logging.INFO)
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
|
|
29 |
# Get token from environment
|
30 |
hf_token = os.getenv("HF_API_TOKEN")
|
31 |
|
32 |
# Constraints
|
33 |
MAX_LARGE_SIDE = 1280
|
34 |
+
MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
|
35 |
+
MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
|
36 |
|
37 |
# Check environment variable for pipeline support
|
38 |
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
|
|
46 |
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
|
47 |
|
48 |
# video model settings (will be used during generation of the initial raw video clip)
|
49 |
+
width: int = 1216 # 768
|
50 |
+
height: int = 704 # 416
|
|
|
|
|
51 |
|
52 |
# this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
|
53 |
# after a quick benchmark using the value 70 seems like a sweet spot
|
|
|
58 |
# visual glitches appear after about 169 frames, so we don't need more actually
|
59 |
num_frames: int = (8 * 14) + 1
|
60 |
|
61 |
+
# values between 3.0 and 4.0 are nice
|
62 |
+
guidance_scale: float = 3.0
|
63 |
|
64 |
num_inference_steps: int = 8
|
65 |
|
|
|
67 |
seed: int = -1 # -1 means random seed
|
68 |
|
69 |
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
70 |
+
fps: int = 30 # FPS of the final video (only applied at the very end, when converting to mp4)
|
71 |
+
double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
|
72 |
+
super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
73 |
|
74 |
+
grain_amount: float = 0.0 # be careful, adding film grain can negatively impact video compression
|
75 |
|
76 |
# audio settings
|
77 |
enable_audio: bool = False # Whether to generate audio
|
78 |
audio_prompt: str = "" # Text prompt for audio generation
|
79 |
+
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
|
80 |
|
81 |
# The range of the CRF scale is 0–51, where:
|
82 |
# 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
|
|
|
88 |
# The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
89 |
quality: int = 18
|
90 |
|
91 |
+
# STG (Spatiotemporal Guidance) settings
|
92 |
+
stg_scale: float = 0.0
|
93 |
+
stg_rescale: float = 1.0
|
94 |
+
stg_mode: str = "attention_values" # Can be "attention_values", "attention_skip", "residual", or "transformer_block"
|
95 |
|
96 |
+
# VAE noise augmentation
|
97 |
+
decode_timestep: float = 0.05
|
98 |
+
decode_noise_scale: float = 0.025
|
99 |
|
100 |
+
# Other advanced settings
|
101 |
+
image_cond_noise_scale: float = 0.15
|
102 |
+
mixed_precision: bool = True # Use mixed precision for inference
|
103 |
+
stochastic_sampling: bool = True # Use stochastic sampling
|
104 |
+
|
105 |
+
# Sampling settings
|
106 |
+
sampler: Optional[str] = "from_checkpoint" # "uniform" or "linear-quadratic" or None (use default from checkpoint)
|
107 |
+
|
108 |
+
# Prompt enhancement
|
109 |
+
enhance_prompt: bool = False # Whether to enhance the prompt using an LLM
|
110 |
+
prompt_enhancement_words_threshold: int = 50 # Enhance prompt only if it has fewer words than this
|
111 |
|
112 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
113 |
"""Validate and adjust parameters to meet constraints"""
|
|
|
115 |
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
|
116 |
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
|
117 |
# For other resolutions, ensure total pixels don't exceed max
|
118 |
+
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
|
119 |
|
120 |
# If total pixels exceed maximum, scale down proportionally
|
121 |
total_pixels = self.width * self.height
|
|
|
135 |
# Set random seed if not specified
|
136 |
if self.seed == -1:
|
137 |
self.seed = random.randint(0, 2**32 - 1)
|
|
|
|
|
138 |
|
139 |
+
# Set up STG parameters
|
140 |
+
if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values":
|
141 |
+
self.stg_mode = "attention_values"
|
142 |
+
elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip":
|
143 |
+
self.stg_mode = "attention_skip"
|
144 |
+
elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual":
|
145 |
+
self.stg_mode = "residual"
|
146 |
+
elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block":
|
147 |
+
self.stg_mode = "transformer_block"
|
148 |
+
|
149 |
+
# Check if we should enhance the prompt
|
150 |
+
if self.enhance_prompt and self.prompt:
|
151 |
+
prompt_word_count = len(self.prompt.split())
|
152 |
+
if prompt_word_count >= self.prompt_enhancement_words_threshold:
|
153 |
+
logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.")
|
154 |
+
self.enhance_prompt = False
|
155 |
|
156 |
+
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
+
def load_image_to_tensor_with_resize_and_crop(
|
159 |
+
image_input: Union[str, bytes],
|
160 |
+
target_height: int = 704,
|
161 |
+
target_width: int = 1216,
|
162 |
+
quality: int = 100
|
163 |
+
) -> torch.Tensor:
|
164 |
+
"""Load and process an image into a tensor.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
image_input: Either a file path (str) or image data (bytes)
|
168 |
+
target_height: Desired height of output tensor
|
169 |
+
target_width: Desired width of output tensor
|
170 |
+
quality: JPEG quality to use when re-encoding (to simulate lower quality images)
|
171 |
+
"""
|
172 |
+
from PIL import Image
|
173 |
+
import io
|
174 |
+
import numpy as np
|
175 |
+
|
176 |
+
# Handle base64 data URI
|
177 |
+
if isinstance(image_input, str) and image_input.startswith('data:'):
|
178 |
+
header, encoded = image_input.split(",", 1)
|
179 |
+
image_data = base64.b64decode(encoded)
|
180 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
181 |
+
# Handle raw bytes
|
182 |
+
elif isinstance(image_input, bytes):
|
183 |
+
image = Image.open(io.BytesIO(image_input)).convert("RGB")
|
184 |
+
# Handle file path
|
185 |
+
elif isinstance(image_input, str):
|
186 |
+
image = Image.open(image_input).convert("RGB")
|
187 |
+
else:
|
188 |
+
raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
|
189 |
+
|
190 |
+
# Apply JPEG compression if quality < 100 (to simulate a video frame)
|
191 |
+
if quality < 100:
|
192 |
+
buffer = io.BytesIO()
|
193 |
+
image.save(buffer, format="JPEG", quality=quality)
|
194 |
+
buffer.seek(0)
|
195 |
+
image = Image.open(buffer).convert("RGB")
|
196 |
+
|
197 |
+
input_width, input_height = image.size
|
198 |
+
aspect_ratio_target = target_width / target_height
|
199 |
+
aspect_ratio_frame = input_width / input_height
|
200 |
+
if aspect_ratio_frame > aspect_ratio_target:
|
201 |
+
new_width = int(input_height * aspect_ratio_target)
|
202 |
+
new_height = input_height
|
203 |
+
x_start = (input_width - new_width) // 2
|
204 |
+
y_start = 0
|
205 |
+
else:
|
206 |
+
new_width = input_width
|
207 |
+
new_height = int(input_width / aspect_ratio_target)
|
208 |
+
x_start = 0
|
209 |
+
y_start = (input_height - new_height) // 2
|
210 |
+
|
211 |
+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
212 |
+
image = image.resize((target_width, target_height))
|
213 |
+
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
|
214 |
+
frame_tensor = (frame_tensor / 127.5) - 1.0
|
215 |
+
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
216 |
+
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
217 |
+
|
218 |
+
def calculate_padding(
|
219 |
+
source_height: int, source_width: int, target_height: int, target_width: int
|
220 |
+
) -> tuple[int, int, int, int]:
|
221 |
+
"""Calculate padding to reach target dimensions"""
|
222 |
+
# Calculate total padding needed
|
223 |
+
pad_height = target_height - source_height
|
224 |
+
pad_width = target_width - source_width
|
225 |
+
|
226 |
+
# Calculate padding for each side
|
227 |
+
pad_top = pad_height // 2
|
228 |
+
pad_bottom = pad_height - pad_top # Handles odd padding
|
229 |
+
pad_left = pad_width // 2
|
230 |
+
pad_right = pad_width - pad_left # Handles odd padding
|
231 |
+
|
232 |
+
# Return padded tensor
|
233 |
+
# Padding format is (left, right, top, bottom)
|
234 |
+
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
235 |
+
return padding
|
236 |
+
|
237 |
+
def prepare_conditioning(
|
238 |
+
conditioning_media_paths: List[str],
|
239 |
+
conditioning_strengths: List[float],
|
240 |
+
conditioning_start_frames: List[int],
|
241 |
+
height: int,
|
242 |
+
width: int,
|
243 |
+
num_frames: int,
|
244 |
+
input_image_quality: int = 100,
|
245 |
+
pipeline: Optional[LTXVideoPipeline] = None,
|
246 |
+
) -> Optional[List[ConditioningItem]]:
|
247 |
+
"""Prepare conditioning items based on input media paths and their parameters"""
|
248 |
+
conditioning_items = []
|
249 |
+
for path, strength, start_frame in zip(
|
250 |
+
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
|
251 |
+
):
|
252 |
+
# Load and process the conditioning image
|
253 |
+
frame_tensor = load_image_to_tensor_with_resize_and_crop(
|
254 |
+
path, height, width, quality=input_image_quality
|
255 |
+
)
|
256 |
+
|
257 |
+
# Trim frame count if needed
|
258 |
+
if pipeline:
|
259 |
+
frame_count = 1 # For image inputs, it's always 1
|
260 |
+
frame_count = pipeline.trim_conditioning_sequence(
|
261 |
+
start_frame, frame_count, num_frames
|
262 |
+
)
|
263 |
+
|
264 |
+
conditioning_items.append(
|
265 |
+
ConditioningItem(frame_tensor, start_frame, strength)
|
266 |
)
|
267 |
|
268 |
+
return conditioning_items
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
+
def create_ltx_video_pipeline(
|
271 |
+
config: GenerationConfig,
|
272 |
+
device: str = "cuda"
|
273 |
+
) -> LTXVideoPipeline:
|
274 |
+
"""Create and configure the LTX video pipeline"""
|
|
|
275 |
|
276 |
+
ckpt_path = "/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors"
|
277 |
+
|
278 |
+
# Get allowed inference steps from config if available
|
279 |
+
allowed_inference_steps = None
|
280 |
+
|
281 |
+
assert os.path.exists(
|
282 |
+
ckpt_path
|
283 |
+
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
|
284 |
+
|
285 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
286 |
+
metadata = f.metadata()
|
287 |
+
config_str = metadata.get("config")
|
288 |
+
configs = json.loads(config_str)
|
289 |
+
allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
290 |
+
|
291 |
+
# Initialize model components
|
292 |
+
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
293 |
+
transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
294 |
+
|
295 |
+
# Use constructor if sampler is specified, otherwise use from_pretrained
|
296 |
+
if config.sampler:
|
297 |
+
scheduler = RectifiedFlowScheduler(
|
298 |
+
sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
|
302 |
+
|
303 |
+
text_encoder = T5EncoderModel.from_pretrained("/repository/text_encoder")
|
304 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
305 |
+
tokenizer = T5Tokenizer.from_pretrained("/repository/tokenizer")
|
306 |
+
|
307 |
+
# Move models to the correct device
|
308 |
+
vae = vae.to(device)
|
309 |
+
transformer = transformer.to(device)
|
310 |
+
text_encoder = text_encoder.to(device)
|
311 |
+
|
312 |
+
# Set up precision
|
313 |
+
vae = vae.to(torch.bfloat16)
|
314 |
+
transformer = transformer.to(torch.bfloat16)
|
315 |
+
text_encoder = text_encoder.to(torch.bfloat16)
|
316 |
+
|
317 |
+
# Initialize prompt enhancer components if needed
|
318 |
+
prompt_enhancer_components = {
|
319 |
+
"prompt_enhancer_image_caption_model": None,
|
320 |
+
"prompt_enhancer_image_caption_processor": None,
|
321 |
+
"prompt_enhancer_llm_model": None,
|
322 |
+
"prompt_enhancer_llm_tokenizer": None
|
323 |
+
}
|
324 |
+
|
325 |
+
if config.enhance_prompt:
|
326 |
+
try:
|
327 |
+
# Use default models or ones specified by config
|
328 |
+
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
|
329 |
+
"MiaoshouAI/Florence-2-large-PromptGen-v2.0",
|
330 |
+
trust_remote_code=True
|
331 |
+
)
|
332 |
+
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
|
333 |
+
"MiaoshouAI/Florence-2-large-PromptGen-v2.0",
|
334 |
+
trust_remote_code=True
|
335 |
+
)
|
336 |
+
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
|
337 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
338 |
+
torch_dtype="bfloat16",
|
339 |
+
)
|
340 |
+
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
|
341 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
342 |
+
)
|
343 |
|
344 |
+
prompt_enhancer_components = {
|
345 |
+
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
|
346 |
+
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
|
347 |
+
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
|
348 |
+
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer
|
349 |
+
}
|
350 |
+
except Exception as e:
|
351 |
+
logger.warning(f"Failed to load prompt enhancer models: {e}")
|
352 |
+
config.enhance_prompt = False
|
353 |
+
|
354 |
+
# Construct the pipeline
|
355 |
+
pipeline = LTXVideoPipeline(
|
356 |
+
transformer=transformer,
|
357 |
+
patchifier=patchifier,
|
358 |
+
text_encoder=text_encoder,
|
359 |
+
tokenizer=tokenizer,
|
360 |
+
scheduler=scheduler,
|
361 |
+
vae=vae,
|
362 |
+
allowed_inference_steps=allowed_inference_steps,
|
363 |
+
**prompt_enhancer_components
|
364 |
+
)
|
365 |
+
|
366 |
+
return pipeline
|
|
|
367 |
|
368 |
+
class EndpointHandler:
|
369 |
+
"""Handler for the LTX Video endpoint"""
|
370 |
+
|
371 |
+
def __init__(self, model_path: str = "/repository/"):
|
372 |
+
"""Initialize the endpoint handler
|
373 |
+
|
374 |
+
Args:
|
375 |
+
model_path: Path to model weights (not used, as weights are in current directory)
|
376 |
+
"""
|
377 |
+
# Enable TF32 for potential speedup on Ampere GPUs
|
378 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
379 |
+
|
380 |
# Initialize Varnish for post-processing
|
381 |
self.varnish = Varnish(
|
382 |
device="cuda",
|
383 |
model_base_dir="/repository/varnish",
|
384 |
+
enable_mmaudio=False, # Disable audio generation for now, since it is broken
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
+
# The actual LTX pipeline will be loaded during inference to save memory
|
388 |
+
self.pipeline = None
|
389 |
+
|
390 |
+
# Perform warm-up inference
|
391 |
+
logger.info("Performing warm-up inference...")
|
392 |
+
self._warmup()
|
393 |
+
logger.info("Warm-up completed!")
|
394 |
+
|
395 |
+
def _warmup(self):
|
396 |
+
"""Perform a warm-up inference to prepare the model for future requests"""
|
397 |
try:
|
398 |
+
# Create a simple test configuration
|
399 |
+
test_config = GenerationConfig(
|
400 |
+
prompt="an astronaut is riding a cow in the desert, during golden hour",
|
401 |
+
negative_prompt="worst quality, lowres",
|
402 |
+
width=768, # Using smaller resolution for faster warm-up
|
403 |
+
height=416,
|
404 |
+
num_frames=33, # Just enough frames for a valid video
|
405 |
+
guidance_scale=1.0,
|
406 |
+
num_inference_steps=4, # Fewer steps for faster warm-up
|
407 |
+
seed=42, # Fixed seed for consistent warm-up
|
408 |
+
fps=16, # Lower FPS for faster processing
|
409 |
+
enable_audio=False, # No audio for warm-up
|
410 |
+
mixed_precision=True,
|
411 |
+
).validate_and_adjust()
|
412 |
|
413 |
+
# Create the pipeline if it doesn't exist
|
414 |
+
if self.pipeline is None:
|
415 |
+
self.pipeline = create_ltx_video_pipeline(test_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
+
# Run a quick inference
|
418 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
|
419 |
+
# Set seeds for reproducibility
|
420 |
+
random.seed(test_config.seed)
|
421 |
+
np.random.seed(test_config.seed)
|
422 |
+
torch.manual_seed(test_config.seed)
|
423 |
+
generator = torch.Generator(device='cuda').manual_seed(test_config.seed)
|
424 |
+
|
425 |
+
# Generate video
|
426 |
+
result = self.pipeline(
|
427 |
+
height=test_config.height,
|
428 |
+
width=test_config.width,
|
429 |
+
num_frames=test_config.num_frames,
|
430 |
+
frame_rate=test_config.fps,
|
431 |
+
prompt=test_config.prompt,
|
432 |
+
negative_prompt=test_config.negative_prompt,
|
433 |
+
guidance_scale=test_config.guidance_scale,
|
434 |
+
num_inference_steps=test_config.num_inference_steps,
|
435 |
+
generator=generator,
|
436 |
+
output_type="pt",
|
437 |
+
mixed_precision=test_config.mixed_precision,
|
438 |
+
is_video=True,
|
439 |
+
vae_per_channel_normalize=True,
|
440 |
+
)
|
441 |
+
|
442 |
+
# Just get the frames without full processing (faster warm-up)
|
443 |
+
frames = result.images
|
444 |
+
|
445 |
+
# Clean up
|
446 |
+
del result
|
447 |
+
torch.cuda.empty_cache()
|
448 |
+
gc.collect()
|
449 |
+
|
450 |
+
logger.info(f"Warm-up successful! Generated {frames.shape[2]} frames at {frames.shape[3]}x{frames.shape[4]}")
|
451 |
+
|
452 |
except Exception as e:
|
453 |
+
# Log the error but don't fail initialization
|
454 |
+
import traceback
|
455 |
+
error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}"
|
456 |
+
logger.warning(error_message)
|
457 |
+
|
458 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
459 |
+
"""Process inference requests
|
460 |
|
461 |
Args:
|
462 |
+
data: Request data containing inputs and parameters
|
463 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
Returns:
|
465 |
+
Dictionary with generated video and metadata
|
|
|
|
|
|
|
466 |
"""
|
467 |
+
# Extract inputs and parameters
|
468 |
+
inputs = data.get("inputs", {})
|
469 |
|
470 |
+
# Support both formats:
|
471 |
+
# 1. {"inputs": {"prompt": "...", "image": "..."}}
|
472 |
+
# 2. {"inputs": "..."} (prompt only)
|
473 |
+
if isinstance(inputs, str):
|
474 |
+
input_prompt = inputs
|
475 |
+
input_image = None
|
476 |
+
else:
|
477 |
+
input_prompt = inputs.get("prompt", "")
|
478 |
+
input_image = inputs.get("image")
|
479 |
|
480 |
+
params = data.get("parameters", {})
|
481 |
+
|
482 |
+
if not input_prompt and not input_image:
|
483 |
raise ValueError("Either prompt or image must be provided")
|
484 |
+
|
|
|
|
|
|
|
485 |
# Create and validate configuration
|
486 |
config = GenerationConfig(
|
487 |
# general content settings
|
488 |
prompt=input_prompt,
|
489 |
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
490 |
+
|
491 |
+
# video model settings
|
492 |
width=params.get("width", GenerationConfig.width),
|
493 |
height=params.get("height", GenerationConfig.height),
|
494 |
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
|
495 |
num_frames=params.get("num_frames", GenerationConfig.num_frames),
|
496 |
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
497 |
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
498 |
+
|
499 |
+
# STG settings
|
500 |
+
stg_scale=params.get("stg_scale", GenerationConfig.stg_scale),
|
501 |
+
stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale),
|
502 |
+
stg_mode=params.get("stg_mode", GenerationConfig.stg_mode),
|
503 |
+
|
504 |
+
# VAE noise settings
|
505 |
+
decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep),
|
506 |
+
decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale),
|
507 |
+
image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale),
|
508 |
+
|
509 |
# reproducible generation settings
|
510 |
seed=params.get("seed", GenerationConfig.seed),
|
511 |
|
512 |
+
# varnish settings
|
513 |
+
fps=params.get("fps", GenerationConfig.fps),
|
514 |
+
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
|
515 |
+
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
|
516 |
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
517 |
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
518 |
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
519 |
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
520 |
quality=params.get("quality", GenerationConfig.quality),
|
521 |
|
522 |
+
# advanced settings
|
523 |
+
mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
|
524 |
+
stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling),
|
525 |
+
sampler=params.get("sampler", GenerationConfig.sampler),
|
|
|
|
|
526 |
|
527 |
+
# prompt enhancement
|
528 |
+
enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt),
|
529 |
+
prompt_enhancement_words_threshold=params.get(
|
530 |
+
"prompt_enhancement_words_threshold",
|
531 |
+
GenerationConfig.prompt_enhancement_words_threshold
|
532 |
+
),
|
|
|
|
|
533 |
).validate_and_adjust()
|
534 |
|
|
|
|
|
|
|
535 |
try:
|
536 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
|
537 |
+
# Set random seeds for reproducibility
|
538 |
random.seed(config.seed)
|
539 |
np.random.seed(config.seed)
|
540 |
torch.manual_seed(config.seed)
|
541 |
+
generator = torch.Generator(device='cuda').manual_seed(config.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
|
543 |
+
# Create pipeline if not already created
|
544 |
+
if self.pipeline is None:
|
545 |
+
self.pipeline = create_ltx_video_pipeline(config)
|
546 |
+
|
547 |
+
# Prepare conditioning items if an image is provided
|
548 |
+
conditioning_items = None
|
549 |
+
if input_image:
|
550 |
+
conditioning_items = [
|
551 |
+
ConditioningItem(
|
552 |
+
load_image_to_tensor_with_resize_and_crop(
|
553 |
+
input_image,
|
554 |
+
config.height,
|
555 |
+
config.width,
|
556 |
+
quality=config.input_image_quality
|
557 |
+
),
|
558 |
+
0, # Start frame
|
559 |
+
1.0 # Conditioning strength
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
)
|
561 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
+
# Set up spatiotemporal guidance strategy
|
564 |
+
if config.stg_mode == "attention_values":
|
565 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionValues
|
566 |
+
elif config.stg_mode == "attention_skip":
|
567 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
|
568 |
+
elif config.stg_mode == "residual":
|
569 |
+
skip_layer_strategy = SkipLayerStrategy.Residual
|
570 |
+
elif config.stg_mode == "transformer_block":
|
571 |
+
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
|
572 |
+
|
573 |
+
# Generate video with LTX pipeline
|
574 |
+
result = self.pipeline(
|
575 |
+
height=config.height,
|
576 |
+
width=config.width,
|
577 |
+
num_frames=config.num_frames,
|
578 |
+
frame_rate=config.fps,
|
579 |
+
prompt=config.prompt,
|
580 |
+
negative_prompt=config.negative_prompt,
|
581 |
+
guidance_scale=config.guidance_scale,
|
582 |
+
num_inference_steps=config.num_inference_steps,
|
583 |
+
generator=generator,
|
584 |
+
output_type="pt", # Return as PyTorch tensor
|
585 |
+
skip_layer_strategy=skip_layer_strategy,
|
586 |
+
stg_scale=config.stg_scale,
|
587 |
+
do_rescaling=config.stg_rescale != 1.0,
|
588 |
+
rescaling_scale=config.stg_rescale,
|
589 |
+
conditioning_items=conditioning_items,
|
590 |
+
decode_timestep=config.decode_timestep,
|
591 |
+
decode_noise_scale=config.decode_noise_scale,
|
592 |
+
image_cond_noise_scale=config.image_cond_noise_scale,
|
593 |
+
mixed_precision=config.mixed_precision,
|
594 |
+
is_video=True,
|
595 |
+
vae_per_channel_normalize=True,
|
596 |
+
stochastic_sampling=config.stochastic_sampling,
|
597 |
+
enhance_prompt=config.enhance_prompt,
|
598 |
+
)
|
599 |
+
|
600 |
+
# Get the generated frames
|
601 |
+
frames = result.images
|
602 |
+
|
603 |
+
# FIX: Convert LTX output format to varnish-compatible format
|
604 |
+
# LTX outputs: [batch, channels, frames, height, width]
|
605 |
+
# We need: [frames, channels, height, width] for varnish
|
606 |
+
frames = frames.squeeze(0) # Remove batch: [channels, frames, height, width]
|
607 |
+
frames = frames.permute(1, 0, 2, 3) # Reorder to: [frames, channels, height, width]
|
608 |
+
|
609 |
+
# Convert from [0, 1] to [0, 255] range
|
610 |
+
frames = frames * 255.0
|
611 |
+
|
612 |
+
# Convert to uint8
|
613 |
+
frames = frames.to(torch.uint8)
|
614 |
+
|
615 |
+
# Process the generated frames with Varnish
|
616 |
+
import asyncio
|
617 |
try:
|
618 |
loop = asyncio.get_event_loop()
|
619 |
except RuntimeError:
|
620 |
loop = asyncio.new_event_loop()
|
621 |
asyncio.set_event_loop(loop)
|
622 |
+
|
623 |
+
# Process with Varnish for post-processing
|
624 |
+
varnish_result = loop.run_until_complete(
|
625 |
+
self.varnish(
|
626 |
+
frames,
|
627 |
+
fps=config.fps,
|
628 |
+
double_num_frames=config.double_num_frames,
|
629 |
+
super_resolution=config.super_resolution,
|
630 |
+
grain_amount=config.grain_amount,
|
631 |
+
enable_audio=config.enable_audio,
|
632 |
+
audio_prompt=config.audio_prompt or config.prompt,
|
633 |
+
audio_negative_prompt=config.audio_negative_prompt,
|
634 |
+
)
|
635 |
+
)
|
636 |
|
637 |
+
# Get the final video as a data URI
|
638 |
+
video_uri = loop.run_until_complete(
|
639 |
+
varnish_result.write(
|
640 |
+
type="data-uri",
|
641 |
+
quality=config.quality
|
642 |
+
)
|
643 |
+
)
|
644 |
+
|
645 |
+
# Prepare metadata about the generated video
|
646 |
+
metadata = {
|
647 |
+
"width": varnish_result.metadata.width,
|
648 |
+
"height": varnish_result.metadata.height,
|
649 |
+
"num_frames": varnish_result.metadata.frame_count,
|
650 |
+
"fps": varnish_result.metadata.fps,
|
651 |
+
"duration": varnish_result.metadata.duration,
|
652 |
+
"seed": config.seed,
|
653 |
+
"prompt": config.prompt,
|
654 |
+
}
|
655 |
+
|
656 |
+
# Clean up to prevent CUDA OOM errors
|
657 |
+
del result
|
658 |
torch.cuda.empty_cache()
|
|
|
659 |
gc.collect()
|
660 |
|
661 |
return {
|
|
|
663 |
"content-type": "video/mp4",
|
664 |
"metadata": metadata
|
665 |
}
|
666 |
+
|
667 |
except Exception as e:
|
668 |
+
# Log the error and reraise
|
669 |
+
import traceback
|
670 |
+
error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
|
671 |
+
logger.error(error_message)
|
672 |
+
raise RuntimeError(error_message)
|