|
from dataclasses import dataclass |
|
from pathlib import Path |
|
import logging |
|
import base64 |
|
import random |
|
import gc |
|
import os |
|
import numpy as np |
|
import torch |
|
from typing import Dict, Any, Optional, List, Union, Tuple |
|
import json |
|
from safetensors import safe_open |
|
|
|
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder |
|
from ltx_video.models.transformers.transformer3d import Transformer3DModel |
|
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier |
|
from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter |
|
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline |
|
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
|
from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer |
|
|
|
from varnish import Varnish |
|
from varnish.utils import is_truthy, process_input_image |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
hf_token = os.getenv("HF_API_TOKEN") |
|
|
|
|
|
MAX_LARGE_SIDE = 1280 |
|
MAX_SMALL_SIDE = 768 |
|
MAX_FRAMES = (8 * 21) + 1 |
|
|
|
|
|
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT")) |
|
|
|
@dataclass |
|
class GenerationConfig: |
|
"""Configuration for video generation""" |
|
|
|
|
|
prompt: str = "" |
|
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres" |
|
|
|
|
|
width: int = 768 |
|
height: int = 416 |
|
|
|
|
|
|
|
input_image_quality: int = 70 |
|
|
|
|
|
|
|
|
|
num_frames: int = (8 * 14) + 1 |
|
|
|
|
|
guidance_scale: float = 3.5 |
|
|
|
num_inference_steps: int = 50 |
|
|
|
|
|
seed: int = -1 |
|
|
|
|
|
fps: int = 30 |
|
double_num_frames: bool = False |
|
super_resolution: bool = False |
|
|
|
grain_amount: float = 0.0 |
|
|
|
|
|
enable_audio: bool = False |
|
audio_prompt: str = "" |
|
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quality: int = 18 |
|
|
|
|
|
stg_scale: float = 1.0 |
|
stg_rescale: float = 0.7 |
|
stg_mode: str = "attention_values" |
|
stg_skip_layers: str = "19" |
|
|
|
|
|
decode_timestep: float = 0.05 |
|
decode_noise_scale: float = 0.025 |
|
|
|
|
|
image_cond_noise_scale: float = 0.15 |
|
mixed_precision: bool = True |
|
stochastic_sampling: bool = False |
|
|
|
|
|
sampler: Optional[str] = None |
|
|
|
|
|
enhance_prompt: bool = False |
|
prompt_enhancement_words_threshold: int = 50 |
|
|
|
def validate_and_adjust(self) -> 'GenerationConfig': |
|
"""Validate and adjust parameters to meet constraints""" |
|
|
|
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or |
|
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)): |
|
|
|
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE |
|
|
|
|
|
total_pixels = self.width * self.height |
|
if total_pixels > MAX_TOTAL_PIXELS: |
|
scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5 |
|
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32)) |
|
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32)) |
|
else: |
|
|
|
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32)) |
|
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32)) |
|
|
|
|
|
k = (self.num_frames - 1) // 8 |
|
self.num_frames = min((k * 8) + 1, MAX_FRAMES) |
|
|
|
|
|
if self.seed == -1: |
|
self.seed = random.randint(0, 2**32 - 1) |
|
|
|
|
|
if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values": |
|
self.stg_mode = "attention_values" |
|
elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip": |
|
self.stg_mode = "attention_skip" |
|
elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual": |
|
self.stg_mode = "residual" |
|
elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block": |
|
self.stg_mode = "transformer_block" |
|
|
|
|
|
if isinstance(self.stg_skip_layers, str): |
|
self.stg_skip_layers = [int(x.strip()) for x in self.stg_skip_layers.split(",")] |
|
|
|
|
|
if self.enhance_prompt and self.prompt: |
|
prompt_word_count = len(self.prompt.split()) |
|
if prompt_word_count >= self.prompt_enhancement_words_threshold: |
|
logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.") |
|
self.enhance_prompt = False |
|
|
|
return self |
|
|
|
def load_image_to_tensor_with_resize_and_crop( |
|
image_input: Union[str, bytes], |
|
target_height: int = 512, |
|
target_width: int = 768, |
|
quality: int = 100 |
|
) -> torch.Tensor: |
|
"""Load and process an image into a tensor. |
|
|
|
Args: |
|
image_input: Either a file path (str) or image data (bytes) |
|
target_height: Desired height of output tensor |
|
target_width: Desired width of output tensor |
|
quality: JPEG quality to use when re-encoding (to simulate lower quality images) |
|
""" |
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
|
|
|
|
if isinstance(image_input, str) and image_input.startswith('data:'): |
|
header, encoded = image_input.split(",", 1) |
|
image_data = base64.b64decode(encoded) |
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
|
elif isinstance(image_input, bytes): |
|
image = Image.open(io.BytesIO(image_input)).convert("RGB") |
|
|
|
elif isinstance(image_input, str): |
|
image = Image.open(image_input).convert("RGB") |
|
else: |
|
raise ValueError("image_input must be either a file path, bytes, or base64 data URI") |
|
|
|
|
|
if quality < 100: |
|
buffer = io.BytesIO() |
|
image.save(buffer, format="JPEG", quality=quality) |
|
buffer.seek(0) |
|
image = Image.open(buffer).convert("RGB") |
|
|
|
input_width, input_height = image.size |
|
aspect_ratio_target = target_width / target_height |
|
aspect_ratio_frame = input_width / input_height |
|
if aspect_ratio_frame > aspect_ratio_target: |
|
new_width = int(input_height * aspect_ratio_target) |
|
new_height = input_height |
|
x_start = (input_width - new_width) // 2 |
|
y_start = 0 |
|
else: |
|
new_width = input_width |
|
new_height = int(input_width / aspect_ratio_target) |
|
x_start = 0 |
|
y_start = (input_height - new_height) // 2 |
|
|
|
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) |
|
image = image.resize((target_width, target_height)) |
|
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() |
|
frame_tensor = (frame_tensor / 127.5) - 1.0 |
|
|
|
return frame_tensor.unsqueeze(0).unsqueeze(2) |
|
|
|
def calculate_padding( |
|
source_height: int, source_width: int, target_height: int, target_width: int |
|
) -> tuple[int, int, int, int]: |
|
"""Calculate padding to reach target dimensions""" |
|
|
|
pad_height = target_height - source_height |
|
pad_width = target_width - source_width |
|
|
|
|
|
pad_top = pad_height // 2 |
|
pad_bottom = pad_height - pad_top |
|
pad_left = pad_width // 2 |
|
pad_right = pad_width - pad_left |
|
|
|
|
|
|
|
padding = (pad_left, pad_right, pad_top, pad_bottom) |
|
return padding |
|
|
|
def prepare_conditioning( |
|
conditioning_media_paths: List[str], |
|
conditioning_strengths: List[float], |
|
conditioning_start_frames: List[int], |
|
height: int, |
|
width: int, |
|
num_frames: int, |
|
input_image_quality: int = 100, |
|
pipeline: Optional[LTXVideoPipeline] = None, |
|
) -> Optional[List[ConditioningItem]]: |
|
"""Prepare conditioning items based on input media paths and their parameters""" |
|
conditioning_items = [] |
|
for path, strength, start_frame in zip( |
|
conditioning_media_paths, conditioning_strengths, conditioning_start_frames |
|
): |
|
|
|
frame_tensor = load_image_to_tensor_with_resize_and_crop( |
|
path, height, width, quality=input_image_quality |
|
) |
|
|
|
|
|
if pipeline: |
|
frame_count = 1 |
|
frame_count = pipeline.trim_conditioning_sequence( |
|
start_frame, frame_count, num_frames |
|
) |
|
|
|
conditioning_items.append( |
|
ConditioningItem(frame_tensor, start_frame, strength) |
|
) |
|
|
|
return conditioning_items |
|
|
|
def create_ltx_video_pipeline( |
|
config: GenerationConfig, |
|
device: str = "cuda" |
|
) -> LTXVideoPipeline: |
|
"""Create and configure the LTX video pipeline""" |
|
|
|
current_dir = Path.cwd() |
|
|
|
ckpt_path = "./ltxv-2b-0.9.6-distilled-04-25.safetensors" |
|
|
|
|
|
allowed_inference_steps = None |
|
|
|
assert os.path.exists( |
|
ckpt_path |
|
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist" |
|
|
|
with safe_open(ckpt_path, framework="pt") as f: |
|
metadata = f.metadata() |
|
config_str = metadata.get("config") |
|
configs = json.loads(config_str) |
|
allowed_inference_steps = configs.get("allowed_inference_steps", None) |
|
|
|
|
|
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) |
|
transformer = Transformer3DModel.from_pretrained(ckpt_path) |
|
|
|
|
|
if config.sampler: |
|
scheduler = RectifiedFlowScheduler( |
|
sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") |
|
) |
|
else: |
|
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) |
|
|
|
text_encoder = T5EncoderModel.from_pretrained("./text_encoder") |
|
patchifier = SymmetricPatchifier(patch_size=1) |
|
tokenizer = T5Tokenizer.from_pretrained("./tokenizer") |
|
|
|
|
|
vae = vae.to(device) |
|
transformer = transformer.to(device) |
|
text_encoder = text_encoder.to(device) |
|
|
|
|
|
vae = vae.to(torch.bfloat16) |
|
transformer = transformer.to(torch.bfloat16) |
|
text_encoder = text_encoder.to(torch.bfloat16) |
|
|
|
|
|
prompt_enhancer_components = { |
|
"prompt_enhancer_image_caption_model": None, |
|
"prompt_enhancer_image_caption_processor": None, |
|
"prompt_enhancer_llm_model": None, |
|
"prompt_enhancer_llm_tokenizer": None |
|
} |
|
|
|
if config.enhance_prompt: |
|
try: |
|
|
|
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( |
|
"MiaoshouAI/Florence-2-large-PromptGen-v2.0", |
|
trust_remote_code=True |
|
) |
|
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( |
|
"MiaoshouAI/Florence-2-large-PromptGen-v2.0", |
|
trust_remote_code=True |
|
) |
|
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( |
|
"unsloth/Llama-3.2-3B-Instruct", |
|
torch_dtype="bfloat16", |
|
) |
|
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( |
|
"unsloth/Llama-3.2-3B-Instruct", |
|
) |
|
|
|
prompt_enhancer_components = { |
|
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model, |
|
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor, |
|
"prompt_enhancer_llm_model": prompt_enhancer_llm_model, |
|
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer |
|
} |
|
except Exception as e: |
|
logger.warning(f"Failed to load prompt enhancer models: {e}") |
|
config.enhance_prompt = False |
|
|
|
|
|
pipeline = LTXVideoPipeline( |
|
transformer=transformer, |
|
patchifier=patchifier, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
scheduler=scheduler, |
|
vae=vae, |
|
allowed_inference_steps=allowed_inference_steps, |
|
**prompt_enhancer_components |
|
) |
|
|
|
return pipeline |
|
|
|
class EndpointHandler: |
|
"""Handler for the LTX Video endpoint""" |
|
|
|
def __init__(self, model_path: str = ""): |
|
"""Initialize the endpoint handler |
|
|
|
Args: |
|
model_path: Path to model weights (not used, as weights are in current directory) |
|
""" |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
self.varnish = Varnish( |
|
device="cuda", |
|
model_base_dir="varnish", |
|
enable_mmaudio=False, |
|
) |
|
|
|
|
|
self.pipeline = None |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
"""Process inference requests |
|
|
|
Args: |
|
data: Request data containing inputs and parameters |
|
|
|
Returns: |
|
Dictionary with generated video and metadata |
|
""" |
|
|
|
inputs = data.get("inputs", {}) |
|
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
input_prompt = inputs |
|
input_image = None |
|
else: |
|
input_prompt = inputs.get("prompt", "") |
|
input_image = inputs.get("image") |
|
|
|
params = data.get("parameters", {}) |
|
|
|
if not input_prompt and not input_image: |
|
raise ValueError("Either prompt or image must be provided") |
|
|
|
|
|
config = GenerationConfig( |
|
|
|
prompt=input_prompt, |
|
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt), |
|
|
|
|
|
width=params.get("width", GenerationConfig.width), |
|
height=params.get("height", GenerationConfig.height), |
|
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality), |
|
num_frames=params.get("num_frames", GenerationConfig.num_frames), |
|
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale), |
|
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps), |
|
|
|
|
|
stg_scale=params.get("stg_scale", GenerationConfig.stg_scale), |
|
stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale), |
|
stg_mode=params.get("stg_mode", GenerationConfig.stg_mode), |
|
stg_skip_layers=params.get("stg_skip_layers", GenerationConfig.stg_skip_layers), |
|
|
|
|
|
decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep), |
|
decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale), |
|
image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale), |
|
|
|
|
|
seed=params.get("seed", GenerationConfig.seed), |
|
|
|
|
|
fps=params.get("fps", GenerationConfig.fps), |
|
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), |
|
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), |
|
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount), |
|
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio), |
|
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt), |
|
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt), |
|
quality=params.get("quality", GenerationConfig.quality), |
|
|
|
|
|
mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision), |
|
stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling), |
|
sampler=params.get("sampler", GenerationConfig.sampler), |
|
|
|
|
|
enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt), |
|
prompt_enhancement_words_threshold=params.get( |
|
"prompt_enhancement_words_threshold", |
|
GenerationConfig.prompt_enhancement_words_threshold |
|
), |
|
).validate_and_adjust() |
|
|
|
try: |
|
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad(): |
|
|
|
random.seed(config.seed) |
|
np.random.seed(config.seed) |
|
torch.manual_seed(config.seed) |
|
generator = torch.Generator(device='cuda').manual_seed(config.seed) |
|
|
|
|
|
if self.pipeline is None: |
|
self.pipeline = create_ltx_video_pipeline(config) |
|
|
|
|
|
conditioning_items = None |
|
if input_image: |
|
conditioning_items = [ |
|
ConditioningItem( |
|
load_image_to_tensor_with_resize_and_crop( |
|
input_image, |
|
config.height, |
|
config.width, |
|
quality=config.input_image_quality |
|
), |
|
0, |
|
1.0 |
|
) |
|
] |
|
|
|
|
|
if config.stg_mode == "attention_values": |
|
skip_layer_strategy = SkipLayerStrategy.AttentionValues |
|
elif config.stg_mode == "attention_skip": |
|
skip_layer_strategy = SkipLayerStrategy.AttentionSkip |
|
elif config.stg_mode == "residual": |
|
skip_layer_strategy = SkipLayerStrategy.Residual |
|
elif config.stg_mode == "transformer_block": |
|
skip_layer_strategy = SkipLayerStrategy.TransformerBlock |
|
|
|
|
|
result = self.pipeline( |
|
height=config.height, |
|
width=config.width, |
|
num_frames=config.num_frames, |
|
frame_rate=config.fps, |
|
prompt=config.prompt, |
|
negative_prompt=config.negative_prompt, |
|
guidance_scale=config.guidance_scale, |
|
num_inference_steps=config.num_inference_steps, |
|
generator=generator, |
|
output_type="pt", |
|
skip_layer_strategy=skip_layer_strategy, |
|
skip_block_list=config.stg_skip_layers, |
|
stg_scale=config.stg_scale, |
|
do_rescaling=config.stg_rescale != 1.0, |
|
rescaling_scale=config.stg_rescale, |
|
conditioning_items=conditioning_items, |
|
decode_timestep=config.decode_timestep, |
|
decode_noise_scale=config.decode_noise_scale, |
|
image_cond_noise_scale=config.image_cond_noise_scale, |
|
mixed_precision=config.mixed_precision, |
|
is_video=True, |
|
vae_per_channel_normalize=True, |
|
stochastic_sampling=config.stochastic_sampling, |
|
enhance_prompt=config.enhance_prompt, |
|
) |
|
|
|
|
|
frames = result.images |
|
|
|
|
|
import asyncio |
|
try: |
|
loop = asyncio.get_event_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
frames = frames * 127.5 + 127.5 |
|
frames = frames.to(torch.uint8) |
|
|
|
|
|
varnish_result = loop.run_until_complete( |
|
self.varnish( |
|
frames, |
|
fps=config.fps, |
|
double_num_frames=config.double_num_frames, |
|
super_resolution=config.super_resolution, |
|
grain_amount=config.grain_amount, |
|
enable_audio=config.enable_audio, |
|
audio_prompt=config.audio_prompt or config.prompt, |
|
audio_negative_prompt=config.audio_negative_prompt, |
|
) |
|
) |
|
|
|
|
|
video_uri = loop.run_until_complete( |
|
varnish_result.write( |
|
type="data-uri", |
|
quality=config.quality |
|
) |
|
) |
|
|
|
|
|
metadata = { |
|
"width": varnish_result.metadata.width, |
|
"height": varnish_result.metadata.height, |
|
"num_frames": varnish_result.metadata.frame_count, |
|
"fps": varnish_result.metadata.fps, |
|
"duration": varnish_result.metadata.duration, |
|
"seed": config.seed, |
|
"prompt": config.prompt, |
|
} |
|
|
|
|
|
del result |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return { |
|
"video": video_uri, |
|
"content-type": "video/mp4", |
|
"metadata": metadata |
|
} |
|
|
|
except Exception as e: |
|
|
|
import traceback |
|
error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}" |
|
logger.error(error_message) |
|
raise RuntimeError(error_message) |