jbilcke-hf HF Staff commited on
Commit
70cf89f
·
verified ·
1 Parent(s): 4da9e55

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -14
handler.py CHANGED
@@ -279,21 +279,25 @@ def create_ltx_video_pipeline(
279
  """Create and configure the LTX video pipeline"""
280
  # Get the absolute paths for the model components
281
  current_dir = Path.cwd()
 
 
282
 
283
  # Get allowed inference steps from config if available
284
  allowed_inference_steps = None
285
- try:
286
- # Load allowed inference steps from metadata if available
287
- if Path("transformer/config.json").exists():
288
- with open("transformer/config.json", "r") as f:
289
- config_data = json.load(f)
290
- allowed_inference_steps = config_data.get("allowed_inference_steps")
291
- except Exception as e:
292
- logger.warning(f"Failed to load allowed_inference_steps from config: {e}")
293
-
 
 
294
  # Initialize model components
295
- vae = CausalVideoAutoencoder.from_pretrained("vae")
296
- transformer = Transformer3DModel.from_pretrained("transformer")
297
 
298
  # Use constructor if sampler is specified, otherwise use from_pretrained
299
  if config.sampler:
@@ -301,11 +305,11 @@ def create_ltx_video_pipeline(
301
  sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
302
  )
303
  else:
304
- scheduler = RectifiedFlowScheduler.from_pretrained("scheduler")
305
 
306
- text_encoder = T5EncoderModel.from_pretrained("text_encoder")
307
  patchifier = SymmetricPatchifier(patch_size=1)
308
- tokenizer = T5Tokenizer.from_pretrained("tokenizer")
309
 
310
  # Move models to the correct device
311
  vae = vae.to(device)
 
279
  """Create and configure the LTX video pipeline"""
280
  # Get the absolute paths for the model components
281
  current_dir = Path.cwd()
282
+
283
+ ckpt_path = "./txv-2b-0.9.6-distilled-04-25.safetensors"
284
 
285
  # Get allowed inference steps from config if available
286
  allowed_inference_steps = None
287
+
288
+ assert os.path.exists(
289
+ ckpt_path
290
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
291
+
292
+ with safe_open(ckpt_path, framework="pt") as f:
293
+ metadata = f.metadata()
294
+ config_str = metadata.get("config")
295
+ configs = json.loads(config_str)
296
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
297
+
298
  # Initialize model components
299
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
300
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
301
 
302
  # Use constructor if sampler is specified, otherwise use from_pretrained
303
  if config.sampler:
 
305
  sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
306
  )
307
  else:
308
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
309
 
310
+ text_encoder = T5EncoderModel.from_pretrained("./text_encoder")
311
  patchifier = SymmetricPatchifier(patch_size=1)
312
+ tokenizer = T5Tokenizer.from_pretrained("./tokenizer")
313
 
314
  # Move models to the correct device
315
  vae = vae.to(device)