Update handler.py
Browse files- handler.py +12 -4
handler.py
CHANGED
@@ -143,19 +143,23 @@ class EndpointHandler:
|
|
143 |
Args:
|
144 |
model_path: Path to LTX model weights
|
145 |
"""
|
|
|
146 |
# Enable TF32 for potential speedup on Ampere GPUs
|
147 |
#torch.backends.cuda.matmul.allow_tf32 = True
|
148 |
|
149 |
# use distilled weights
|
150 |
-
model_path = "/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors"
|
151 |
|
|
|
152 |
transformer = LTXVideoTransformer3DModel.from_single_file(
|
153 |
model_path, torch_dtype=torch.bfloat16
|
154 |
)
|
155 |
|
|
|
156 |
vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
|
157 |
|
158 |
if support_image_prompt:
|
|
|
159 |
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
|
160 |
"/repository",
|
161 |
transformer=transformer,
|
@@ -173,6 +177,7 @@ class EndpointHandler:
|
|
173 |
#self.image_to_video.unet = torch.compile(self.image_to_video.unet, mode="reduce-overhead", fullgraph=True)
|
174 |
|
175 |
else:
|
|
|
176 |
# Initialize models with bfloat16 precision
|
177 |
self.text_to_video = LTXPipeline.from_pretrained(
|
178 |
"/repository",
|
@@ -227,7 +232,7 @@ class EndpointHandler:
|
|
227 |
# magic_number = pickle_module.load(f, **pickle_load_args)
|
228 |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
229 |
# _pickle.UnpicklingError: invalid load key, '<'.
|
230 |
-
enable_mmaudio=
|
231 |
)
|
232 |
|
233 |
# Determine if TeaCache is already installed or not
|
@@ -319,7 +324,10 @@ class EndpointHandler:
|
|
319 |
- content-type: MIME type
|
320 |
- metadata: Generation metadata
|
321 |
"""
|
|
|
322 |
inputs = data.get("inputs", dict())
|
|
|
|
|
323 |
|
324 |
input_prompt = inputs.get("prompt", "")
|
325 |
input_image = inputs.get("image")
|
@@ -360,14 +368,14 @@ class EndpointHandler:
|
|
360 |
quality=params.get("quality", GenerationConfig.quality),
|
361 |
|
362 |
# TeaCache settings
|
363 |
-
enable_teacache=params.get("enable_teacache",
|
364 |
|
365 |
# values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
|
366 |
teacache_threshold=params.get("teacache_threshold", 0.05),
|
367 |
|
368 |
|
369 |
# Add enhance-a-video settings
|
370 |
-
enable_enhance_a_video=params.get("enable_enhance_a_video",
|
371 |
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
|
372 |
|
373 |
# LoRA settings
|
|
|
143 |
Args:
|
144 |
model_path: Path to LTX model weights
|
145 |
"""
|
146 |
+
print("EndpointHandler.__init__(): initializing..")
|
147 |
# Enable TF32 for potential speedup on Ampere GPUs
|
148 |
#torch.backends.cuda.matmul.allow_tf32 = True
|
149 |
|
150 |
# use distilled weights
|
151 |
+
model_path = Path("/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors")
|
152 |
|
153 |
+
print("EndpointHandler.__init__(): initializing LTXVideoTransformer3DModel..")
|
154 |
transformer = LTXVideoTransformer3DModel.from_single_file(
|
155 |
model_path, torch_dtype=torch.bfloat16
|
156 |
)
|
157 |
|
158 |
+
print("EndpointHandler.__init__(): initializing AutoencoderKLLTXVideo..")
|
159 |
vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
|
160 |
|
161 |
if support_image_prompt:
|
162 |
+
print("EndpointHandler.__init__(): initializing LTXImageToVideoPipeline..")
|
163 |
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
|
164 |
"/repository",
|
165 |
transformer=transformer,
|
|
|
177 |
#self.image_to_video.unet = torch.compile(self.image_to_video.unet, mode="reduce-overhead", fullgraph=True)
|
178 |
|
179 |
else:
|
180 |
+
print("EndpointHandler.__init__(): initializing LTXPipeline..")
|
181 |
# Initialize models with bfloat16 precision
|
182 |
self.text_to_video = LTXPipeline.from_pretrained(
|
183 |
"/repository",
|
|
|
232 |
# magic_number = pickle_module.load(f, **pickle_load_args)
|
233 |
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
234 |
# _pickle.UnpicklingError: invalid load key, '<'.
|
235 |
+
enable_mmaudio=False,
|
236 |
)
|
237 |
|
238 |
# Determine if TeaCache is already installed or not
|
|
|
324 |
- content-type: MIME type
|
325 |
- metadata: Generation metadata
|
326 |
"""
|
327 |
+
print("__call__(): inputs = data.get('inputs', dict())")
|
328 |
inputs = data.get("inputs", dict())
|
329 |
+
print("inputs = ")
|
330 |
+
print(inputs)
|
331 |
|
332 |
input_prompt = inputs.get("prompt", "")
|
333 |
input_image = inputs.get("image")
|
|
|
368 |
quality=params.get("quality", GenerationConfig.quality),
|
369 |
|
370 |
# TeaCache settings
|
371 |
+
enable_teacache=params.get("enable_teacache", False),
|
372 |
|
373 |
# values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
|
374 |
teacache_threshold=params.get("teacache_threshold", 0.05),
|
375 |
|
376 |
|
377 |
# Add enhance-a-video settings
|
378 |
+
enable_enhance_a_video=params.get("enable_enhance_a_video", False),
|
379 |
enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
|
380 |
|
381 |
# LoRA settings
|