Create handler.py
Browse files- handler.py +601 -0
handler.py
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
import logging
|
4 |
+
import base64
|
5 |
+
import random
|
6 |
+
import gc
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from typing import Dict, Any, Optional, List, Union, Tuple
|
11 |
+
import json
|
12 |
+
from safetensors import safe_open
|
13 |
+
|
14 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
15 |
+
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
16 |
+
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
17 |
+
from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter
|
18 |
+
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
|
19 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
20 |
+
from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
21 |
+
|
22 |
+
from varnish import Varnish
|
23 |
+
from varnish.utils import is_truthy, process_input_image
|
24 |
+
|
25 |
+
# Configure logging
|
26 |
+
logging.basicConfig(level=logging.INFO)
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
# Get token from environment
|
30 |
+
hf_token = os.getenv("HF_API_TOKEN")
|
31 |
+
|
32 |
+
# Constraints
|
33 |
+
MAX_LARGE_SIDE = 1280
|
34 |
+
MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
|
35 |
+
MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
|
36 |
+
|
37 |
+
# Check environment variable for pipeline support
|
38 |
+
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class GenerationConfig:
|
42 |
+
"""Configuration for video generation"""
|
43 |
+
|
44 |
+
# general content settings
|
45 |
+
prompt: str = ""
|
46 |
+
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
|
47 |
+
|
48 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
49 |
+
width: int = 768
|
50 |
+
height: int = 416
|
51 |
+
|
52 |
+
# this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
|
53 |
+
# after a quick benchmark using the value 70 seems like a sweet spot
|
54 |
+
input_image_quality: int = 70
|
55 |
+
|
56 |
+
# users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
|
57 |
+
# The value must be a multiple of 8, plus 1 frame.
|
58 |
+
# visual glitches appear after about 169 frames, so we don't need more actually
|
59 |
+
num_frames: int = (8 * 14) + 1
|
60 |
+
|
61 |
+
# values between 3.0 and 4.0 are nice
|
62 |
+
guidance_scale: float = 3.5
|
63 |
+
|
64 |
+
num_inference_steps: int = 50
|
65 |
+
|
66 |
+
# reproducible generation settings
|
67 |
+
seed: int = -1 # -1 means random seed
|
68 |
+
|
69 |
+
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
70 |
+
fps: int = 30 # FPS of the final video (only applied at the very end, when converting to mp4)
|
71 |
+
double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
|
72 |
+
super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
73 |
+
|
74 |
+
grain_amount: float = 0.0 # be careful, adding film grain can negatively impact video compression
|
75 |
+
|
76 |
+
# audio settings
|
77 |
+
enable_audio: bool = False # Whether to generate audio
|
78 |
+
audio_prompt: str = "" # Text prompt for audio generation
|
79 |
+
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
|
80 |
+
|
81 |
+
# The range of the CRF scale is 0–51, where:
|
82 |
+
# 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
|
83 |
+
# 23 is the default
|
84 |
+
# 51 is worst quality possible
|
85 |
+
# A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
86 |
+
# Consider 17 or 18 to be visually lossless or nearly so;
|
87 |
+
# it should look the same or nearly the same as the input but it isn't technically lossless.
|
88 |
+
# The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
89 |
+
quality: int = 18
|
90 |
+
|
91 |
+
# STG (Spatiotemporal Guidance) settings
|
92 |
+
stg_scale: float = 1.0
|
93 |
+
stg_rescale: float = 0.7
|
94 |
+
stg_mode: str = "attention_values" # Can be "attention_values", "attention_skip", "residual", or "transformer_block"
|
95 |
+
stg_skip_layers: str = "19" # Comma-separated list of layers to block for spatiotemporal guidance
|
96 |
+
|
97 |
+
# VAE noise augmentation
|
98 |
+
decode_timestep: float = 0.05
|
99 |
+
decode_noise_scale: float = 0.025
|
100 |
+
|
101 |
+
# Other advanced settings
|
102 |
+
image_cond_noise_scale: float = 0.15
|
103 |
+
mixed_precision: bool = True # Use mixed precision for inference
|
104 |
+
stochastic_sampling: bool = False # Use stochastic sampling
|
105 |
+
|
106 |
+
# Sampling settings
|
107 |
+
sampler: Optional[str] = None # "uniform" or "linear-quadratic" or None (use default from checkpoint)
|
108 |
+
|
109 |
+
# Prompt enhancement
|
110 |
+
enhance_prompt: bool = False # Whether to enhance the prompt using an LLM
|
111 |
+
prompt_enhancement_words_threshold: int = 50 # Enhance prompt only if it has fewer words than this
|
112 |
+
|
113 |
+
def validate_and_adjust(self) -> 'GenerationConfig':
|
114 |
+
"""Validate and adjust parameters to meet constraints"""
|
115 |
+
# First check if it's one of our explicitly allowed resolutions
|
116 |
+
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
|
117 |
+
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
|
118 |
+
# For other resolutions, ensure total pixels don't exceed max
|
119 |
+
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
|
120 |
+
|
121 |
+
# If total pixels exceed maximum, scale down proportionally
|
122 |
+
total_pixels = self.width * self.height
|
123 |
+
if total_pixels > MAX_TOTAL_PIXELS:
|
124 |
+
scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
|
125 |
+
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
|
126 |
+
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
|
127 |
+
else:
|
128 |
+
# Round dimensions to nearest multiple of 32
|
129 |
+
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
|
130 |
+
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
|
131 |
+
|
132 |
+
# Adjust number of frames to be in format 8k + 1
|
133 |
+
k = (self.num_frames - 1) // 8
|
134 |
+
self.num_frames = min((k * 8) + 1, MAX_FRAMES)
|
135 |
+
|
136 |
+
# Set random seed if not specified
|
137 |
+
if self.seed == -1:
|
138 |
+
self.seed = random.randint(0, 2**32 - 1)
|
139 |
+
|
140 |
+
# Set up STG parameters
|
141 |
+
if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values":
|
142 |
+
self.stg_mode = "attention_values"
|
143 |
+
elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip":
|
144 |
+
self.stg_mode = "attention_skip"
|
145 |
+
elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual":
|
146 |
+
self.stg_mode = "residual"
|
147 |
+
elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block":
|
148 |
+
self.stg_mode = "transformer_block"
|
149 |
+
|
150 |
+
# Convert STG skip layers from string to list of integers
|
151 |
+
if isinstance(self.stg_skip_layers, str):
|
152 |
+
self.stg_skip_layers = [int(x.strip()) for x in self.stg_skip_layers.split(",")]
|
153 |
+
|
154 |
+
# Check if we should enhance the prompt
|
155 |
+
if self.enhance_prompt and self.prompt:
|
156 |
+
prompt_word_count = len(self.prompt.split())
|
157 |
+
if prompt_word_count >= self.prompt_enhancement_words_threshold:
|
158 |
+
logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.")
|
159 |
+
self.enhance_prompt = False
|
160 |
+
|
161 |
+
return self
|
162 |
+
|
163 |
+
def load_image_to_tensor_with_resize_and_crop(
|
164 |
+
image_input: Union[str, bytes],
|
165 |
+
target_height: int = 512,
|
166 |
+
target_width: int = 768,
|
167 |
+
quality: int = 100
|
168 |
+
) -> torch.Tensor:
|
169 |
+
"""Load and process an image into a tensor.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
image_input: Either a file path (str) or image data (bytes)
|
173 |
+
target_height: Desired height of output tensor
|
174 |
+
target_width: Desired width of output tensor
|
175 |
+
quality: JPEG quality to use when re-encoding (to simulate lower quality images)
|
176 |
+
"""
|
177 |
+
from PIL import Image
|
178 |
+
import io
|
179 |
+
import numpy as np
|
180 |
+
|
181 |
+
# Handle base64 data URI
|
182 |
+
if isinstance(image_input, str) and image_input.startswith('data:'):
|
183 |
+
header, encoded = image_input.split(",", 1)
|
184 |
+
image_data = base64.b64decode(encoded)
|
185 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
186 |
+
# Handle raw bytes
|
187 |
+
elif isinstance(image_input, bytes):
|
188 |
+
image = Image.open(io.BytesIO(image_input)).convert("RGB")
|
189 |
+
# Handle file path
|
190 |
+
elif isinstance(image_input, str):
|
191 |
+
image = Image.open(image_input).convert("RGB")
|
192 |
+
else:
|
193 |
+
raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
|
194 |
+
|
195 |
+
# Apply JPEG compression if quality < 100 (to simulate a video frame)
|
196 |
+
if quality < 100:
|
197 |
+
buffer = io.BytesIO()
|
198 |
+
image.save(buffer, format="JPEG", quality=quality)
|
199 |
+
buffer.seek(0)
|
200 |
+
image = Image.open(buffer).convert("RGB")
|
201 |
+
|
202 |
+
input_width, input_height = image.size
|
203 |
+
aspect_ratio_target = target_width / target_height
|
204 |
+
aspect_ratio_frame = input_width / input_height
|
205 |
+
if aspect_ratio_frame > aspect_ratio_target:
|
206 |
+
new_width = int(input_height * aspect_ratio_target)
|
207 |
+
new_height = input_height
|
208 |
+
x_start = (input_width - new_width) // 2
|
209 |
+
y_start = 0
|
210 |
+
else:
|
211 |
+
new_width = input_width
|
212 |
+
new_height = int(input_width / aspect_ratio_target)
|
213 |
+
x_start = 0
|
214 |
+
y_start = (input_height - new_height) // 2
|
215 |
+
|
216 |
+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
217 |
+
image = image.resize((target_width, target_height))
|
218 |
+
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
|
219 |
+
frame_tensor = (frame_tensor / 127.5) - 1.0
|
220 |
+
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
221 |
+
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
222 |
+
|
223 |
+
def calculate_padding(
|
224 |
+
source_height: int, source_width: int, target_height: int, target_width: int
|
225 |
+
) -> tuple[int, int, int, int]:
|
226 |
+
"""Calculate padding to reach target dimensions"""
|
227 |
+
# Calculate total padding needed
|
228 |
+
pad_height = target_height - source_height
|
229 |
+
pad_width = target_width - source_width
|
230 |
+
|
231 |
+
# Calculate padding for each side
|
232 |
+
pad_top = pad_height // 2
|
233 |
+
pad_bottom = pad_height - pad_top # Handles odd padding
|
234 |
+
pad_left = pad_width // 2
|
235 |
+
pad_right = pad_width - pad_left # Handles odd padding
|
236 |
+
|
237 |
+
# Return padded tensor
|
238 |
+
# Padding format is (left, right, top, bottom)
|
239 |
+
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
240 |
+
return padding
|
241 |
+
|
242 |
+
def prepare_conditioning(
|
243 |
+
conditioning_media_paths: List[str],
|
244 |
+
conditioning_strengths: List[float],
|
245 |
+
conditioning_start_frames: List[int],
|
246 |
+
height: int,
|
247 |
+
width: int,
|
248 |
+
num_frames: int,
|
249 |
+
input_image_quality: int = 100,
|
250 |
+
pipeline: Optional[LTXVideoPipeline] = None,
|
251 |
+
) -> Optional[List[ConditioningItem]]:
|
252 |
+
"""Prepare conditioning items based on input media paths and their parameters"""
|
253 |
+
conditioning_items = []
|
254 |
+
for path, strength, start_frame in zip(
|
255 |
+
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
|
256 |
+
):
|
257 |
+
# Load and process the conditioning image
|
258 |
+
frame_tensor = load_image_to_tensor_with_resize_and_crop(
|
259 |
+
path, height, width, quality=input_image_quality
|
260 |
+
)
|
261 |
+
|
262 |
+
# Trim frame count if needed
|
263 |
+
if pipeline:
|
264 |
+
frame_count = 1 # For image inputs, it's always 1
|
265 |
+
frame_count = pipeline.trim_conditioning_sequence(
|
266 |
+
start_frame, frame_count, num_frames
|
267 |
+
)
|
268 |
+
|
269 |
+
conditioning_items.append(
|
270 |
+
ConditioningItem(frame_tensor, start_frame, strength)
|
271 |
+
)
|
272 |
+
|
273 |
+
return conditioning_items
|
274 |
+
|
275 |
+
def create_ltx_video_pipeline(
|
276 |
+
config: GenerationConfig,
|
277 |
+
device: str = "cuda"
|
278 |
+
) -> LTXVideoPipeline:
|
279 |
+
"""Create and configure the LTX video pipeline"""
|
280 |
+
# Get the absolute paths for the model components
|
281 |
+
current_dir = Path.cwd()
|
282 |
+
|
283 |
+
# Get allowed inference steps from config if available
|
284 |
+
allowed_inference_steps = None
|
285 |
+
try:
|
286 |
+
# Load allowed inference steps from metadata if available
|
287 |
+
if Path("transformer/config.json").exists():
|
288 |
+
with open("transformer/config.json", "r") as f:
|
289 |
+
config_data = json.load(f)
|
290 |
+
allowed_inference_steps = config_data.get("allowed_inference_steps")
|
291 |
+
except Exception as e:
|
292 |
+
logger.warning(f"Failed to load allowed_inference_steps from config: {e}")
|
293 |
+
|
294 |
+
# Initialize model components
|
295 |
+
vae = CausalVideoAutoencoder.from_pretrained(".")
|
296 |
+
transformer = Transformer3DModel.from_pretrained(".")
|
297 |
+
|
298 |
+
# Use constructor if sampler is specified, otherwise use from_pretrained
|
299 |
+
if config.sampler:
|
300 |
+
scheduler = RectifiedFlowScheduler(
|
301 |
+
sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(".")
|
305 |
+
|
306 |
+
text_encoder = T5EncoderModel.from_pretrained("text_encoder")
|
307 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
308 |
+
tokenizer = T5Tokenizer.from_pretrained("tokenizer")
|
309 |
+
|
310 |
+
# Move models to the correct device
|
311 |
+
vae = vae.to(device)
|
312 |
+
transformer = transformer.to(device)
|
313 |
+
text_encoder = text_encoder.to(device)
|
314 |
+
|
315 |
+
# Set up precision
|
316 |
+
vae = vae.to(torch.bfloat16)
|
317 |
+
transformer = transformer.to(torch.bfloat16)
|
318 |
+
text_encoder = text_encoder.to(torch.bfloat16)
|
319 |
+
|
320 |
+
# Initialize prompt enhancer components if needed
|
321 |
+
prompt_enhancer_components = {
|
322 |
+
"prompt_enhancer_image_caption_model": None,
|
323 |
+
"prompt_enhancer_image_caption_processor": None,
|
324 |
+
"prompt_enhancer_llm_model": None,
|
325 |
+
"prompt_enhancer_llm_tokenizer": None
|
326 |
+
}
|
327 |
+
|
328 |
+
if config.enhance_prompt:
|
329 |
+
try:
|
330 |
+
# Use default models or ones specified by config
|
331 |
+
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
|
332 |
+
"MiaoshouAI/Florence-2-large-PromptGen-v2.0",
|
333 |
+
trust_remote_code=True
|
334 |
+
)
|
335 |
+
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
|
336 |
+
"MiaoshouAI/Florence-2-large-PromptGen-v2.0",
|
337 |
+
trust_remote_code=True
|
338 |
+
)
|
339 |
+
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
|
340 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
341 |
+
torch_dtype="bfloat16",
|
342 |
+
)
|
343 |
+
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
|
344 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
345 |
+
)
|
346 |
+
|
347 |
+
prompt_enhancer_components = {
|
348 |
+
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
|
349 |
+
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
|
350 |
+
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
|
351 |
+
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer
|
352 |
+
}
|
353 |
+
except Exception as e:
|
354 |
+
logger.warning(f"Failed to load prompt enhancer models: {e}")
|
355 |
+
config.enhance_prompt = False
|
356 |
+
|
357 |
+
# Construct the pipeline
|
358 |
+
pipeline = LTXVideoPipeline(
|
359 |
+
transformer=transformer,
|
360 |
+
patchifier=patchifier,
|
361 |
+
text_encoder=text_encoder,
|
362 |
+
tokenizer=tokenizer,
|
363 |
+
scheduler=scheduler,
|
364 |
+
vae=vae,
|
365 |
+
allowed_inference_steps=allowed_inference_steps,
|
366 |
+
**prompt_enhancer_components
|
367 |
+
)
|
368 |
+
|
369 |
+
return pipeline
|
370 |
+
|
371 |
+
class EndpointHandler:
|
372 |
+
"""Handler for the LTX Video endpoint"""
|
373 |
+
|
374 |
+
def __init__(self, model_path: str = ""):
|
375 |
+
"""Initialize the endpoint handler
|
376 |
+
|
377 |
+
Args:
|
378 |
+
model_path: Path to model weights (not used, as weights are in current directory)
|
379 |
+
"""
|
380 |
+
# Enable TF32 for potential speedup on Ampere GPUs
|
381 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
382 |
+
|
383 |
+
# Initialize Varnish for post-processing
|
384 |
+
self.varnish = Varnish(
|
385 |
+
device="cuda",
|
386 |
+
model_base_dir="varnish",
|
387 |
+
enable_mmaudio=False, # Disable audio generation for now, since it is broken
|
388 |
+
)
|
389 |
+
|
390 |
+
# The actual LTX pipeline will be loaded during inference to save memory
|
391 |
+
self.pipeline = None
|
392 |
+
|
393 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
394 |
+
"""Process inference requests
|
395 |
+
|
396 |
+
Args:
|
397 |
+
data: Request data containing inputs and parameters
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Dictionary with generated video and metadata
|
401 |
+
"""
|
402 |
+
# Extract inputs and parameters
|
403 |
+
inputs = data.get("inputs", {})
|
404 |
+
|
405 |
+
# Support both formats:
|
406 |
+
# 1. {"inputs": {"prompt": "...", "image": "..."}}
|
407 |
+
# 2. {"inputs": "..."} (prompt only)
|
408 |
+
if isinstance(inputs, str):
|
409 |
+
input_prompt = inputs
|
410 |
+
input_image = None
|
411 |
+
else:
|
412 |
+
input_prompt = inputs.get("prompt", "")
|
413 |
+
input_image = inputs.get("image")
|
414 |
+
|
415 |
+
params = data.get("parameters", {})
|
416 |
+
|
417 |
+
if not input_prompt and not input_image:
|
418 |
+
raise ValueError("Either prompt or image must be provided")
|
419 |
+
|
420 |
+
# Create and validate configuration
|
421 |
+
config = GenerationConfig(
|
422 |
+
# general content settings
|
423 |
+
prompt=input_prompt,
|
424 |
+
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
425 |
+
|
426 |
+
# video model settings
|
427 |
+
width=params.get("width", GenerationConfig.width),
|
428 |
+
height=params.get("height", GenerationConfig.height),
|
429 |
+
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
|
430 |
+
num_frames=params.get("num_frames", GenerationConfig.num_frames),
|
431 |
+
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
432 |
+
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
433 |
+
|
434 |
+
# STG settings
|
435 |
+
stg_scale=params.get("stg_scale", GenerationConfig.stg_scale),
|
436 |
+
stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale),
|
437 |
+
stg_mode=params.get("stg_mode", GenerationConfig.stg_mode),
|
438 |
+
stg_skip_layers=params.get("stg_skip_layers", GenerationConfig.stg_skip_layers),
|
439 |
+
|
440 |
+
# VAE noise settings
|
441 |
+
decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep),
|
442 |
+
decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale),
|
443 |
+
image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale),
|
444 |
+
|
445 |
+
# reproducible generation settings
|
446 |
+
seed=params.get("seed", GenerationConfig.seed),
|
447 |
+
|
448 |
+
# varnish settings
|
449 |
+
fps=params.get("fps", GenerationConfig.fps),
|
450 |
+
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
|
451 |
+
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
|
452 |
+
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
453 |
+
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
454 |
+
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
455 |
+
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
456 |
+
quality=params.get("quality", GenerationConfig.quality),
|
457 |
+
|
458 |
+
# advanced settings
|
459 |
+
mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
|
460 |
+
stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling),
|
461 |
+
sampler=params.get("sampler", GenerationConfig.sampler),
|
462 |
+
|
463 |
+
# prompt enhancement
|
464 |
+
enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt),
|
465 |
+
prompt_enhancement_words_threshold=params.get(
|
466 |
+
"prompt_enhancement_words_threshold",
|
467 |
+
GenerationConfig.prompt_enhancement_words_threshold
|
468 |
+
),
|
469 |
+
).validate_and_adjust()
|
470 |
+
|
471 |
+
try:
|
472 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
|
473 |
+
# Set random seeds for reproducibility
|
474 |
+
random.seed(config.seed)
|
475 |
+
np.random.seed(config.seed)
|
476 |
+
torch.manual_seed(config.seed)
|
477 |
+
generator = torch.Generator(device='cuda').manual_seed(config.seed)
|
478 |
+
|
479 |
+
# Create pipeline if not already created
|
480 |
+
if self.pipeline is None:
|
481 |
+
self.pipeline = create_ltx_video_pipeline(config)
|
482 |
+
|
483 |
+
# Prepare conditioning items if an image is provided
|
484 |
+
conditioning_items = None
|
485 |
+
if input_image:
|
486 |
+
conditioning_items = [
|
487 |
+
ConditioningItem(
|
488 |
+
load_image_to_tensor_with_resize_and_crop(
|
489 |
+
input_image,
|
490 |
+
config.height,
|
491 |
+
config.width,
|
492 |
+
quality=config.input_image_quality
|
493 |
+
),
|
494 |
+
0, # Start frame
|
495 |
+
1.0 # Conditioning strength
|
496 |
+
)
|
497 |
+
]
|
498 |
+
|
499 |
+
# Set up spatiotemporal guidance strategy
|
500 |
+
if config.stg_mode == "attention_values":
|
501 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionValues
|
502 |
+
elif config.stg_mode == "attention_skip":
|
503 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
|
504 |
+
elif config.stg_mode == "residual":
|
505 |
+
skip_layer_strategy = SkipLayerStrategy.Residual
|
506 |
+
elif config.stg_mode == "transformer_block":
|
507 |
+
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
|
508 |
+
|
509 |
+
# Generate video with LTX pipeline
|
510 |
+
result = self.pipeline(
|
511 |
+
height=config.height,
|
512 |
+
width=config.width,
|
513 |
+
num_frames=config.num_frames,
|
514 |
+
frame_rate=config.fps,
|
515 |
+
prompt=config.prompt,
|
516 |
+
negative_prompt=config.negative_prompt,
|
517 |
+
guidance_scale=config.guidance_scale,
|
518 |
+
num_inference_steps=config.num_inference_steps,
|
519 |
+
generator=generator,
|
520 |
+
output_type="pt", # Return as PyTorch tensor
|
521 |
+
skip_layer_strategy=skip_layer_strategy,
|
522 |
+
skip_block_list=config.stg_skip_layers,
|
523 |
+
stg_scale=config.stg_scale,
|
524 |
+
do_rescaling=config.stg_rescale != 1.0,
|
525 |
+
rescaling_scale=config.stg_rescale,
|
526 |
+
conditioning_items=conditioning_items,
|
527 |
+
decode_timestep=config.decode_timestep,
|
528 |
+
decode_noise_scale=config.decode_noise_scale,
|
529 |
+
image_cond_noise_scale=config.image_cond_noise_scale,
|
530 |
+
mixed_precision=config.mixed_precision,
|
531 |
+
is_video=True,
|
532 |
+
vae_per_channel_normalize=True,
|
533 |
+
stochastic_sampling=config.stochastic_sampling,
|
534 |
+
enhance_prompt=config.enhance_prompt,
|
535 |
+
)
|
536 |
+
|
537 |
+
# Get the generated frames
|
538 |
+
frames = result.images
|
539 |
+
|
540 |
+
# Process the generated frames with Varnish
|
541 |
+
import asyncio
|
542 |
+
try:
|
543 |
+
loop = asyncio.get_event_loop()
|
544 |
+
except RuntimeError:
|
545 |
+
loop = asyncio.new_event_loop()
|
546 |
+
asyncio.set_event_loop(loop)
|
547 |
+
|
548 |
+
# Prepare frames for Varnish (denormalize to 0-255 range)
|
549 |
+
frames = frames * 127.5 + 127.5
|
550 |
+
frames = frames.to(torch.uint8)
|
551 |
+
|
552 |
+
# Process with Varnish for post-processing
|
553 |
+
varnish_result = loop.run_until_complete(
|
554 |
+
self.varnish(
|
555 |
+
frames,
|
556 |
+
fps=config.fps,
|
557 |
+
double_num_frames=config.double_num_frames,
|
558 |
+
super_resolution=config.super_resolution,
|
559 |
+
grain_amount=config.grain_amount,
|
560 |
+
enable_audio=config.enable_audio,
|
561 |
+
audio_prompt=config.audio_prompt or config.prompt,
|
562 |
+
audio_negative_prompt=config.audio_negative_prompt,
|
563 |
+
)
|
564 |
+
)
|
565 |
+
|
566 |
+
# Get the final video as a data URI
|
567 |
+
video_uri = loop.run_until_complete(
|
568 |
+
varnish_result.write(
|
569 |
+
type="data-uri",
|
570 |
+
quality=config.quality
|
571 |
+
)
|
572 |
+
)
|
573 |
+
|
574 |
+
# Prepare metadata about the generated video
|
575 |
+
metadata = {
|
576 |
+
"width": varnish_result.metadata.width,
|
577 |
+
"height": varnish_result.metadata.height,
|
578 |
+
"num_frames": varnish_result.metadata.frame_count,
|
579 |
+
"fps": varnish_result.metadata.fps,
|
580 |
+
"duration": varnish_result.metadata.duration,
|
581 |
+
"seed": config.seed,
|
582 |
+
"prompt": config.prompt,
|
583 |
+
}
|
584 |
+
|
585 |
+
# Clean up to prevent CUDA OOM errors
|
586 |
+
del result
|
587 |
+
torch.cuda.empty_cache()
|
588 |
+
gc.collect()
|
589 |
+
|
590 |
+
return {
|
591 |
+
"video": video_uri,
|
592 |
+
"content-type": "video/mp4",
|
593 |
+
"metadata": metadata
|
594 |
+
}
|
595 |
+
|
596 |
+
except Exception as e:
|
597 |
+
# Log the error and reraise
|
598 |
+
import traceback
|
599 |
+
error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
|
600 |
+
logger.error(error_message)
|
601 |
+
raise RuntimeError(error_message)
|