jbilcke-hf HF Staff commited on
Commit
2d2c4c3
·
verified ·
1 Parent(s): bea6e03

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +601 -0
handler.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ import logging
4
+ import base64
5
+ import random
6
+ import gc
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ from typing import Dict, Any, Optional, List, Union, Tuple
11
+ import json
12
+ from safetensors import safe_open
13
+
14
+ from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
15
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
16
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
17
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter
18
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
19
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
20
+ from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
21
+
22
+ from varnish import Varnish
23
+ from varnish.utils import is_truthy, process_input_image
24
+
25
+ # Configure logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Get token from environment
30
+ hf_token = os.getenv("HF_API_TOKEN")
31
+
32
+ # Constraints
33
+ MAX_LARGE_SIDE = 1280
34
+ MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
35
+ MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
36
+
37
+ # Check environment variable for pipeline support
38
+ support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
39
+
40
+ @dataclass
41
+ class GenerationConfig:
42
+ """Configuration for video generation"""
43
+
44
+ # general content settings
45
+ prompt: str = ""
46
+ negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
47
+
48
+ # video model settings (will be used during generation of the initial raw video clip)
49
+ width: int = 768
50
+ height: int = 416
51
+
52
+ # this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
53
+ # after a quick benchmark using the value 70 seems like a sweet spot
54
+ input_image_quality: int = 70
55
+
56
+ # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
57
+ # The value must be a multiple of 8, plus 1 frame.
58
+ # visual glitches appear after about 169 frames, so we don't need more actually
59
+ num_frames: int = (8 * 14) + 1
60
+
61
+ # values between 3.0 and 4.0 are nice
62
+ guidance_scale: float = 3.5
63
+
64
+ num_inference_steps: int = 50
65
+
66
+ # reproducible generation settings
67
+ seed: int = -1 # -1 means random seed
68
+
69
+ # varnish settings (will be used for post-processing after the raw video clip has been generated
70
+ fps: int = 30 # FPS of the final video (only applied at the very end, when converting to mp4)
71
+ double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
72
+ super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
73
+
74
+ grain_amount: float = 0.0 # be careful, adding film grain can negatively impact video compression
75
+
76
+ # audio settings
77
+ enable_audio: bool = False # Whether to generate audio
78
+ audio_prompt: str = "" # Text prompt for audio generation
79
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
80
+
81
+ # The range of the CRF scale is 0–51, where:
82
+ # 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
83
+ # 23 is the default
84
+ # 51 is worst quality possible
85
+ # A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
86
+ # Consider 17 or 18 to be visually lossless or nearly so;
87
+ # it should look the same or nearly the same as the input but it isn't technically lossless.
88
+ # The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
89
+ quality: int = 18
90
+
91
+ # STG (Spatiotemporal Guidance) settings
92
+ stg_scale: float = 1.0
93
+ stg_rescale: float = 0.7
94
+ stg_mode: str = "attention_values" # Can be "attention_values", "attention_skip", "residual", or "transformer_block"
95
+ stg_skip_layers: str = "19" # Comma-separated list of layers to block for spatiotemporal guidance
96
+
97
+ # VAE noise augmentation
98
+ decode_timestep: float = 0.05
99
+ decode_noise_scale: float = 0.025
100
+
101
+ # Other advanced settings
102
+ image_cond_noise_scale: float = 0.15
103
+ mixed_precision: bool = True # Use mixed precision for inference
104
+ stochastic_sampling: bool = False # Use stochastic sampling
105
+
106
+ # Sampling settings
107
+ sampler: Optional[str] = None # "uniform" or "linear-quadratic" or None (use default from checkpoint)
108
+
109
+ # Prompt enhancement
110
+ enhance_prompt: bool = False # Whether to enhance the prompt using an LLM
111
+ prompt_enhancement_words_threshold: int = 50 # Enhance prompt only if it has fewer words than this
112
+
113
+ def validate_and_adjust(self) -> 'GenerationConfig':
114
+ """Validate and adjust parameters to meet constraints"""
115
+ # First check if it's one of our explicitly allowed resolutions
116
+ if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
117
+ (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
118
+ # For other resolutions, ensure total pixels don't exceed max
119
+ MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
120
+
121
+ # If total pixels exceed maximum, scale down proportionally
122
+ total_pixels = self.width * self.height
123
+ if total_pixels > MAX_TOTAL_PIXELS:
124
+ scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
125
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
126
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
127
+ else:
128
+ # Round dimensions to nearest multiple of 32
129
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
130
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
131
+
132
+ # Adjust number of frames to be in format 8k + 1
133
+ k = (self.num_frames - 1) // 8
134
+ self.num_frames = min((k * 8) + 1, MAX_FRAMES)
135
+
136
+ # Set random seed if not specified
137
+ if self.seed == -1:
138
+ self.seed = random.randint(0, 2**32 - 1)
139
+
140
+ # Set up STG parameters
141
+ if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values":
142
+ self.stg_mode = "attention_values"
143
+ elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip":
144
+ self.stg_mode = "attention_skip"
145
+ elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual":
146
+ self.stg_mode = "residual"
147
+ elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block":
148
+ self.stg_mode = "transformer_block"
149
+
150
+ # Convert STG skip layers from string to list of integers
151
+ if isinstance(self.stg_skip_layers, str):
152
+ self.stg_skip_layers = [int(x.strip()) for x in self.stg_skip_layers.split(",")]
153
+
154
+ # Check if we should enhance the prompt
155
+ if self.enhance_prompt and self.prompt:
156
+ prompt_word_count = len(self.prompt.split())
157
+ if prompt_word_count >= self.prompt_enhancement_words_threshold:
158
+ logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.")
159
+ self.enhance_prompt = False
160
+
161
+ return self
162
+
163
+ def load_image_to_tensor_with_resize_and_crop(
164
+ image_input: Union[str, bytes],
165
+ target_height: int = 512,
166
+ target_width: int = 768,
167
+ quality: int = 100
168
+ ) -> torch.Tensor:
169
+ """Load and process an image into a tensor.
170
+
171
+ Args:
172
+ image_input: Either a file path (str) or image data (bytes)
173
+ target_height: Desired height of output tensor
174
+ target_width: Desired width of output tensor
175
+ quality: JPEG quality to use when re-encoding (to simulate lower quality images)
176
+ """
177
+ from PIL import Image
178
+ import io
179
+ import numpy as np
180
+
181
+ # Handle base64 data URI
182
+ if isinstance(image_input, str) and image_input.startswith('data:'):
183
+ header, encoded = image_input.split(",", 1)
184
+ image_data = base64.b64decode(encoded)
185
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
186
+ # Handle raw bytes
187
+ elif isinstance(image_input, bytes):
188
+ image = Image.open(io.BytesIO(image_input)).convert("RGB")
189
+ # Handle file path
190
+ elif isinstance(image_input, str):
191
+ image = Image.open(image_input).convert("RGB")
192
+ else:
193
+ raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
194
+
195
+ # Apply JPEG compression if quality < 100 (to simulate a video frame)
196
+ if quality < 100:
197
+ buffer = io.BytesIO()
198
+ image.save(buffer, format="JPEG", quality=quality)
199
+ buffer.seek(0)
200
+ image = Image.open(buffer).convert("RGB")
201
+
202
+ input_width, input_height = image.size
203
+ aspect_ratio_target = target_width / target_height
204
+ aspect_ratio_frame = input_width / input_height
205
+ if aspect_ratio_frame > aspect_ratio_target:
206
+ new_width = int(input_height * aspect_ratio_target)
207
+ new_height = input_height
208
+ x_start = (input_width - new_width) // 2
209
+ y_start = 0
210
+ else:
211
+ new_width = input_width
212
+ new_height = int(input_width / aspect_ratio_target)
213
+ x_start = 0
214
+ y_start = (input_height - new_height) // 2
215
+
216
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
217
+ image = image.resize((target_width, target_height))
218
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
219
+ frame_tensor = (frame_tensor / 127.5) - 1.0
220
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
221
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
222
+
223
+ def calculate_padding(
224
+ source_height: int, source_width: int, target_height: int, target_width: int
225
+ ) -> tuple[int, int, int, int]:
226
+ """Calculate padding to reach target dimensions"""
227
+ # Calculate total padding needed
228
+ pad_height = target_height - source_height
229
+ pad_width = target_width - source_width
230
+
231
+ # Calculate padding for each side
232
+ pad_top = pad_height // 2
233
+ pad_bottom = pad_height - pad_top # Handles odd padding
234
+ pad_left = pad_width // 2
235
+ pad_right = pad_width - pad_left # Handles odd padding
236
+
237
+ # Return padded tensor
238
+ # Padding format is (left, right, top, bottom)
239
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
240
+ return padding
241
+
242
+ def prepare_conditioning(
243
+ conditioning_media_paths: List[str],
244
+ conditioning_strengths: List[float],
245
+ conditioning_start_frames: List[int],
246
+ height: int,
247
+ width: int,
248
+ num_frames: int,
249
+ input_image_quality: int = 100,
250
+ pipeline: Optional[LTXVideoPipeline] = None,
251
+ ) -> Optional[List[ConditioningItem]]:
252
+ """Prepare conditioning items based on input media paths and their parameters"""
253
+ conditioning_items = []
254
+ for path, strength, start_frame in zip(
255
+ conditioning_media_paths, conditioning_strengths, conditioning_start_frames
256
+ ):
257
+ # Load and process the conditioning image
258
+ frame_tensor = load_image_to_tensor_with_resize_and_crop(
259
+ path, height, width, quality=input_image_quality
260
+ )
261
+
262
+ # Trim frame count if needed
263
+ if pipeline:
264
+ frame_count = 1 # For image inputs, it's always 1
265
+ frame_count = pipeline.trim_conditioning_sequence(
266
+ start_frame, frame_count, num_frames
267
+ )
268
+
269
+ conditioning_items.append(
270
+ ConditioningItem(frame_tensor, start_frame, strength)
271
+ )
272
+
273
+ return conditioning_items
274
+
275
+ def create_ltx_video_pipeline(
276
+ config: GenerationConfig,
277
+ device: str = "cuda"
278
+ ) -> LTXVideoPipeline:
279
+ """Create and configure the LTX video pipeline"""
280
+ # Get the absolute paths for the model components
281
+ current_dir = Path.cwd()
282
+
283
+ # Get allowed inference steps from config if available
284
+ allowed_inference_steps = None
285
+ try:
286
+ # Load allowed inference steps from metadata if available
287
+ if Path("transformer/config.json").exists():
288
+ with open("transformer/config.json", "r") as f:
289
+ config_data = json.load(f)
290
+ allowed_inference_steps = config_data.get("allowed_inference_steps")
291
+ except Exception as e:
292
+ logger.warning(f"Failed to load allowed_inference_steps from config: {e}")
293
+
294
+ # Initialize model components
295
+ vae = CausalVideoAutoencoder.from_pretrained(".")
296
+ transformer = Transformer3DModel.from_pretrained(".")
297
+
298
+ # Use constructor if sampler is specified, otherwise use from_pretrained
299
+ if config.sampler:
300
+ scheduler = RectifiedFlowScheduler(
301
+ sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
302
+ )
303
+ else:
304
+ scheduler = RectifiedFlowScheduler.from_pretrained(".")
305
+
306
+ text_encoder = T5EncoderModel.from_pretrained("text_encoder")
307
+ patchifier = SymmetricPatchifier(patch_size=1)
308
+ tokenizer = T5Tokenizer.from_pretrained("tokenizer")
309
+
310
+ # Move models to the correct device
311
+ vae = vae.to(device)
312
+ transformer = transformer.to(device)
313
+ text_encoder = text_encoder.to(device)
314
+
315
+ # Set up precision
316
+ vae = vae.to(torch.bfloat16)
317
+ transformer = transformer.to(torch.bfloat16)
318
+ text_encoder = text_encoder.to(torch.bfloat16)
319
+
320
+ # Initialize prompt enhancer components if needed
321
+ prompt_enhancer_components = {
322
+ "prompt_enhancer_image_caption_model": None,
323
+ "prompt_enhancer_image_caption_processor": None,
324
+ "prompt_enhancer_llm_model": None,
325
+ "prompt_enhancer_llm_tokenizer": None
326
+ }
327
+
328
+ if config.enhance_prompt:
329
+ try:
330
+ # Use default models or ones specified by config
331
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
332
+ "MiaoshouAI/Florence-2-large-PromptGen-v2.0",
333
+ trust_remote_code=True
334
+ )
335
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
336
+ "MiaoshouAI/Florence-2-large-PromptGen-v2.0",
337
+ trust_remote_code=True
338
+ )
339
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
340
+ "unsloth/Llama-3.2-3B-Instruct",
341
+ torch_dtype="bfloat16",
342
+ )
343
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
344
+ "unsloth/Llama-3.2-3B-Instruct",
345
+ )
346
+
347
+ prompt_enhancer_components = {
348
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
349
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
350
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
351
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer
352
+ }
353
+ except Exception as e:
354
+ logger.warning(f"Failed to load prompt enhancer models: {e}")
355
+ config.enhance_prompt = False
356
+
357
+ # Construct the pipeline
358
+ pipeline = LTXVideoPipeline(
359
+ transformer=transformer,
360
+ patchifier=patchifier,
361
+ text_encoder=text_encoder,
362
+ tokenizer=tokenizer,
363
+ scheduler=scheduler,
364
+ vae=vae,
365
+ allowed_inference_steps=allowed_inference_steps,
366
+ **prompt_enhancer_components
367
+ )
368
+
369
+ return pipeline
370
+
371
+ class EndpointHandler:
372
+ """Handler for the LTX Video endpoint"""
373
+
374
+ def __init__(self, model_path: str = ""):
375
+ """Initialize the endpoint handler
376
+
377
+ Args:
378
+ model_path: Path to model weights (not used, as weights are in current directory)
379
+ """
380
+ # Enable TF32 for potential speedup on Ampere GPUs
381
+ torch.backends.cuda.matmul.allow_tf32 = True
382
+
383
+ # Initialize Varnish for post-processing
384
+ self.varnish = Varnish(
385
+ device="cuda",
386
+ model_base_dir="varnish",
387
+ enable_mmaudio=False, # Disable audio generation for now, since it is broken
388
+ )
389
+
390
+ # The actual LTX pipeline will be loaded during inference to save memory
391
+ self.pipeline = None
392
+
393
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
394
+ """Process inference requests
395
+
396
+ Args:
397
+ data: Request data containing inputs and parameters
398
+
399
+ Returns:
400
+ Dictionary with generated video and metadata
401
+ """
402
+ # Extract inputs and parameters
403
+ inputs = data.get("inputs", {})
404
+
405
+ # Support both formats:
406
+ # 1. {"inputs": {"prompt": "...", "image": "..."}}
407
+ # 2. {"inputs": "..."} (prompt only)
408
+ if isinstance(inputs, str):
409
+ input_prompt = inputs
410
+ input_image = None
411
+ else:
412
+ input_prompt = inputs.get("prompt", "")
413
+ input_image = inputs.get("image")
414
+
415
+ params = data.get("parameters", {})
416
+
417
+ if not input_prompt and not input_image:
418
+ raise ValueError("Either prompt or image must be provided")
419
+
420
+ # Create and validate configuration
421
+ config = GenerationConfig(
422
+ # general content settings
423
+ prompt=input_prompt,
424
+ negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
425
+
426
+ # video model settings
427
+ width=params.get("width", GenerationConfig.width),
428
+ height=params.get("height", GenerationConfig.height),
429
+ input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
430
+ num_frames=params.get("num_frames", GenerationConfig.num_frames),
431
+ guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
432
+ num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
433
+
434
+ # STG settings
435
+ stg_scale=params.get("stg_scale", GenerationConfig.stg_scale),
436
+ stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale),
437
+ stg_mode=params.get("stg_mode", GenerationConfig.stg_mode),
438
+ stg_skip_layers=params.get("stg_skip_layers", GenerationConfig.stg_skip_layers),
439
+
440
+ # VAE noise settings
441
+ decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep),
442
+ decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale),
443
+ image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale),
444
+
445
+ # reproducible generation settings
446
+ seed=params.get("seed", GenerationConfig.seed),
447
+
448
+ # varnish settings
449
+ fps=params.get("fps", GenerationConfig.fps),
450
+ double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
451
+ super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
452
+ grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
453
+ enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
454
+ audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
455
+ audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
456
+ quality=params.get("quality", GenerationConfig.quality),
457
+
458
+ # advanced settings
459
+ mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
460
+ stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling),
461
+ sampler=params.get("sampler", GenerationConfig.sampler),
462
+
463
+ # prompt enhancement
464
+ enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt),
465
+ prompt_enhancement_words_threshold=params.get(
466
+ "prompt_enhancement_words_threshold",
467
+ GenerationConfig.prompt_enhancement_words_threshold
468
+ ),
469
+ ).validate_and_adjust()
470
+
471
+ try:
472
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
473
+ # Set random seeds for reproducibility
474
+ random.seed(config.seed)
475
+ np.random.seed(config.seed)
476
+ torch.manual_seed(config.seed)
477
+ generator = torch.Generator(device='cuda').manual_seed(config.seed)
478
+
479
+ # Create pipeline if not already created
480
+ if self.pipeline is None:
481
+ self.pipeline = create_ltx_video_pipeline(config)
482
+
483
+ # Prepare conditioning items if an image is provided
484
+ conditioning_items = None
485
+ if input_image:
486
+ conditioning_items = [
487
+ ConditioningItem(
488
+ load_image_to_tensor_with_resize_and_crop(
489
+ input_image,
490
+ config.height,
491
+ config.width,
492
+ quality=config.input_image_quality
493
+ ),
494
+ 0, # Start frame
495
+ 1.0 # Conditioning strength
496
+ )
497
+ ]
498
+
499
+ # Set up spatiotemporal guidance strategy
500
+ if config.stg_mode == "attention_values":
501
+ skip_layer_strategy = SkipLayerStrategy.AttentionValues
502
+ elif config.stg_mode == "attention_skip":
503
+ skip_layer_strategy = SkipLayerStrategy.AttentionSkip
504
+ elif config.stg_mode == "residual":
505
+ skip_layer_strategy = SkipLayerStrategy.Residual
506
+ elif config.stg_mode == "transformer_block":
507
+ skip_layer_strategy = SkipLayerStrategy.TransformerBlock
508
+
509
+ # Generate video with LTX pipeline
510
+ result = self.pipeline(
511
+ height=config.height,
512
+ width=config.width,
513
+ num_frames=config.num_frames,
514
+ frame_rate=config.fps,
515
+ prompt=config.prompt,
516
+ negative_prompt=config.negative_prompt,
517
+ guidance_scale=config.guidance_scale,
518
+ num_inference_steps=config.num_inference_steps,
519
+ generator=generator,
520
+ output_type="pt", # Return as PyTorch tensor
521
+ skip_layer_strategy=skip_layer_strategy,
522
+ skip_block_list=config.stg_skip_layers,
523
+ stg_scale=config.stg_scale,
524
+ do_rescaling=config.stg_rescale != 1.0,
525
+ rescaling_scale=config.stg_rescale,
526
+ conditioning_items=conditioning_items,
527
+ decode_timestep=config.decode_timestep,
528
+ decode_noise_scale=config.decode_noise_scale,
529
+ image_cond_noise_scale=config.image_cond_noise_scale,
530
+ mixed_precision=config.mixed_precision,
531
+ is_video=True,
532
+ vae_per_channel_normalize=True,
533
+ stochastic_sampling=config.stochastic_sampling,
534
+ enhance_prompt=config.enhance_prompt,
535
+ )
536
+
537
+ # Get the generated frames
538
+ frames = result.images
539
+
540
+ # Process the generated frames with Varnish
541
+ import asyncio
542
+ try:
543
+ loop = asyncio.get_event_loop()
544
+ except RuntimeError:
545
+ loop = asyncio.new_event_loop()
546
+ asyncio.set_event_loop(loop)
547
+
548
+ # Prepare frames for Varnish (denormalize to 0-255 range)
549
+ frames = frames * 127.5 + 127.5
550
+ frames = frames.to(torch.uint8)
551
+
552
+ # Process with Varnish for post-processing
553
+ varnish_result = loop.run_until_complete(
554
+ self.varnish(
555
+ frames,
556
+ fps=config.fps,
557
+ double_num_frames=config.double_num_frames,
558
+ super_resolution=config.super_resolution,
559
+ grain_amount=config.grain_amount,
560
+ enable_audio=config.enable_audio,
561
+ audio_prompt=config.audio_prompt or config.prompt,
562
+ audio_negative_prompt=config.audio_negative_prompt,
563
+ )
564
+ )
565
+
566
+ # Get the final video as a data URI
567
+ video_uri = loop.run_until_complete(
568
+ varnish_result.write(
569
+ type="data-uri",
570
+ quality=config.quality
571
+ )
572
+ )
573
+
574
+ # Prepare metadata about the generated video
575
+ metadata = {
576
+ "width": varnish_result.metadata.width,
577
+ "height": varnish_result.metadata.height,
578
+ "num_frames": varnish_result.metadata.frame_count,
579
+ "fps": varnish_result.metadata.fps,
580
+ "duration": varnish_result.metadata.duration,
581
+ "seed": config.seed,
582
+ "prompt": config.prompt,
583
+ }
584
+
585
+ # Clean up to prevent CUDA OOM errors
586
+ del result
587
+ torch.cuda.empty_cache()
588
+ gc.collect()
589
+
590
+ return {
591
+ "video": video_uri,
592
+ "content-type": "video/mp4",
593
+ "metadata": metadata
594
+ }
595
+
596
+ except Exception as e:
597
+ # Log the error and reraise
598
+ import traceback
599
+ error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
600
+ logger.error(error_message)
601
+ raise RuntimeError(error_message)