jbilcke-hf HF Staff commited on
Commit
34dc7eb
·
verified ·
1 Parent(s): c447016

Create handler_LAST_WORKING.py

Browse files
Files changed (1) hide show
  1. handler_LAST_WORKING.py +510 -0
handler_LAST_WORKING.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ import pathlib
4
+ from typing import Dict, Any, Optional, Tuple
5
+ import asyncio
6
+ import base64
7
+ import io
8
+ import pprint
9
+ import logging
10
+ import random
11
+ import traceback
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ import gc
16
+
17
+ from diffusers import AutoencoderKLLTXVideo, LTXPipeline, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
18
+ #from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
19
+ #from teacache import apply_teacache
20
+
21
+ from PIL import Image
22
+
23
+ from varnish import Varnish
24
+ from varnish.utils import is_truthy, process_input_image
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # Get token from environment
32
+ hf_token = os.getenv("HF_API_TOKEN")
33
+
34
+ # Constraints
35
+ MAX_LARGE_SIDE = 1280
36
+ MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
37
+ MAX_FRAMES = 257
38
+
39
+ # Check environment variable for pipeline support
40
+ support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
41
+
42
+ @dataclass
43
+ class GenerationConfig:
44
+ """Configuration for video generation"""
45
+
46
+ # general content settings
47
+ prompt: str = ""
48
+ negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
49
+
50
+ # video model settings (will be used during generation of the initial raw video clip)
51
+ # we use small values to make things a bit faster
52
+ width: int = 768
53
+ height: int = 416
54
+
55
+
56
+ # this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
57
+ # after a quick benchmark using the value 70 seems like a sweet spot
58
+ input_image_quality: int = 70
59
+
60
+ # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
61
+ # The value must be a multiple of 8, plus 1 frame.
62
+ # visual glitches appear after about 169 frames, so we don't need more actually
63
+ num_frames: int = (8 * 14) + 1
64
+
65
+ # with the distilled model, a guidance scale of 1.0 is fine
66
+ guidance_scale: float = 1.0
67
+
68
+ num_inference_steps: int = 8
69
+
70
+ # reproducible generation settings
71
+ seed: int = -1 # -1 means random seed
72
+
73
+ # varnish settings (will be used for post-processing after the raw video clip has been generated
74
+ fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
75
+ double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
76
+ super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
77
+
78
+ grain_amount: float = 0.0 # be careful, adding film grian can negatively impact video compression
79
+
80
+ # audio settings
81
+ enable_audio: bool = False # Whether to generate audio
82
+ audio_prompt: str = "" # Text prompt for audio generation
83
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
84
+
85
+ # The range of the CRF scale is 0–51, where:
86
+ # 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
87
+ # 23 is the default
88
+ # 51 is worst quality possible
89
+ # A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
90
+ # Consider 17 or 18 to be visually lossless or nearly so;
91
+ # it should look the same or nearly the same as the input but it isn't technically lossless.
92
+ # The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
93
+ quality: int = 18
94
+
95
+ # TeaCache settings
96
+ enable_teacache: bool = False
97
+ teacache_threshold: float = 0.05 # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
98
+
99
+ # Enhance-A-Video settings
100
+ enable_enhance_a_video: bool = False
101
+ enhance_a_video_weight: float = 5.0
102
+
103
+ # LoRA settings
104
+ lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
105
+ lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
106
+ lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
107
+
108
+ def validate_and_adjust(self) -> 'GenerationConfig':
109
+ """Validate and adjust parameters to meet constraints"""
110
+ # First check if it's one of our explicitly allowed resolutions
111
+ if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
112
+ (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
113
+ # For other resolutions, ensure total pixels don't exceed max
114
+ MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
115
+
116
+ # If total pixels exceed maximum, scale down proportionally
117
+ total_pixels = self.width * self.height
118
+ if total_pixels > MAX_TOTAL_PIXELS:
119
+ scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
120
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
121
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
122
+ else:
123
+ # Round dimensions to nearest multiple of 32
124
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
125
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
126
+
127
+ # Adjust number of frames to be in format 8k + 1
128
+ k = (self.num_frames - 1) // 8
129
+ self.num_frames = min((k * 8) + 1, MAX_FRAMES)
130
+
131
+ # Set random seed if not specified
132
+ if self.seed == -1:
133
+ self.seed = random.randint(0, 2**32 - 1)
134
+
135
+ return self
136
+
137
+ class EndpointHandler:
138
+ """Handles video generation requests using LTX models and Varnish post-processing"""
139
+
140
+ def __init__(self, model_path: str = ""):
141
+ """Initialize the handler with LTX models and Varnish
142
+
143
+ Args:
144
+ model_path: Path to LTX model weights
145
+ """
146
+ print("EndpointHandler.__init__(): initializing..")
147
+ # Enable TF32 for potential speedup on Ampere GPUs
148
+ #torch.backends.cuda.matmul.allow_tf32 = True
149
+
150
+ # use distilled weights
151
+ model_path = Path("/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors")
152
+
153
+ print("EndpointHandler.__init__(): initializing LTXVideoTransformer3DModel..")
154
+ transformer = LTXVideoTransformer3DModel.from_single_file(
155
+ model_path, torch_dtype=torch.bfloat16
156
+ )
157
+
158
+ print("EndpointHandler.__init__(): initializing AutoencoderKLLTXVideo..")
159
+ vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
160
+
161
+ if support_image_prompt:
162
+ print("EndpointHandler.__init__(): initializing LTXImageToVideoPipeline..")
163
+ self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
164
+ "/repository",
165
+ transformer=transformer,
166
+ vae=vae,
167
+ torch_dtype=torch.bfloat16
168
+ ).to("cuda")
169
+
170
+ #apply_teacache(self.image_to_video)
171
+
172
+ # Compilation requires some time to complete, so it is best suited for
173
+ # situations where you prepare your pipeline once and then perform the
174
+ # same type of inference operations multiple times.
175
+ # For example, calling the compiled pipeline on a different image size
176
+ # triggers compilation again which can be expensive.
177
+ #self.image_to_video = torch.compile(self.image_to_video, mode="reduce-overhead", fullgraph=True)
178
+
179
+ else:
180
+ print("EndpointHandler.__init__(): initializing LTXPipeline..")
181
+ # Initialize models with bfloat16 precision
182
+ self.text_to_video = LTXPipeline.from_pretrained(
183
+ "/repository",
184
+ transformer=transformer,
185
+ vae=vae,
186
+ torch_dtype=torch.bfloat16
187
+ ).to("cuda")
188
+
189
+ #apply_teacache(self.text_to_video)
190
+
191
+ # Compilation requires some time to complete, so it is best suited for
192
+ # situations where you prepare your pipeline once and then perform the
193
+ # same type of inference operations multiple times.
194
+ # For example, calling the compiled pipeline on a different image size
195
+ # triggers compilation again which can be expensive.
196
+ #self.text_to_video = torch.compile(self.text_to_video, mode="reduce-overhead", fullgraph=True)
197
+
198
+
199
+ # Initialize LoRA tracking
200
+ self._current_lora_model = None
201
+
202
+ #if support_image_prompt:
203
+ # # Enable CPU offload for memory efficiency
204
+ # self.image_to_video.enable_model_cpu_offload()
205
+ # # Inject enhance-a-video functionality
206
+ # inject_enhance_for_ltx(self.image_to_video.transformer)
207
+ #else:
208
+ # # Enable CPU offload for memory efficiency
209
+ # self.text_to_video.enable_model_cpu_offload()
210
+ # # Inject enhance-a-video functionality
211
+ # inject_enhance_for_ltx(self.text_to_video.transformer)
212
+
213
+
214
+ # Initialize Varnish for post-processing
215
+ self.varnish = Varnish(
216
+ device="cuda",
217
+ model_base_dir="/repository/varnish",
218
+
219
+ # there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
220
+ # not sure how to fix that.. :/
221
+ #
222
+ # it says:
223
+ # File "dist-packages/varnish.py", line 152, in __init__
224
+ # self._setup_mmaudio()
225
+ # File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
226
+ # net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
227
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
228
+ # File "dist-packages/torch/serialization.py", line 1384, in load
229
+ # return _legacy_load(
230
+ # ^^^^^^^^^^^^^
231
+ # File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
232
+ # magic_number = pickle_module.load(f, **pickle_load_args)
233
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
234
+ # _pickle.UnpicklingError: invalid load key, '<'.
235
+ enable_mmaudio=False,
236
+ )
237
+
238
+ # Determine if TeaCache is already installed or not
239
+ self.text_to_video_teacache = False
240
+ self.image_to_video_teacache = False
241
+
242
+
243
+ async def process_frames(
244
+ self,
245
+ frames: torch.Tensor,
246
+ config: GenerationConfig
247
+ ) -> tuple[str, dict]:
248
+ """Post-process generated frames using Varnish
249
+
250
+ Args:
251
+ frames: Generated video frames tensor
252
+ config: Generation configuration
253
+
254
+ Returns:
255
+ Tuple of (video data URI, metadata dictionary)
256
+ """
257
+ try:
258
+ # Process video with Varnish
259
+ result = await self.varnish(
260
+ input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
261
+ fps=config.fps, # this is the FPS of the final output video. This number can be used by Varnish to calculate the duration of a clip ((using frames * factor) / fps etc)
262
+ double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
263
+ super_resolution=config.super_resolution, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
264
+ grain_amount=config.grain_amount,
265
+ enable_audio=config.enable_audio,
266
+ audio_prompt=config.audio_prompt,
267
+ audio_negative_prompt=config.audio_negative_prompt,
268
+ )
269
+
270
+ # Convert to data URI
271
+ video_uri = await result.write(type="data-uri", quality=config.quality)
272
+
273
+ # Collect metadata
274
+ metadata = {
275
+ "width": result.metadata.width,
276
+ "height": result.metadata.height,
277
+ "num_frames": result.metadata.frame_count,
278
+ "fps": result.metadata.fps,
279
+ "duration": result.metadata.duration,
280
+ "seed": config.seed,
281
+ }
282
+
283
+ return video_uri, metadata
284
+
285
+ except Exception as e:
286
+ logger.error(f"Error in process_frames: {str(e)}")
287
+ raise RuntimeError(f"Failed to process frames: {str(e)}")
288
+
289
+
290
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
291
+ """Process incoming requests for video generation
292
+
293
+ Args:
294
+ data: Request data containing:
295
+ - inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
296
+ - parameters (dict):
297
+ - prompt (required, string): list of concepts to keep in the video.
298
+ - negative_prompt (optional, string): list of concepts to ignore in the video.
299
+ - width (optional, int, default to 768): width, or horizontal size in pixels.
300
+ - height (optional, int, default to 512): height, or vertical size in pixels.
301
+ - input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
302
+ - num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
303
+ - guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
304
+ - num_inference_steps (optional, int, default to 50): number of inference steps
305
+ - seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
306
+ - fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
307
+ - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
308
+ - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
309
+ - grain_amount (optional, float): amount of film grain to add to the output video
310
+ - enable_audio (optional, bool): automatically generate an audio track
311
+ - audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
312
+ - audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
313
+ - quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
314
+ - enable_teacache (optional, bool, default to False): Generate faster at the cost of a slight quality loss
315
+ - teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
316
+ - enable_enhance_a_video (optional, bool, default to False): enable the enhance_a_video optimization
317
+ - enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
318
+ - lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
319
+ - lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
320
+ - lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
321
+ Returns:
322
+ Dictionary containing:
323
+ - video: Base64 encoded MP4 data URI
324
+ - content-type: MIME type
325
+ - metadata: Generation metadata
326
+ """
327
+ inputs = data.get("inputs", dict())
328
+ #print(inputs)
329
+
330
+ input_prompt = inputs.get("prompt", "")
331
+ input_image = inputs.get("image")
332
+
333
+ params = data.get("parameters", dict())
334
+
335
+ if not input_image and not input_prompt:
336
+ raise ValueError("Either prompt or image must be provided")
337
+
338
+ #logger.debug(f"Raw parameters:")
339
+ # pprint.pprint(params)
340
+
341
+ # Create and validate configuration
342
+ config = GenerationConfig(
343
+ # general content settings
344
+ prompt=input_prompt,
345
+ negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
346
+
347
+ # video model settings (will be used during generation of the initial raw video clip)
348
+ width=params.get("width", GenerationConfig.width),
349
+ height=params.get("height", GenerationConfig.height),
350
+ input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
351
+ num_frames=params.get("num_frames", GenerationConfig.num_frames),
352
+ guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
353
+ num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
354
+
355
+ # reproducible generation settings
356
+ seed=params.get("seed", GenerationConfig.seed),
357
+
358
+ # varnish settings (will be used for post-processing after the raw video clip has been generated)
359
+ fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
360
+ double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
361
+ super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
362
+ grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
363
+ enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
364
+ audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
365
+ audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
366
+ quality=params.get("quality", GenerationConfig.quality),
367
+
368
+ # TeaCache settings
369
+ enable_teacache=params.get("enable_teacache", False),
370
+
371
+ # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
372
+ teacache_threshold=params.get("teacache_threshold", 0.05),
373
+
374
+
375
+ # Add enhance-a-video settings
376
+ enable_enhance_a_video=params.get("enable_enhance_a_video", False),
377
+ enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
378
+
379
+ # LoRA settings
380
+ lora_model_name=params.get("lora_model_name", ""),
381
+ lora_model_weight_file=params.get("lora_model_weight_file", ""),
382
+ lora_model_trigger=params.get("lora_model_trigger", ""),
383
+ ).validate_and_adjust()
384
+
385
+ #logger.debug(f"Global request settings:")
386
+ #pprint.pprint(config)
387
+
388
+ try:
389
+ with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
390
+ # Set random seeds
391
+ random.seed(config.seed)
392
+ np.random.seed(config.seed)
393
+ torch.manual_seed(config.seed)
394
+ generator = torch.Generator(device='cuda')
395
+ generator = generator.manual_seed(config.seed)
396
+
397
+ # Configure enhance-a-video
398
+ #if config.enable_enhance_a_video:
399
+ # enable_enhance()
400
+ # set_enhance_weight(config.enhance_a_video_weight)
401
+
402
+ # Prepare generation parameters for the video model (we omit params that are destined to Varnish, or things like the seed which is set externally)
403
+ generation_kwargs = {
404
+ # general content settings
405
+ "prompt": config.prompt,
406
+ "negative_prompt": config.negative_prompt,
407
+
408
+ # video model settings (will be used during generation of the initial raw video clip)
409
+ "width": config.width,
410
+ "height": config.height,
411
+ "num_frames": config.num_frames,
412
+ "guidance_scale": config.guidance_scale,
413
+ "num_inference_steps": config.num_inference_steps,
414
+
415
+ # constants
416
+ "output_type": "pt",
417
+ "generator": generator,
418
+
419
+ # Timestep for decoding VAE noise: the timestep at which generated video is decoded
420
+ "decode_timestep": 0.05,
421
+
422
+ # Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
423
+ "decode_noise_scale": 0.025,
424
+ }
425
+ #logger.info(f"Video model generation settings:")
426
+ #pprint.pprint(generation_kwargs)
427
+
428
+ # Handle LoRA loading/unloading
429
+ if hasattr(self, '_current_lora_model'):
430
+ if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
431
+ # Unload previous LoRA if it exists and is different
432
+ if hasattr(self.text_to_video, 'unload_lora_weights'):
433
+ print("Unloading LoRA weights for the text_to_video pipeline..")
434
+ self.text_to_video.unload_lora_weights()
435
+
436
+ if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
437
+ print("Unloading LoRA weights for the image_to_video pipeline..")
438
+ self.image_to_video.unload_lora_weights()
439
+
440
+ if config.lora_model_name:
441
+ # Load new LoRA
442
+ if hasattr(self.text_to_video, 'load_lora_weights'):
443
+ print("Loading LoRA weights for the text_to_video pipeline..")
444
+ self.text_to_video.load_lora_weights(
445
+ config.lora_model_name,
446
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
447
+ token=hf_token,
448
+ )
449
+ if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
450
+ print("Loading LoRA weights for the image_to_video pipeline..")
451
+ self.image_to_video.load_lora_weights(
452
+ config.lora_model_name,
453
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
454
+ token=hf_token,
455
+ )
456
+ self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
457
+
458
+ # Modify prompt if trigger word is provided
459
+ if config.lora_model_trigger:
460
+ generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
461
+
462
+ #enhance_a_video_config = EnhanceAVideoConfig(
463
+ # weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
464
+ # # doing some testing
465
+ # num_frames_callback=lambda: (8 + 1),
466
+ # # num_frames_callback=lambda: config.num_frames,
467
+ # # num_frames_callback=lambda: (config.num_frames - 1),
468
+ #
469
+ # _attention_type=1
470
+ #)
471
+
472
+ # Check if image-to-video generation is requested
473
+ if support_image_prompt and input_image:
474
+ processed_image = process_input_image(
475
+ input_image,
476
+ config.width,
477
+ config.height,
478
+ config.input_image_quality,
479
+ )
480
+ generation_kwargs["image"] = processed_image
481
+ # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
482
+ # apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
483
+ frames = self.image_to_video(**generation_kwargs).frames
484
+ else:
485
+ # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
486
+ # apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
487
+ frames = self.text_to_video(**generation_kwargs).frames
488
+
489
+ try:
490
+ loop = asyncio.get_event_loop()
491
+ except RuntimeError:
492
+ loop = asyncio.new_event_loop()
493
+ asyncio.set_event_loop(loop)
494
+
495
+ video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
496
+
497
+ torch.cuda.empty_cache()
498
+ torch.cuda.reset_peak_memory_stats()
499
+ gc.collect()
500
+
501
+ return {
502
+ "video": video_uri,
503
+ "content-type": "video/mp4",
504
+ "metadata": metadata
505
+ }
506
+
507
+ except Exception as e:
508
+ message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
509
+ print(message)
510
+ raise RuntimeError(message)