Create handler.py
Browse files- handler.py +504 -0
handler.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
import pathlib
|
4 |
+
from typing import Dict, Any, Optional, Tuple
|
5 |
+
import asyncio
|
6 |
+
import base64
|
7 |
+
import io
|
8 |
+
import pprint
|
9 |
+
import logging
|
10 |
+
import random
|
11 |
+
import traceback
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import gc
|
16 |
+
|
17 |
+
from diffusers import AutoencoderKLLTXVideo, LTXPipeline, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
18 |
+
#from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
|
19 |
+
from teacache import apply_teacache
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from varnish import Varnish
|
24 |
+
from varnish.utils import is_truthy, process_input_image
|
25 |
+
|
26 |
+
# Configure logging
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
# Get token from environment
|
32 |
+
hf_token = os.getenv("HF_API_TOKEN")
|
33 |
+
|
34 |
+
# Constraints
|
35 |
+
MAX_LARGE_SIDE = 1280
|
36 |
+
MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
|
37 |
+
MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
|
38 |
+
|
39 |
+
# Check environment variable for pipeline support
|
40 |
+
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class GenerationConfig:
|
44 |
+
"""Configuration for video generation"""
|
45 |
+
|
46 |
+
# general content settings
|
47 |
+
prompt: str = ""
|
48 |
+
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
|
49 |
+
|
50 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
51 |
+
# we use small values to make things a bit faster
|
52 |
+
width: int = 768
|
53 |
+
height: int = 416
|
54 |
+
|
55 |
+
|
56 |
+
# this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
|
57 |
+
# after a quick benchmark using the value 70 seems like a sweet spot
|
58 |
+
input_image_quality: int = 70
|
59 |
+
|
60 |
+
# users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
|
61 |
+
# The value must be a multiple of 8, plus 1 frame.
|
62 |
+
# visual glitches appear after about 169 frames, so we don't need more actually
|
63 |
+
num_frames: int = (8 * 14) + 1
|
64 |
+
|
65 |
+
# values between 3.0 and 4.0 are nice
|
66 |
+
guidance_scale: float = 3.5
|
67 |
+
|
68 |
+
num_inference_steps: int = 50
|
69 |
+
|
70 |
+
# reproducible generation settings
|
71 |
+
seed: int = -1 # -1 means random seed
|
72 |
+
|
73 |
+
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
74 |
+
fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
|
75 |
+
double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
|
76 |
+
super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
77 |
+
|
78 |
+
grain_amount: float = 0.0 # be careful, adding film grian can negatively impact video compression
|
79 |
+
|
80 |
+
# audio settings
|
81 |
+
enable_audio: bool = False # Whether to generate audio
|
82 |
+
audio_prompt: str = "" # Text prompt for audio generation
|
83 |
+
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
|
84 |
+
|
85 |
+
# The range of the CRF scale is 0–51, where:
|
86 |
+
# 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
|
87 |
+
# 23 is the default
|
88 |
+
# 51 is worst quality possible
|
89 |
+
# A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
90 |
+
# Consider 17 or 18 to be visually lossless or nearly so;
|
91 |
+
# it should look the same or nearly the same as the input but it isn't technically lossless.
|
92 |
+
# The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
93 |
+
quality: int = 18
|
94 |
+
|
95 |
+
# TeaCache settings
|
96 |
+
enable_teacache: bool = True
|
97 |
+
teacache_threshold: float = 0.05 # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
|
98 |
+
|
99 |
+
# Enhance-A-Video settings
|
100 |
+
enable_enhance_a_video: bool = True
|
101 |
+
enhance_a_video_weight: float = 5.0
|
102 |
+
|
103 |
+
# LoRA settings
|
104 |
+
lora_model_name: str = "" # HuggingFace repo ID or path to LoRA model
|
105 |
+
lora_model_weight_file: str = "" # Specific weight file to load from the LoRA model
|
106 |
+
lora_model_trigger: str = "" # Optional trigger word to prepend to the prompt
|
107 |
+
|
108 |
+
def validate_and_adjust(self) -> 'GenerationConfig':
|
109 |
+
"""Validate and adjust parameters to meet constraints"""
|
110 |
+
# First check if it's one of our explicitly allowed resolutions
|
111 |
+
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
|
112 |
+
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
|
113 |
+
# For other resolutions, ensure total pixels don't exceed max
|
114 |
+
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
|
115 |
+
|
116 |
+
# If total pixels exceed maximum, scale down proportionally
|
117 |
+
total_pixels = self.width * self.height
|
118 |
+
if total_pixels > MAX_TOTAL_PIXELS:
|
119 |
+
scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
|
120 |
+
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
|
121 |
+
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
|
122 |
+
else:
|
123 |
+
# Round dimensions to nearest multiple of 32
|
124 |
+
self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
|
125 |
+
self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
|
126 |
+
|
127 |
+
# Adjust number of frames to be in format 8k + 1
|
128 |
+
k = (self.num_frames - 1) // 8
|
129 |
+
self.num_frames = min((k * 8) + 1, MAX_FRAMES)
|
130 |
+
|
131 |
+
# Set random seed if not specified
|
132 |
+
if self.seed == -1:
|
133 |
+
self.seed = random.randint(0, 2**32 - 1)
|
134 |
+
|
135 |
+
return self
|
136 |
+
|
137 |
+
class EndpointHandler:
|
138 |
+
"""Handles video generation requests using LTX models and Varnish post-processing"""
|
139 |
+
|
140 |
+
def __init__(self, model_path: str = ""):
|
141 |
+
"""Initialize the handler with LTX models and Varnish
|
142 |
+
|
143 |
+
Args:
|
144 |
+
model_path: Path to LTX model weights
|
145 |
+
"""
|
146 |
+
# Enable TF32 for potential speedup on Ampere GPUs
|
147 |
+
#torch.backends.cuda.matmul.allow_tf32 = True
|
148 |
+
|
149 |
+
# use distilled weights
|
150 |
+
model_path = f"{model_path}/ltxv-2b-0.9.6-distilled-04-25.safetensors"
|
151 |
+
|
152 |
+
transformer = LTXVideoTransformer3DModel.from_single_file(
|
153 |
+
model_path, torch_dtype=torch.bfloat16
|
154 |
+
)
|
155 |
+
|
156 |
+
vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
|
157 |
+
|
158 |
+
if support_image_prompt:
|
159 |
+
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
|
160 |
+
model_path,
|
161 |
+
transformer=transformer,
|
162 |
+
vae=vae,
|
163 |
+
torch_dtype=torch.bfloat16
|
164 |
+
).to("cuda")
|
165 |
+
|
166 |
+
#apply_teacache(self.image_to_video)
|
167 |
+
|
168 |
+
# Compilation requires some time to complete, so it is best suited for
|
169 |
+
# situations where you prepare your pipeline once and then perform the
|
170 |
+
# same type of inference operations multiple times.
|
171 |
+
# For example, calling the compiled pipeline on a different image size
|
172 |
+
# triggers compilation again which can be expensive.
|
173 |
+
#self.image_to_video.unet = torch.compile(self.image_to_video.unet, mode="reduce-overhead", fullgraph=True)
|
174 |
+
|
175 |
+
else:
|
176 |
+
# Initialize models with bfloat16 precision
|
177 |
+
self.text_to_video = LTXPipeline.from_pretrained(
|
178 |
+
model_path,
|
179 |
+
transformer=transformer,
|
180 |
+
vae=vae,
|
181 |
+
torch_dtype=torch.bfloat16
|
182 |
+
).to("cuda")
|
183 |
+
|
184 |
+
#apply_teacache(self.text_to_video)
|
185 |
+
|
186 |
+
# Compilation requires some time to complete, so it is best suited for
|
187 |
+
# situations where you prepare your pipeline once and then perform the
|
188 |
+
# same type of inference operations multiple times.
|
189 |
+
# For example, calling the compiled pipeline on a different image size
|
190 |
+
# triggers compilation again which can be expensive.
|
191 |
+
#self.text_to_video.unet = torch.compile(self.text_to_video.unet, mode="reduce-overhead", fullgraph=True)
|
192 |
+
|
193 |
+
|
194 |
+
# Initialize LoRA tracking
|
195 |
+
self._current_lora_model = None
|
196 |
+
|
197 |
+
#if support_image_prompt:
|
198 |
+
# # Enable CPU offload for memory efficiency
|
199 |
+
# self.image_to_video.enable_model_cpu_offload()
|
200 |
+
# # Inject enhance-a-video functionality
|
201 |
+
# inject_enhance_for_ltx(self.image_to_video.transformer)
|
202 |
+
#else:
|
203 |
+
# # Enable CPU offload for memory efficiency
|
204 |
+
# self.text_to_video.enable_model_cpu_offload()
|
205 |
+
# # Inject enhance-a-video functionality
|
206 |
+
# inject_enhance_for_ltx(self.text_to_video.transformer)
|
207 |
+
|
208 |
+
|
209 |
+
# Initialize Varnish for post-processing
|
210 |
+
self.varnish = Varnish(
|
211 |
+
device="cuda",
|
212 |
+
model_base_dir="/repository/varnish",
|
213 |
+
|
214 |
+
# there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
|
215 |
+
# not sure how to fix that.. :/
|
216 |
+
#
|
217 |
+
# it says:
|
218 |
+
# File "dist-packages/varnish.py", line 152, in __init__
|
219 |
+
# self._setup_mmaudio()
|
220 |
+
# File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
|
221 |
+
# net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
|
222 |
+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
223 |
+
# File "dist-packages/torch/serialization.py", line 1384, in load
|
224 |
+
# return _legacy_load(
|
225 |
+
# ^^^^^^^^^^^^^
|
226 |
+
# File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
|
227 |
+
# magic_number = pickle_module.load(f, **pickle_load_args)
|
228 |
+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
229 |
+
# _pickle.UnpicklingError: invalid load key, '<'.
|
230 |
+
enable_mmaudio=True,
|
231 |
+
)
|
232 |
+
|
233 |
+
# Determine if TeaCache is already installed or not
|
234 |
+
self.text_to_video_teacache = False
|
235 |
+
self.image_to_video_teacache = False
|
236 |
+
|
237 |
+
|
238 |
+
async def process_frames(
|
239 |
+
self,
|
240 |
+
frames: torch.Tensor,
|
241 |
+
config: GenerationConfig
|
242 |
+
) -> tuple[str, dict]:
|
243 |
+
"""Post-process generated frames using Varnish
|
244 |
+
|
245 |
+
Args:
|
246 |
+
frames: Generated video frames tensor
|
247 |
+
config: Generation configuration
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Tuple of (video data URI, metadata dictionary)
|
251 |
+
"""
|
252 |
+
try:
|
253 |
+
# Process video with Varnish
|
254 |
+
result = await self.varnish(
|
255 |
+
input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
|
256 |
+
fps=config.fps, # this is the FPS of the final output video. This number can be used by Varnish to calculate the duration of a clip ((using frames * factor) / fps etc)
|
257 |
+
double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
|
258 |
+
super_resolution=config.super_resolution, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
259 |
+
grain_amount=config.grain_amount,
|
260 |
+
enable_audio=config.enable_audio,
|
261 |
+
audio_prompt=config.audio_prompt,
|
262 |
+
audio_negative_prompt=config.audio_negative_prompt,
|
263 |
+
)
|
264 |
+
|
265 |
+
# Convert to data URI
|
266 |
+
video_uri = await result.write(type="data-uri", quality=config.quality)
|
267 |
+
|
268 |
+
# Collect metadata
|
269 |
+
metadata = {
|
270 |
+
"width": result.metadata.width,
|
271 |
+
"height": result.metadata.height,
|
272 |
+
"num_frames": result.metadata.frame_count,
|
273 |
+
"fps": result.metadata.fps,
|
274 |
+
"duration": result.metadata.duration,
|
275 |
+
"seed": config.seed,
|
276 |
+
}
|
277 |
+
|
278 |
+
return video_uri, metadata
|
279 |
+
|
280 |
+
except Exception as e:
|
281 |
+
logger.error(f"Error in process_frames: {str(e)}")
|
282 |
+
raise RuntimeError(f"Failed to process frames: {str(e)}")
|
283 |
+
|
284 |
+
|
285 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
286 |
+
"""Process incoming requests for video generation
|
287 |
+
|
288 |
+
Args:
|
289 |
+
data: Request data containing:
|
290 |
+
- inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
|
291 |
+
- parameters (dict):
|
292 |
+
- prompt (required, string): list of concepts to keep in the video.
|
293 |
+
- negative_prompt (optional, string): list of concepts to ignore in the video.
|
294 |
+
- width (optional, int, default to 768): width, or horizontal size in pixels.
|
295 |
+
- height (optional, int, default to 512): height, or vertical size in pixels.
|
296 |
+
- input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
|
297 |
+
- num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
|
298 |
+
- guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
|
299 |
+
- num_inference_steps (optional, int, default to 50): number of inference steps
|
300 |
+
- seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
|
301 |
+
- fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
|
302 |
+
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
|
303 |
+
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
|
304 |
+
- grain_amount (optional, float): amount of film grain to add to the output video
|
305 |
+
- enable_audio (optional, bool): automatically generate an audio track
|
306 |
+
- audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
|
307 |
+
- audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
|
308 |
+
- quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
309 |
+
- enable_teacache (optional, bool, default to True): Generate faster at the cost of a slight quality loss
|
310 |
+
- teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
|
311 |
+
- enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
|
312 |
+
- enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
|
313 |
+
- lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
|
314 |
+
- lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
|
315 |
+
- lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
|
316 |
+
Returns:
|
317 |
+
Dictionary containing:
|
318 |
+
- video: Base64 encoded MP4 data URI
|
319 |
+
- content-type: MIME type
|
320 |
+
- metadata: Generation metadata
|
321 |
+
"""
|
322 |
+
inputs = data.get("inputs", dict())
|
323 |
+
|
324 |
+
input_prompt = inputs.get("prompt", "")
|
325 |
+
input_image = inputs.get("image")
|
326 |
+
|
327 |
+
params = data.get("parameters", dict())
|
328 |
+
|
329 |
+
if not input_image and not input_prompt:
|
330 |
+
raise ValueError("Either prompt or image must be provided")
|
331 |
+
|
332 |
+
#logger.debug(f"Raw parameters:")
|
333 |
+
# pprint.pprint(params)
|
334 |
+
|
335 |
+
# Create and validate configuration
|
336 |
+
config = GenerationConfig(
|
337 |
+
# general content settings
|
338 |
+
prompt=input_prompt,
|
339 |
+
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
340 |
+
|
341 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
342 |
+
width=params.get("width", GenerationConfig.width),
|
343 |
+
height=params.get("height", GenerationConfig.height),
|
344 |
+
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
|
345 |
+
num_frames=params.get("num_frames", GenerationConfig.num_frames),
|
346 |
+
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
347 |
+
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
348 |
+
|
349 |
+
# reproducible generation settings
|
350 |
+
seed=params.get("seed", GenerationConfig.seed),
|
351 |
+
|
352 |
+
# varnish settings (will be used for post-processing after the raw video clip has been generated)
|
353 |
+
fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
|
354 |
+
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
|
355 |
+
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
356 |
+
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
357 |
+
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
358 |
+
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
359 |
+
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
360 |
+
quality=params.get("quality", GenerationConfig.quality),
|
361 |
+
|
362 |
+
# TeaCache settings
|
363 |
+
enable_teacache=params.get("enable_teacache", True),
|
364 |
+
|
365 |
+
# values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
|
366 |
+
teacache_threshold=params.get("teacache_threshold", 0.05),
|
367 |
+
|
368 |
+
|
369 |
+
# Add enhance-a-video settings
|
370 |
+
enable_enhance_a_video=params.get("enable_enhance_a_video", True),
|
371 |
+
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
|
372 |
+
|
373 |
+
# LoRA settings
|
374 |
+
lora_model_name=params.get("lora_model_name", ""),
|
375 |
+
lora_model_weight_file=params.get("lora_model_weight_file", ""),
|
376 |
+
lora_model_trigger=params.get("lora_model_trigger", ""),
|
377 |
+
).validate_and_adjust()
|
378 |
+
|
379 |
+
#logger.debug(f"Global request settings:")
|
380 |
+
#pprint.pprint(config)
|
381 |
+
|
382 |
+
try:
|
383 |
+
with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
|
384 |
+
# Set random seeds
|
385 |
+
random.seed(config.seed)
|
386 |
+
np.random.seed(config.seed)
|
387 |
+
torch.manual_seed(config.seed)
|
388 |
+
generator = torch.Generator(device='cuda')
|
389 |
+
generator = generator.manual_seed(config.seed)
|
390 |
+
|
391 |
+
# Configure enhance-a-video
|
392 |
+
#if config.enable_enhance_a_video:
|
393 |
+
# enable_enhance()
|
394 |
+
# set_enhance_weight(config.enhance_a_video_weight)
|
395 |
+
|
396 |
+
# Prepare generation parameters for the video model (we omit params that are destined to Varnish, or things like the seed which is set externally)
|
397 |
+
generation_kwargs = {
|
398 |
+
# general content settings
|
399 |
+
"prompt": config.prompt,
|
400 |
+
"negative_prompt": config.negative_prompt,
|
401 |
+
|
402 |
+
# video model settings (will be used during generation of the initial raw video clip)
|
403 |
+
"width": config.width,
|
404 |
+
"height": config.height,
|
405 |
+
"num_frames": config.num_frames,
|
406 |
+
"guidance_scale": config.guidance_scale,
|
407 |
+
"num_inference_steps": config.num_inference_steps,
|
408 |
+
|
409 |
+
# constants
|
410 |
+
"output_type": "pt",
|
411 |
+
"generator": generator,
|
412 |
+
|
413 |
+
# Timestep for decoding VAE noise: the timestep at which generated video is decoded
|
414 |
+
"decode_timestep": 0.05,
|
415 |
+
|
416 |
+
# Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
|
417 |
+
"decode_noise_scale": 0.025,
|
418 |
+
}
|
419 |
+
#logger.info(f"Video model generation settings:")
|
420 |
+
#pprint.pprint(generation_kwargs)
|
421 |
+
|
422 |
+
# Handle LoRA loading/unloading
|
423 |
+
if hasattr(self, '_current_lora_model'):
|
424 |
+
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
|
425 |
+
# Unload previous LoRA if it exists and is different
|
426 |
+
if hasattr(self.text_to_video, 'unload_lora_weights'):
|
427 |
+
print("Unloading LoRA weights for the text_to_video pipeline..")
|
428 |
+
self.text_to_video.unload_lora_weights()
|
429 |
+
|
430 |
+
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
|
431 |
+
print("Unloading LoRA weights for the image_to_video pipeline..")
|
432 |
+
self.image_to_video.unload_lora_weights()
|
433 |
+
|
434 |
+
if config.lora_model_name:
|
435 |
+
# Load new LoRA
|
436 |
+
if hasattr(self.text_to_video, 'load_lora_weights'):
|
437 |
+
print("Loading LoRA weights for the text_to_video pipeline..")
|
438 |
+
self.text_to_video.load_lora_weights(
|
439 |
+
config.lora_model_name,
|
440 |
+
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
|
441 |
+
token=hf_token,
|
442 |
+
)
|
443 |
+
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
|
444 |
+
print("Loading LoRA weights for the image_to_video pipeline..")
|
445 |
+
self.image_to_video.load_lora_weights(
|
446 |
+
config.lora_model_name,
|
447 |
+
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
|
448 |
+
token=hf_token,
|
449 |
+
)
|
450 |
+
self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
|
451 |
+
|
452 |
+
# Modify prompt if trigger word is provided
|
453 |
+
if config.lora_model_trigger:
|
454 |
+
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
|
455 |
+
|
456 |
+
#enhance_a_video_config = EnhanceAVideoConfig(
|
457 |
+
# weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
|
458 |
+
# # doing some testing
|
459 |
+
# num_frames_callback=lambda: (8 + 1),
|
460 |
+
# # num_frames_callback=lambda: config.num_frames,
|
461 |
+
# # num_frames_callback=lambda: (config.num_frames - 1),
|
462 |
+
#
|
463 |
+
# _attention_type=1
|
464 |
+
#)
|
465 |
+
|
466 |
+
# Check if image-to-video generation is requested
|
467 |
+
if support_image_prompt and input_image:
|
468 |
+
processed_image = process_input_image(
|
469 |
+
input_image,
|
470 |
+
config.width,
|
471 |
+
config.height,
|
472 |
+
config.input_image_quality,
|
473 |
+
)
|
474 |
+
generation_kwargs["image"] = processed_image
|
475 |
+
# disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
|
476 |
+
# apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
|
477 |
+
frames = self.image_to_video(**generation_kwargs).frames
|
478 |
+
else:
|
479 |
+
# disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
|
480 |
+
# apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
|
481 |
+
frames = self.text_to_video(**generation_kwargs).frames
|
482 |
+
|
483 |
+
try:
|
484 |
+
loop = asyncio.get_event_loop()
|
485 |
+
except RuntimeError:
|
486 |
+
loop = asyncio.new_event_loop()
|
487 |
+
asyncio.set_event_loop(loop)
|
488 |
+
|
489 |
+
video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
|
490 |
+
|
491 |
+
torch.cuda.empty_cache()
|
492 |
+
torch.cuda.reset_peak_memory_stats()
|
493 |
+
gc.collect()
|
494 |
+
|
495 |
+
return {
|
496 |
+
"video": video_uri,
|
497 |
+
"content-type": "video/mp4",
|
498 |
+
"metadata": metadata
|
499 |
+
}
|
500 |
+
|
501 |
+
except Exception as e:
|
502 |
+
message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
|
503 |
+
print(message)
|
504 |
+
raise RuntimeError(message)
|