Update handler.py
Browse files- handler.py +18 -14
handler.py
CHANGED
@@ -279,21 +279,25 @@ def create_ltx_video_pipeline(
|
|
279 |
"""Create and configure the LTX video pipeline"""
|
280 |
# Get the absolute paths for the model components
|
281 |
current_dir = Path.cwd()
|
|
|
|
|
282 |
|
283 |
# Get allowed inference steps from config if available
|
284 |
allowed_inference_steps = None
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
# Initialize model components
|
295 |
-
vae = CausalVideoAutoencoder.from_pretrained(
|
296 |
-
transformer = Transformer3DModel.from_pretrained(
|
297 |
|
298 |
# Use constructor if sampler is specified, otherwise use from_pretrained
|
299 |
if config.sampler:
|
@@ -301,11 +305,11 @@ def create_ltx_video_pipeline(
|
|
301 |
sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
|
302 |
)
|
303 |
else:
|
304 |
-
scheduler = RectifiedFlowScheduler.from_pretrained(
|
305 |
|
306 |
-
text_encoder = T5EncoderModel.from_pretrained("text_encoder")
|
307 |
patchifier = SymmetricPatchifier(patch_size=1)
|
308 |
-
tokenizer = T5Tokenizer.from_pretrained("tokenizer")
|
309 |
|
310 |
# Move models to the correct device
|
311 |
vae = vae.to(device)
|
|
|
279 |
"""Create and configure the LTX video pipeline"""
|
280 |
# Get the absolute paths for the model components
|
281 |
current_dir = Path.cwd()
|
282 |
+
|
283 |
+
ckpt_path = "./txv-2b-0.9.6-distilled-04-25.safetensors"
|
284 |
|
285 |
# Get allowed inference steps from config if available
|
286 |
allowed_inference_steps = None
|
287 |
+
|
288 |
+
assert os.path.exists(
|
289 |
+
ckpt_path
|
290 |
+
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
|
291 |
+
|
292 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
293 |
+
metadata = f.metadata()
|
294 |
+
config_str = metadata.get("config")
|
295 |
+
configs = json.loads(config_str)
|
296 |
+
allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
297 |
+
|
298 |
# Initialize model components
|
299 |
+
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
300 |
+
transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
301 |
|
302 |
# Use constructor if sampler is specified, otherwise use from_pretrained
|
303 |
if config.sampler:
|
|
|
305 |
sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
|
306 |
)
|
307 |
else:
|
308 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
|
309 |
|
310 |
+
text_encoder = T5EncoderModel.from_pretrained("./text_encoder")
|
311 |
patchifier = SymmetricPatchifier(patch_size=1)
|
312 |
+
tokenizer = T5Tokenizer.from_pretrained("./tokenizer")
|
313 |
|
314 |
# Move models to the correct device
|
315 |
vae = vae.to(device)
|