jbilcke-hf HF Staff commited on
Commit
03b7962
·
verified ·
1 Parent(s): 5dc9220

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -4
handler.py CHANGED
@@ -143,19 +143,23 @@ class EndpointHandler:
143
  Args:
144
  model_path: Path to LTX model weights
145
  """
 
146
  # Enable TF32 for potential speedup on Ampere GPUs
147
  #torch.backends.cuda.matmul.allow_tf32 = True
148
 
149
  # use distilled weights
150
- model_path = "/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors"
151
 
 
152
  transformer = LTXVideoTransformer3DModel.from_single_file(
153
  model_path, torch_dtype=torch.bfloat16
154
  )
155
 
 
156
  vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
157
 
158
  if support_image_prompt:
 
159
  self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
160
  "/repository",
161
  transformer=transformer,
@@ -173,6 +177,7 @@ class EndpointHandler:
173
  #self.image_to_video.unet = torch.compile(self.image_to_video.unet, mode="reduce-overhead", fullgraph=True)
174
 
175
  else:
 
176
  # Initialize models with bfloat16 precision
177
  self.text_to_video = LTXPipeline.from_pretrained(
178
  "/repository",
@@ -227,7 +232,7 @@ class EndpointHandler:
227
  # magic_number = pickle_module.load(f, **pickle_load_args)
228
  # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
229
  # _pickle.UnpicklingError: invalid load key, '<'.
230
- enable_mmaudio=True,
231
  )
232
 
233
  # Determine if TeaCache is already installed or not
@@ -319,7 +324,10 @@ class EndpointHandler:
319
  - content-type: MIME type
320
  - metadata: Generation metadata
321
  """
 
322
  inputs = data.get("inputs", dict())
 
 
323
 
324
  input_prompt = inputs.get("prompt", "")
325
  input_image = inputs.get("image")
@@ -360,14 +368,14 @@ class EndpointHandler:
360
  quality=params.get("quality", GenerationConfig.quality),
361
 
362
  # TeaCache settings
363
- enable_teacache=params.get("enable_teacache", True),
364
 
365
  # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
366
  teacache_threshold=params.get("teacache_threshold", 0.05),
367
 
368
 
369
  # Add enhance-a-video settings
370
- enable_enhance_a_video=params.get("enable_enhance_a_video", True),
371
  enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
372
 
373
  # LoRA settings
 
143
  Args:
144
  model_path: Path to LTX model weights
145
  """
146
+ print("EndpointHandler.__init__(): initializing..")
147
  # Enable TF32 for potential speedup on Ampere GPUs
148
  #torch.backends.cuda.matmul.allow_tf32 = True
149
 
150
  # use distilled weights
151
+ model_path = Path("/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors")
152
 
153
+ print("EndpointHandler.__init__(): initializing LTXVideoTransformer3DModel..")
154
  transformer = LTXVideoTransformer3DModel.from_single_file(
155
  model_path, torch_dtype=torch.bfloat16
156
  )
157
 
158
+ print("EndpointHandler.__init__(): initializing AutoencoderKLLTXVideo..")
159
  vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
160
 
161
  if support_image_prompt:
162
+ print("EndpointHandler.__init__(): initializing LTXImageToVideoPipeline..")
163
  self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
164
  "/repository",
165
  transformer=transformer,
 
177
  #self.image_to_video.unet = torch.compile(self.image_to_video.unet, mode="reduce-overhead", fullgraph=True)
178
 
179
  else:
180
+ print("EndpointHandler.__init__(): initializing LTXPipeline..")
181
  # Initialize models with bfloat16 precision
182
  self.text_to_video = LTXPipeline.from_pretrained(
183
  "/repository",
 
232
  # magic_number = pickle_module.load(f, **pickle_load_args)
233
  # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
234
  # _pickle.UnpicklingError: invalid load key, '<'.
235
+ enable_mmaudio=False,
236
  )
237
 
238
  # Determine if TeaCache is already installed or not
 
324
  - content-type: MIME type
325
  - metadata: Generation metadata
326
  """
327
+ print("__call__(): inputs = data.get('inputs', dict())")
328
  inputs = data.get("inputs", dict())
329
+ print("inputs = ")
330
+ print(inputs)
331
 
332
  input_prompt = inputs.get("prompt", "")
333
  input_image = inputs.get("image")
 
368
  quality=params.get("quality", GenerationConfig.quality),
369
 
370
  # TeaCache settings
371
+ enable_teacache=params.get("enable_teacache", False),
372
 
373
  # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
374
  teacache_threshold=params.get("teacache_threshold", 0.05),
375
 
376
 
377
  # Add enhance-a-video settings
378
+ enable_enhance_a_video=params.get("enable_enhance_a_video", False),
379
  enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),
380
 
381
  # LoRA settings