jbilcke-hf HF Staff commited on
Commit
ad1d6a3
·
verified ·
1 Parent(s): eb06e8c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +504 -0
handler.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ import pathlib
4
+ from typing import Dict, Any, Optional, Tuple
5
+ import asyncio
6
+ import base64
7
+ import io
8
+ import pprint
9
+ import logging
10
+ import random
11
+ import traceback
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ import gc
16
+
17
+ from diffusers import AutoencoderKLLTXVideo, LTXPipeline, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
18
+ #from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
19
+ from teacache import apply_teacache
20
+
21
+ from PIL import Image
22
+
23
+ from varnish import Varnish
24
+ from varnish.utils import is_truthy, process_input_image
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # Get token from environment
32
+ hf_token = os.getenv("HF_API_TOKEN")
33
+
34
+ # Constraints
35
+ MAX_LARGE_SIDE = 1280
36
+ MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
37
+ MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
38
+
39
+ # Check environment variable for pipeline support
40
+ support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
41
+
42
+ @dataclass
43
+ class GenerationConfig:
44
+ """Configuration for video generation"""
45
+
46
+ # general content settings
47
+ prompt: str = ""
48
+ negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
49
+
50
+ # video model settings (will be used during generation of the initial raw video clip)
51
+ # we use small values to make things a bit faster
52
+ width: int = 768
53
+ height: int = 416
54
+
55
+
56
+ # this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
57
+ # after a quick benchmark using the value 70 seems like a sweet spot
58
+ input_image_quality: int = 70
59
+
60
+ # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
61
+ # The value must be a multiple of 8, plus 1 frame.
62
+ # visual glitches appear after about 169 frames, so we don't need more actually
63
+ num_frames: int = (8 * 14) + 1
64
+
65
+ # values between 3.0 and 4.0 are nice
66
+ guidance_scale: float = 3.5
67
+
68
+ num_inference_steps: int = 50
69
+
70
+ # reproducible generation settings
71
+ seed: int = -1 # -1 means random seed
72
+
73
+ # varnish settings (will be used for post-processing after the raw video clip has been generated
74
+ fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
75
+ double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
76
+ super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
77
+
78
+ grain_amount: float = 0.0 # be careful, adding film grian can negatively impact video compression
79
+
80
+ # audio settings
81
+ enable_audio: bool = False # Whether to generate audio
82
+ audio_prompt: str = "" # Text prompt for audio generation
83
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
84
+
85
+ # The range of the CRF scale is 0–51, where:
86
+ # 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
87
+ # 23 is the default
88
+ # 51 is worst quality possible
89
+ # A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
90
+ # Consider 17 or 18 to be visually lossless or nearly so;
91
+ # it should look the same or nearly the same as the input but it isn't technically lossless.
92
+ # The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
93
+ quality: int = 18
94
+
95
+ # TeaCache settings
96
+ enable_teacache: bool = True
97
+ teacache_threshold: float = 0.05 # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
98
+
99
+ # Enhance-A-Video settings
100
+ enable_enhance_a_video: bool = True
101
+ enhance_a_video_weight: float = 5.0
102
+
103
+ # LoRA settings
104
+ lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
105
+ lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
106
+ lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
107
+
108
+ def validate_and_adjust(self) -> 'GenerationConfig':
109
+ """Validate and adjust parameters to meet constraints"""
110
+ # First check if it's one of our explicitly allowed resolutions
111
+ if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
112
+ (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
113
+ # For other resolutions, ensure total pixels don't exceed max
114
+ MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
115
+
116
+ # If total pixels exceed maximum, scale down proportionally
117
+ total_pixels = self.width * self.height
118
+ if total_pixels > MAX_TOTAL_PIXELS:
119
+ scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
120
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
121
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
122
+ else:
123
+ # Round dimensions to nearest multiple of 32
124
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
125
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
126
+
127
+ # Adjust number of frames to be in format 8k + 1
128
+ k = (self.num_frames - 1) // 8
129
+ self.num_frames = min((k * 8) + 1, MAX_FRAMES)
130
+
131
+ # Set random seed if not specified
132
+ if self.seed == -1:
133
+ self.seed = random.randint(0, 2**32 - 1)
134
+
135
+ return self
136
+
137
+ class EndpointHandler:
138
+ """Handles video generation requests using LTX models and Varnish post-processing"""
139
+
140
+ def __init__(self, model_path: str = ""):
141
+ """Initialize the handler with LTX models and Varnish
142
+
143
+ Args:
144
+ model_path: Path to LTX model weights
145
+ """
146
+ # Enable TF32 for potential speedup on Ampere GPUs
147
+ #torch.backends.cuda.matmul.allow_tf32 = True
148
+
149
+ # use distilled weights
150
+ model_path = f"{model_path}/ltxv-2b-0.9.6-distilled-04-25.safetensors"
151
+
152
+ transformer = LTXVideoTransformer3DModel.from_single_file(
153
+ model_path, torch_dtype=torch.bfloat16
154
+ )
155
+
156
+ vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
157
+
158
+ if support_image_prompt:
159
+ self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
160
+ model_path,
161
+ transformer=transformer,
162
+ vae=vae,
163
+ torch_dtype=torch.bfloat16
164
+ ).to("cuda")
165
+
166
+ #apply_teacache(self.image_to_video)
167
+
168
+ # Compilation requires some time to complete, so it is best suited for
169
+ # situations where you prepare your pipeline once and then perform the
170
+ # same type of inference operations multiple times.
171
+ # For example, calling the compiled pipeline on a different image size
172
+ # triggers compilation again which can be expensive.
173
+ #self.image_to_video.unet = torch.compile(self.image_to_video.unet, mode="reduce-overhead", fullgraph=True)
174
+
175
+ else:
176
+ # Initialize models with bfloat16 precision
177
+ self.text_to_video = LTXPipeline.from_pretrained(
178
+ model_path,
179
+ transformer=transformer,
180
+ vae=vae,
181
+ torch_dtype=torch.bfloat16
182
+ ).to("cuda")
183
+
184
+ #apply_teacache(self.text_to_video)
185
+
186
+ # Compilation requires some time to complete, so it is best suited for
187
+ # situations where you prepare your pipeline once and then perform the
188
+ # same type of inference operations multiple times.
189
+ # For example, calling the compiled pipeline on a different image size
190
+ # triggers compilation again which can be expensive.
191
+ #self.text_to_video.unet = torch.compile(self.text_to_video.unet, mode="reduce-overhead", fullgraph=True)
192
+
193
+
194
+ # Initialize LoRA tracking
195
+ self._current_lora_model = None
196
+
197
+ #if support_image_prompt:
198
+ # # Enable CPU offload for memory efficiency
199
+ # self.image_to_video.enable_model_cpu_offload()
200
+ # # Inject enhance-a-video functionality
201
+ # inject_enhance_for_ltx(self.image_to_video.transformer)
202
+ #else:
203
+ # # Enable CPU offload for memory efficiency
204
+ # self.text_to_video.enable_model_cpu_offload()
205
+ # # Inject enhance-a-video functionality
206
+ # inject_enhance_for_ltx(self.text_to_video.transformer)
207
+
208
+
209
+ # Initialize Varnish for post-processing
210
+ self.varnish = Varnish(
211
+ device="cuda",
212
+ model_base_dir="/repository/varnish",
213
+
214
+ # there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
215
+ # not sure how to fix that.. :/
216
+ #
217
+ # it says:
218
+ # File "dist-packages/varnish.py", line 152, in __init__
219
+ # self._setup_mmaudio()
220
+ # File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
221
+ # net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
222
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
223
+ # File "dist-packages/torch/serialization.py", line 1384, in load
224
+ # return _legacy_load(
225
+ # ^^^^^^^^^^^^^
226
+ # File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
227
+ # magic_number = pickle_module.load(f, **pickle_load_args)
228
+ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
229
+ # _pickle.UnpicklingError: invalid load key, '<'.
230
+ enable_mmaudio=True,
231
+ )
232
+
233
+ # Determine if TeaCache is already installed or not
234
+ self.text_to_video_teacache = False
235
+ self.image_to_video_teacache = False
236
+
237
+
238
+ async def process_frames(
239
+ self,
240
+ frames: torch.Tensor,
241
+ config: GenerationConfig
242
+ ) -> tuple[str, dict]:
243
+ """Post-process generated frames using Varnish
244
+
245
+ Args:
246
+ frames: Generated video frames tensor
247
+ config: Generation configuration
248
+
249
+ Returns:
250
+ Tuple of (video data URI, metadata dictionary)
251
+ """
252
+ try:
253
+ # Process video with Varnish
254
+ result = await self.varnish(
255
+ input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
256
+ fps=config.fps, # this is the FPS of the final output video. This number can be used by Varnish to calculate the duration of a clip ((using frames * factor) / fps etc)
257
+ double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
258
+ super_resolution=config.super_resolution, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
259
+ grain_amount=config.grain_amount,
260
+ enable_audio=config.enable_audio,
261
+ audio_prompt=config.audio_prompt,
262
+ audio_negative_prompt=config.audio_negative_prompt,
263
+ )
264
+
265
+ # Convert to data URI
266
+ video_uri = await result.write(type="data-uri", quality=config.quality)
267
+
268
+ # Collect metadata
269
+ metadata = {
270
+ "width": result.metadata.width,
271
+ "height": result.metadata.height,
272
+ "num_frames": result.metadata.frame_count,
273
+ "fps": result.metadata.fps,
274
+ "duration": result.metadata.duration,
275
+ "seed": config.seed,
276
+ }
277
+
278
+ return video_uri, metadata
279
+
280
+ except Exception as e:
281
+ logger.error(f"Error in process_frames: {str(e)}")
282
+ raise RuntimeError(f"Failed to process frames: {str(e)}")
283
+
284
+
285
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
286
+ """Process incoming requests for video generation
287
+
288
+ Args:
289
+ data: Request data containing:
290
+ - inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
291
+ - parameters (dict):
292
+ - prompt (required, string): list of concepts to keep in the video.
293
+ - negative_prompt (optional, string): list of concepts to ignore in the video.
294
+ - width (optional, int, default to 768): width, or horizontal size in pixels.
295
+ - height (optional, int, default to 512): height, or vertical size in pixels.
296
+ - input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
297
+ - num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
298
+ - guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
299
+ - num_inference_steps (optional, int, default to 50): number of inference steps
300
+ - seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
301
+ - fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
302
+ - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
303
+ - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
304
+ - grain_amount (optional, float): amount of film grain to add to the output video
305
+ - enable_audio (optional, bool): automatically generate an audio track
306
+ - audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
307
+ - audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
308
+ - quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
309
+ - enable_teacache (optional, bool, default to True): Generate faster at the cost of a slight quality loss
310
+ - teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
311
+ - enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
312
+ - enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
313
+ - lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
314
+ - lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
315
+ - lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
316
+ Returns:
317
+ Dictionary containing:
318
+ - video: Base64 encoded MP4 data URI
319
+ - content-type: MIME type
320
+ - metadata: Generation metadata
321
+ """
322
+ inputs = data.get("inputs", dict())
323
+
324
+ input_prompt = inputs.get("prompt", "")
325
+ input_image = inputs.get("image")
326
+
327
+ params = data.get("parameters", dict())
328
+
329
+ if not input_image and not input_prompt:
330
+ raise ValueError("Either prompt or image must be provided")
331
+
332
+ #logger.debug(f"Raw parameters:")
333
+ # pprint.pprint(params)
334
+
335
+ # Create and validate configuration
336
+ config = GenerationConfig(
337
+ # general content settings
338
+ prompt=input_prompt,
339
+ negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
340
+
341
+ # video model settings (will be used during generation of the initial raw video clip)
342
+ width=params.get("width", GenerationConfig.width),
343
+ height=params.get("height", GenerationConfig.height),
344
+ input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
345
+ num_frames=params.get("num_frames", GenerationConfig.num_frames),
346
+ guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
347
+ num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
348
+
349
+ # reproducible generation settings
350
+ seed=params.get("seed", GenerationConfig.seed),
351
+
352
+ # varnish settings (will be used for post-processing after the raw video clip has been generated)
353
+ fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
354
+ double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
355
+ super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
356
+ grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
357
+ enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
358
+ audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
359
+ audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
360
+ quality=params.get("quality", GenerationConfig.quality),
361
+
362
+ # TeaCache settings
363
+ enable_teacache=params.get("enable_teacache", True),
364
+
365
+ # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
366
+ teacache_threshold=params.get("teacache_threshold", 0.05),
367
+
368
+
369
+ # Add enhance-a-video settings
370
+ enable_enhance_a_video=params.get("enable_enhance_a_video", True),
371
+ enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
372
+
373
+ # LoRA settings
374
+ lora_model_name=params.get("lora_model_name", ""),
375
+ lora_model_weight_file=params.get("lora_model_weight_file", ""),
376
+ lora_model_trigger=params.get("lora_model_trigger", ""),
377
+ ).validate_and_adjust()
378
+
379
+ #logger.debug(f"Global request settings:")
380
+ #pprint.pprint(config)
381
+
382
+ try:
383
+ with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
384
+ # Set random seeds
385
+ random.seed(config.seed)
386
+ np.random.seed(config.seed)
387
+ torch.manual_seed(config.seed)
388
+ generator = torch.Generator(device='cuda')
389
+ generator = generator.manual_seed(config.seed)
390
+
391
+ # Configure enhance-a-video
392
+ #if config.enable_enhance_a_video:
393
+ # enable_enhance()
394
+ # set_enhance_weight(config.enhance_a_video_weight)
395
+
396
+ # Prepare generation parameters for the video model (we omit params that are destined to Varnish, or things like the seed which is set externally)
397
+ generation_kwargs = {
398
+ # general content settings
399
+ "prompt": config.prompt,
400
+ "negative_prompt": config.negative_prompt,
401
+
402
+ # video model settings (will be used during generation of the initial raw video clip)
403
+ "width": config.width,
404
+ "height": config.height,
405
+ "num_frames": config.num_frames,
406
+ "guidance_scale": config.guidance_scale,
407
+ "num_inference_steps": config.num_inference_steps,
408
+
409
+ # constants
410
+ "output_type": "pt",
411
+ "generator": generator,
412
+
413
+ # Timestep for decoding VAE noise: the timestep at which generated video is decoded
414
+ "decode_timestep": 0.05,
415
+
416
+ # Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
417
+ "decode_noise_scale": 0.025,
418
+ }
419
+ #logger.info(f"Video model generation settings:")
420
+ #pprint.pprint(generation_kwargs)
421
+
422
+ # Handle LoRA loading/unloading
423
+ if hasattr(self, '_current_lora_model'):
424
+ if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
425
+ # Unload previous LoRA if it exists and is different
426
+ if hasattr(self.text_to_video, 'unload_lora_weights'):
427
+ print("Unloading LoRA weights for the text_to_video pipeline..")
428
+ self.text_to_video.unload_lora_weights()
429
+
430
+ if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
431
+ print("Unloading LoRA weights for the image_to_video pipeline..")
432
+ self.image_to_video.unload_lora_weights()
433
+
434
+ if config.lora_model_name:
435
+ # Load new LoRA
436
+ if hasattr(self.text_to_video, 'load_lora_weights'):
437
+ print("Loading LoRA weights for the text_to_video pipeline..")
438
+ self.text_to_video.load_lora_weights(
439
+ config.lora_model_name,
440
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
441
+ token=hf_token,
442
+ )
443
+ if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
444
+ print("Loading LoRA weights for the image_to_video pipeline..")
445
+ self.image_to_video.load_lora_weights(
446
+ config.lora_model_name,
447
+ weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
448
+ token=hf_token,
449
+ )
450
+ self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
451
+
452
+ # Modify prompt if trigger word is provided
453
+ if config.lora_model_trigger:
454
+ generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
455
+
456
+ #enhance_a_video_config = EnhanceAVideoConfig(
457
+ # weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
458
+ # # doing some testing
459
+ # num_frames_callback=lambda: (8 + 1),
460
+ # # num_frames_callback=lambda: config.num_frames,
461
+ # # num_frames_callback=lambda: (config.num_frames - 1),
462
+ #
463
+ # _attention_type=1
464
+ #)
465
+
466
+ # Check if image-to-video generation is requested
467
+ if support_image_prompt and input_image:
468
+ processed_image = process_input_image(
469
+ input_image,
470
+ config.width,
471
+ config.height,
472
+ config.input_image_quality,
473
+ )
474
+ generation_kwargs["image"] = processed_image
475
+ # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
476
+ # apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
477
+ frames = self.image_to_video(**generation_kwargs).frames
478
+ else:
479
+ # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
480
+ # apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
481
+ frames = self.text_to_video(**generation_kwargs).frames
482
+
483
+ try:
484
+ loop = asyncio.get_event_loop()
485
+ except RuntimeError:
486
+ loop = asyncio.new_event_loop()
487
+ asyncio.set_event_loop(loop)
488
+
489
+ video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
490
+
491
+ torch.cuda.empty_cache()
492
+ torch.cuda.reset_peak_memory_stats()
493
+ gc.collect()
494
+
495
+ return {
496
+ "video": video_uri,
497
+ "content-type": "video/mp4",
498
+ "metadata": metadata
499
+ }
500
+
501
+ except Exception as e:
502
+ message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
503
+ print(message)
504
+ raise RuntimeError(message)