jacobitterman linoyts HF Staff commited on
Commit
68008ac
·
verified ·
1 Parent(s): 77756bc

update to 0.9.8 (#23)

Browse files

- update to 0.9.8 (0135c48cef03cb1cbdefffc47d83a85095518e10)
- Update app.py (e1c8d1db758dbb5663f2b31af5be19c7aaf5b5c3)
- update code base with main (27e0be50ac121b4b63239a5b5dc6a11eddac9209)
- Update app.py (c3db712c6a92b978a2feb71ccfd70038b6ee3108)
- Update app.py (ab760cfbc5339b4551b6e5a02020981b4c8affe3)


Co-authored-by: Linoy Tsaban <linoyts@users.noreply.huggingface.co>

app.py CHANGED
@@ -24,7 +24,7 @@ from inference import (
24
  from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
25
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
26
 
27
- config_file_path = "configs/ltxv-13b-0.9.7-distilled.yaml"
28
  with open(config_file_path, "r") as file:
29
  PIPELINE_CONFIG_YAML = yaml.safe_load(file)
30
 
@@ -374,8 +374,8 @@ css="""
374
  """
375
 
376
  with gr.Blocks(css=css) as demo:
377
- gr.Markdown("# LTX Video 0.9.7 Distilled")
378
- gr.Markdown("Fast high quality video generation. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](https://huggingface.co/Lightricks/LTX-Video-0.9.7-distilled#diffusers-🧨)")
379
 
380
  with gr.Row():
381
  with gr.Column():
@@ -404,7 +404,7 @@ with gr.Blocks(css=css) as demo:
404
  step=0.1,
405
  info=f"Target video duration (0.3s to 8.5s)"
406
  )
407
- improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
408
 
409
  with gr.Column():
410
  output_video = gr.Video(label="Generated Video", interactive=False)
@@ -416,7 +416,7 @@ with gr.Blocks(css=css) as demo:
416
  with gr.Row():
417
  seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
418
  randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
419
- with gr.Row():
420
  guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
421
  with gr.Row():
422
  height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
 
24
  from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
25
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
26
 
27
+ config_file_path = "configs/ltxv-13b-0.9.8-distilled.yaml"
28
  with open(config_file_path, "r") as file:
29
  PIPELINE_CONFIG_YAML = yaml.safe_load(file)
30
 
 
374
  """
375
 
376
  with gr.Blocks(css=css) as demo:
377
+ gr.Markdown("# LTX Video 0.9.8 13B Distilled")
378
+ gr.Markdown("Fast high quality video generation. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.8-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](https://huggingface.co/Lightricks/LTX-Video-0.9.8-13B-distilled#diffusers-🧨)")
379
 
380
  with gr.Row():
381
  with gr.Column():
 
404
  step=0.1,
405
  info=f"Target video duration (0.3s to 8.5s)"
406
  )
407
+ improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True,visible=False, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
408
 
409
  with gr.Column():
410
  output_video = gr.Video(label="Generated Video", interactive=False)
 
416
  with gr.Row():
417
  seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
418
  randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
419
+ with gr.Row(visible=False):
420
  guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
421
  with gr.Row():
422
  height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
configs/ltxv-13b-0.9.8-dev-fp8.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline_type: multi-scale
2
+ checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors"
3
+ downscale_factor: 0.6666666
4
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
5
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6
+ decode_timestep: 0.05
7
+ decode_noise_scale: 0.025
8
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9
+ precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
10
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11
+ prompt_enhancement_words_threshold: 120
12
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14
+ stochastic_sampling: false
15
+
16
+ first_pass:
17
+ guidance_scale: [1, 1, 6, 8, 6, 1, 1]
18
+ stg_scale: [0, 0, 4, 4, 4, 2, 1]
19
+ rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
20
+ guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
21
+ skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
22
+ num_inference_steps: 30
23
+ skip_final_inference_steps: 3
24
+ cfg_star_rescale: true
25
+
26
+ second_pass:
27
+ guidance_scale: [1]
28
+ stg_scale: [1]
29
+ rescaling_scale: [1]
30
+ guidance_timesteps: [1.0]
31
+ skip_block_list: [27]
32
+ num_inference_steps: 30
33
+ skip_initial_inference_steps: 17
34
+ cfg_star_rescale: true
configs/ltxv-13b-0.9.8-dev.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline_type: multi-scale
2
+ checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
3
+ downscale_factor: 0.6666666
4
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
5
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6
+ decode_timestep: 0.05
7
+ decode_noise_scale: 0.025
8
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9
+ precision: "bfloat16"
10
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11
+ prompt_enhancement_words_threshold: 120
12
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14
+ stochastic_sampling: false
15
+
16
+ first_pass:
17
+ guidance_scale: [1, 1, 6, 8, 6, 1, 1]
18
+ stg_scale: [0, 0, 4, 4, 4, 2, 1]
19
+ rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
20
+ guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
21
+ skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
22
+ num_inference_steps: 30
23
+ skip_final_inference_steps: 3
24
+ cfg_star_rescale: true
25
+
26
+ second_pass:
27
+ guidance_scale: [1]
28
+ stg_scale: [1]
29
+ rescaling_scale: [1]
30
+ guidance_timesteps: [1.0]
31
+ skip_block_list: [27]
32
+ num_inference_steps: 30
33
+ skip_initial_inference_steps: 17
34
+ cfg_star_rescale: true
configs/ltxv-13b-0.9.8-distilled-fp8.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline_type: multi-scale
2
+ checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors"
3
+ downscale_factor: 0.6666666
4
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
5
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6
+ decode_timestep: 0.05
7
+ decode_noise_scale: 0.025
8
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9
+ precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
10
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11
+ prompt_enhancement_words_threshold: 120
12
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14
+ stochastic_sampling: false
15
+
16
+ first_pass:
17
+ timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
18
+ guidance_scale: 1
19
+ stg_scale: 0
20
+ rescaling_scale: 1
21
+ skip_block_list: [42]
22
+
23
+ second_pass:
24
+ timesteps: [0.9094, 0.7250, 0.4219]
25
+ guidance_scale: 1
26
+ stg_scale: 0
27
+ rescaling_scale: 1
28
+ skip_block_list: [42]
29
+ tone_map_compression_ratio: 0.6
configs/ltxv-13b-0.9.8-distilled.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline_type: multi-scale
2
+ checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
3
+ downscale_factor: 0.6666666
4
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
5
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6
+ decode_timestep: 0.05
7
+ decode_noise_scale: 0.025
8
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9
+ precision: "bfloat16"
10
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11
+ prompt_enhancement_words_threshold: 120
12
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14
+ stochastic_sampling: false
15
+
16
+ first_pass:
17
+ timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
18
+ guidance_scale: 1
19
+ stg_scale: 0
20
+ rescaling_scale: 1
21
+ skip_block_list: [42]
22
+
23
+ second_pass:
24
+ timesteps: [0.9094, 0.7250, 0.4219]
25
+ guidance_scale: 1
26
+ stg_scale: 0
27
+ rescaling_scale: 1
28
+ skip_block_list: [42]
29
+ tone_map_compression_ratio: 0.6
configs/ltxv-2b-0.9.8-distilled-fp8.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline_type: multi-scale
2
+ checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors"
3
+ downscale_factor: 0.6666666
4
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
5
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6
+ decode_timestep: 0.05
7
+ decode_noise_scale: 0.025
8
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9
+ precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
10
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11
+ prompt_enhancement_words_threshold: 120
12
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14
+ stochastic_sampling: false
15
+
16
+ first_pass:
17
+ timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
18
+ guidance_scale: 1
19
+ stg_scale: 0
20
+ rescaling_scale: 1
21
+ skip_block_list: [42]
22
+
23
+ second_pass:
24
+ timesteps: [0.9094, 0.7250, 0.4219]
25
+ guidance_scale: 1
26
+ stg_scale: 0
27
+ rescaling_scale: 1
28
+ skip_block_list: [42]
configs/ltxv-2b-0.9.8-distilled.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pipeline_type: multi-scale
2
+ checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors"
3
+ downscale_factor: 0.6666666
4
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
5
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6
+ decode_timestep: 0.05
7
+ decode_noise_scale: 0.025
8
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9
+ precision: "bfloat16"
10
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11
+ prompt_enhancement_words_threshold: 120
12
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14
+ stochastic_sampling: false
15
+
16
+ first_pass:
17
+ timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
18
+ guidance_scale: 1
19
+ stg_scale: 0
20
+ rescaling_scale: 1
21
+ skip_block_list: [42]
22
+
23
+ second_pass:
24
+ timesteps: [0.9094, 0.7250, 0.4219]
25
+ guidance_scale: 1
26
+ stg_scale: 0
27
+ rescaling_scale: 1
28
+ skip_block_list: [42]
ltx_video/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -235,7 +235,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
235
  "compress_time",
236
  "compress_all",
237
  "compress_all_res",
238
- "compress_space_res",
239
  ]
240
  ]
241
  )
@@ -608,7 +608,7 @@ class Decoder(nn.Module):
608
  block_params = block_params if isinstance(block_params, dict) else {}
609
  if block_name == "res_x_y":
610
  output_channel = output_channel * block_params.get("multiplier", 2)
611
- if block_name == "compress_all":
612
  output_channel = output_channel * block_params.get("multiplier", 1)
613
 
614
  self.conv_in = make_conv_nd(
@@ -1303,20 +1303,15 @@ def create_video_autoencoder_demo_config(
1303
  encoder_blocks = [
1304
  ("res_x", {"num_layers": 2}),
1305
  ("compress_space_res", {"multiplier": 2}),
1306
- ("res_x", {"num_layers": 2}),
1307
  ("compress_time_res", {"multiplier": 2}),
1308
- ("res_x", {"num_layers": 1}),
1309
  ("compress_all_res", {"multiplier": 2}),
1310
- ("res_x", {"num_layers": 1}),
1311
  ("compress_all_res", {"multiplier": 2}),
1312
  ("res_x", {"num_layers": 1}),
1313
  ]
1314
  decoder_blocks = [
1315
  ("res_x", {"num_layers": 2, "inject_noise": False}),
1316
  ("compress_all", {"residual": True, "multiplier": 2}),
1317
- ("res_x", {"num_layers": 2, "inject_noise": False}),
1318
  ("compress_all", {"residual": True, "multiplier": 2}),
1319
- ("res_x", {"num_layers": 2, "inject_noise": False}),
1320
  ("compress_all", {"residual": True, "multiplier": 2}),
1321
  ("res_x", {"num_layers": 2, "inject_noise": False}),
1322
  ]
 
235
  "compress_time",
236
  "compress_all",
237
  "compress_all_res",
238
+ "compress_time_res",
239
  ]
240
  ]
241
  )
 
608
  block_params = block_params if isinstance(block_params, dict) else {}
609
  if block_name == "res_x_y":
610
  output_channel = output_channel * block_params.get("multiplier", 2)
611
+ if block_name.startswith("compress"):
612
  output_channel = output_channel * block_params.get("multiplier", 1)
613
 
614
  self.conv_in = make_conv_nd(
 
1303
  encoder_blocks = [
1304
  ("res_x", {"num_layers": 2}),
1305
  ("compress_space_res", {"multiplier": 2}),
 
1306
  ("compress_time_res", {"multiplier": 2}),
 
1307
  ("compress_all_res", {"multiplier": 2}),
 
1308
  ("compress_all_res", {"multiplier": 2}),
1309
  ("res_x", {"num_layers": 1}),
1310
  ]
1311
  decoder_blocks = [
1312
  ("res_x", {"num_layers": 2, "inject_noise": False}),
1313
  ("compress_all", {"residual": True, "multiplier": 2}),
 
1314
  ("compress_all", {"residual": True, "multiplier": 2}),
 
1315
  ("compress_all", {"residual": True, "multiplier": 2}),
1316
  ("res_x", {"num_layers": 2, "inject_noise": False}),
1317
  ]
ltx_video/models/transformers/attention.py CHANGED
@@ -205,7 +205,6 @@ class BasicTransformerBlock(nn.Module):
205
  timestep: Optional[torch.LongTensor] = None,
206
  cross_attention_kwargs: Dict[str, Any] = None,
207
  class_labels: Optional[torch.LongTensor] = None,
208
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
209
  skip_layer_mask: Optional[torch.Tensor] = None,
210
  skip_layer_strategy: Optional[SkipLayerStrategy] = None,
211
  ) -> torch.FloatTensor:
 
205
  timestep: Optional[torch.LongTensor] = None,
206
  cross_attention_kwargs: Dict[str, Any] = None,
207
  class_labels: Optional[torch.LongTensor] = None,
 
208
  skip_layer_mask: Optional[torch.Tensor] = None,
209
  skip_layer_strategy: Optional[SkipLayerStrategy] = None,
210
  ) -> torch.FloatTensor:
ltx_video/models/transformers/transformer3d.py CHANGED
@@ -268,7 +268,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
268
  for key, value in state_dict.items()
269
  if key.startswith("model.diffusion_model.")
270
  }
271
- super().load_state_dict(state_dict, **kwargs)
272
 
273
  @classmethod
274
  def from_pretrained(
 
268
  for key, value in state_dict.items()
269
  if key.startswith("model.diffusion_model.")
270
  }
271
+ super().load_state_dict(state_dict, *args, **kwargs)
272
 
273
  @classmethod
274
  def from_pretrained(
ltx_video/pipelines/pipeline_ltx_video.py CHANGED
@@ -45,11 +45,6 @@ from ltx_video.models.autoencoders.vae_encode import (
45
  )
46
 
47
 
48
- try:
49
- import torch_xla.distributed.spmd as xs
50
- except ImportError:
51
- xs = None
52
-
53
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
 
55
 
@@ -795,6 +790,7 @@ class LTXVideoPipeline(DiffusionPipeline):
795
  text_encoder_max_tokens: int = 256,
796
  stochastic_sampling: bool = False,
797
  media_items: Optional[torch.Tensor] = None,
 
798
  **kwargs,
799
  ) -> Union[ImagePipelineOutput, Tuple]:
800
  """
@@ -876,6 +872,8 @@ class LTXVideoPipeline(DiffusionPipeline):
876
  If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic.
877
  media_items ('torch.Tensor', *optional*):
878
  The input media item used for image-to-image / video-to-video.
 
 
879
  Examples:
880
 
881
  Returns:
@@ -978,10 +976,6 @@ class LTXVideoPipeline(DiffusionPipeline):
978
  guidance_scale[guidance_mapping[i]] for i in range(len(timesteps))
979
  ]
980
 
981
- # For simplicity, we are using a constant num_conds for all timesteps, so we need to zero
982
- # out cases where the guidance scale should not be applied.
983
- guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale]
984
-
985
  if not isinstance(stg_scale, List):
986
  stg_scale = [stg_scale] * len(timesteps)
987
  else:
@@ -994,16 +988,6 @@ class LTXVideoPipeline(DiffusionPipeline):
994
  rescaling_scale[guidance_mapping[i]] for i in range(len(timesteps))
995
  ]
996
 
997
- do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale)
998
- do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale)
999
- do_rescaling = any(x != 1.0 for x in rescaling_scale)
1000
-
1001
- num_conds = 1
1002
- if do_classifier_free_guidance:
1003
- num_conds += 1
1004
- if do_spatio_temporal_guidance:
1005
- num_conds += 1
1006
-
1007
  # Normalize skip_block_list to always be None or a list of lists matching timesteps
1008
  if skip_block_list is not None:
1009
  # Convert single list to list of lists if needed
@@ -1015,17 +999,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1015
  new_skip_block_list.append(skip_block_list[guidance_mapping[i]])
1016
  skip_block_list = new_skip_block_list
1017
 
1018
- # Prepare skip layer masks
1019
- skip_layer_masks: Optional[List[torch.Tensor]] = None
1020
- if do_spatio_temporal_guidance:
1021
- if skip_block_list is not None:
1022
- skip_layer_masks = [
1023
- self.transformer.create_skip_layer_mask(
1024
- batch_size, num_conds, num_conds - 1, skip_blocks
1025
- )
1026
- for skip_blocks in skip_block_list
1027
- ]
1028
-
1029
  if enhance_prompt:
1030
  self.prompt_enhancer_image_caption_model = (
1031
  self.prompt_enhancer_image_caption_model.to(self._execution_device)
@@ -1055,7 +1028,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1055
  negative_prompt_attention_mask,
1056
  ) = self.encode_prompt(
1057
  prompt,
1058
- do_classifier_free_guidance,
1059
  negative_prompt=negative_prompt,
1060
  num_images_per_prompt=num_images_per_prompt,
1061
  device=device,
@@ -1073,23 +1046,28 @@ class LTXVideoPipeline(DiffusionPipeline):
1073
 
1074
  prompt_embeds_batch = prompt_embeds
1075
  prompt_attention_mask_batch = prompt_attention_mask
1076
- if do_classifier_free_guidance:
1077
- prompt_embeds_batch = torch.cat(
1078
- [negative_prompt_embeds, prompt_embeds], dim=0
1079
- )
1080
- prompt_attention_mask_batch = torch.cat(
1081
- [negative_prompt_attention_mask, prompt_attention_mask], dim=0
1082
- )
1083
- if do_spatio_temporal_guidance:
1084
- prompt_embeds_batch = torch.cat([prompt_embeds_batch, prompt_embeds], dim=0)
1085
- prompt_attention_mask_batch = torch.cat(
1086
- [
1087
- prompt_attention_mask_batch,
1088
- prompt_attention_mask,
1089
- ],
1090
- dim=0,
1091
- )
1092
 
 
 
 
 
 
 
 
 
 
 
 
1093
  # 4. Prepare the initial latents using the provided media and conditioning items
1094
 
1095
  # Prepare the initial latents tensor, shape = (b, c, f, h, w)
@@ -1098,7 +1076,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1098
  media_items=media_items,
1099
  timestep=timesteps[0],
1100
  latent_shape=latent_shape,
1101
- dtype=prompt_embeds_batch.dtype,
1102
  device=device,
1103
  generator=generator,
1104
  vae_per_channel_normalize=vae_per_channel_normalize,
@@ -1118,14 +1096,6 @@ class LTXVideoPipeline(DiffusionPipeline):
1118
  )
1119
  init_latents = latents.clone() # Used for image_cond_noise_update
1120
 
1121
- pixel_coords = torch.cat([pixel_coords] * num_conds)
1122
- orig_conditioning_mask = conditioning_mask
1123
- if conditioning_mask is not None and is_video:
1124
- assert num_images_per_prompt == 1
1125
- conditioning_mask = torch.cat([conditioning_mask] * num_conds)
1126
- fractional_coords = pixel_coords.to(torch.float32)
1127
- fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
1128
-
1129
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1130
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1131
 
@@ -1134,8 +1104,50 @@ class LTXVideoPipeline(DiffusionPipeline):
1134
  len(timesteps) - num_inference_steps * self.scheduler.order, 0
1135
  )
1136
 
 
 
 
 
 
 
 
1137
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1138
  for i, t in enumerate(timesteps):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
  if conditioning_mask is not None and image_cond_noise_scale > 0.0:
1140
  latents = self.add_noise_to_image_conditioning_latents(
1141
  t,
@@ -1194,16 +1206,12 @@ class LTXVideoPipeline(DiffusionPipeline):
1194
  noise_pred = self.transformer(
1195
  latent_model_input.to(self.transformer.dtype),
1196
  indices_grid=fractional_coords,
1197
- encoder_hidden_states=prompt_embeds_batch.to(
1198
  self.transformer.dtype
1199
  ),
1200
- encoder_attention_mask=prompt_attention_mask_batch,
1201
  timestep=current_timestep,
1202
- skip_layer_mask=(
1203
- skip_layer_masks[i]
1204
- if skip_layer_masks is not None
1205
- else None
1206
- ),
1207
  skip_layer_strategy=skip_layer_strategy,
1208
  return_dict=False,
1209
  )[0]
@@ -1315,6 +1323,7 @@ class LTXVideoPipeline(DiffusionPipeline):
1315
  )
1316
  else:
1317
  decode_timestep = None
 
1318
  image = vae_decode(
1319
  latents,
1320
  self.vae,
@@ -1736,6 +1745,47 @@ class LTXVideoPipeline(DiffusionPipeline):
1736
  num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
1737
  return num_frames
1738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1739
 
1740
  def adain_filter_latent(
1741
  latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
 
45
  )
46
 
47
 
 
 
 
 
 
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
 
50
 
 
790
  text_encoder_max_tokens: int = 256,
791
  stochastic_sampling: bool = False,
792
  media_items: Optional[torch.Tensor] = None,
793
+ tone_map_compression_ratio: float = 0.0,
794
  **kwargs,
795
  ) -> Union[ImagePipelineOutput, Tuple]:
796
  """
 
872
  If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic.
873
  media_items ('torch.Tensor', *optional*):
874
  The input media item used for image-to-image / video-to-video.
875
+ tone_map_compression_ratio: compression ratio for tone mapping, defaults to 0.0.
876
+ If set to 0.0, no tone mapping is applied. If set to 1.0 - full compression is applied.
877
  Examples:
878
 
879
  Returns:
 
976
  guidance_scale[guidance_mapping[i]] for i in range(len(timesteps))
977
  ]
978
 
 
 
 
 
979
  if not isinstance(stg_scale, List):
980
  stg_scale = [stg_scale] * len(timesteps)
981
  else:
 
988
  rescaling_scale[guidance_mapping[i]] for i in range(len(timesteps))
989
  ]
990
 
 
 
 
 
 
 
 
 
 
 
991
  # Normalize skip_block_list to always be None or a list of lists matching timesteps
992
  if skip_block_list is not None:
993
  # Convert single list to list of lists if needed
 
999
  new_skip_block_list.append(skip_block_list[guidance_mapping[i]])
1000
  skip_block_list = new_skip_block_list
1001
 
 
 
 
 
 
 
 
 
 
 
 
1002
  if enhance_prompt:
1003
  self.prompt_enhancer_image_caption_model = (
1004
  self.prompt_enhancer_image_caption_model.to(self._execution_device)
 
1028
  negative_prompt_attention_mask,
1029
  ) = self.encode_prompt(
1030
  prompt,
1031
+ True,
1032
  negative_prompt=negative_prompt,
1033
  num_images_per_prompt=num_images_per_prompt,
1034
  device=device,
 
1046
 
1047
  prompt_embeds_batch = prompt_embeds
1048
  prompt_attention_mask_batch = prompt_attention_mask
1049
+ negative_prompt_embeds = (
1050
+ torch.zeros_like(prompt_embeds)
1051
+ if negative_prompt_embeds is None
1052
+ else negative_prompt_embeds
1053
+ )
1054
+ negative_prompt_attention_mask = (
1055
+ torch.zeros_like(prompt_attention_mask)
1056
+ if negative_prompt_attention_mask is None
1057
+ else negative_prompt_attention_mask
1058
+ )
 
 
 
 
 
 
1059
 
1060
+ prompt_embeds_batch = torch.cat(
1061
+ [negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0
1062
+ )
1063
+ prompt_attention_mask_batch = torch.cat(
1064
+ [
1065
+ negative_prompt_attention_mask,
1066
+ prompt_attention_mask,
1067
+ prompt_attention_mask,
1068
+ ],
1069
+ dim=0,
1070
+ )
1071
  # 4. Prepare the initial latents using the provided media and conditioning items
1072
 
1073
  # Prepare the initial latents tensor, shape = (b, c, f, h, w)
 
1076
  media_items=media_items,
1077
  timestep=timesteps[0],
1078
  latent_shape=latent_shape,
1079
+ dtype=prompt_embeds.dtype,
1080
  device=device,
1081
  generator=generator,
1082
  vae_per_channel_normalize=vae_per_channel_normalize,
 
1096
  )
1097
  init_latents = latents.clone() # Used for image_cond_noise_update
1098
 
 
 
 
 
 
 
 
 
1099
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1100
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1101
 
 
1104
  len(timesteps) - num_inference_steps * self.scheduler.order, 0
1105
  )
1106
 
1107
+ orig_conditioning_mask = conditioning_mask
1108
+
1109
+ # Befor compiling this code please be aware:
1110
+ # This code might generate different input shapes if some timesteps have no STG or CFG.
1111
+ # This means that the codes might need to be compiled mutliple times.
1112
+ # To avoid that, use the same STG and CFG values for all timesteps.
1113
+
1114
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1115
  for i, t in enumerate(timesteps):
1116
+ do_classifier_free_guidance = guidance_scale[i] > 1.0
1117
+ do_spatio_temporal_guidance = stg_scale[i] > 0
1118
+ do_rescaling = rescaling_scale[i] != 1.0
1119
+
1120
+ num_conds = 1
1121
+ if do_classifier_free_guidance:
1122
+ num_conds += 1
1123
+ if do_spatio_temporal_guidance:
1124
+ num_conds += 1
1125
+
1126
+ if do_classifier_free_guidance and do_spatio_temporal_guidance:
1127
+ indices = slice(batch_size * 0, batch_size * 3)
1128
+ elif do_classifier_free_guidance:
1129
+ indices = slice(batch_size * 0, batch_size * 2)
1130
+ elif do_spatio_temporal_guidance:
1131
+ indices = slice(batch_size * 1, batch_size * 3)
1132
+ else:
1133
+ indices = slice(batch_size * 1, batch_size * 2)
1134
+
1135
+ # Prepare skip layer masks
1136
+ skip_layer_mask: Optional[torch.Tensor] = None
1137
+ if do_spatio_temporal_guidance:
1138
+ if skip_block_list is not None:
1139
+ skip_layer_mask = self.transformer.create_skip_layer_mask(
1140
+ batch_size, num_conds, num_conds - 1, skip_block_list[i]
1141
+ )
1142
+
1143
+ batch_pixel_coords = torch.cat([pixel_coords] * num_conds)
1144
+ conditioning_mask = orig_conditioning_mask
1145
+ if conditioning_mask is not None and is_video:
1146
+ assert num_images_per_prompt == 1
1147
+ conditioning_mask = torch.cat([conditioning_mask] * num_conds)
1148
+ fractional_coords = batch_pixel_coords.to(torch.float32)
1149
+ fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
1150
+
1151
  if conditioning_mask is not None and image_cond_noise_scale > 0.0:
1152
  latents = self.add_noise_to_image_conditioning_latents(
1153
  t,
 
1206
  noise_pred = self.transformer(
1207
  latent_model_input.to(self.transformer.dtype),
1208
  indices_grid=fractional_coords,
1209
+ encoder_hidden_states=prompt_embeds_batch[indices].to(
1210
  self.transformer.dtype
1211
  ),
1212
+ encoder_attention_mask=prompt_attention_mask_batch[indices],
1213
  timestep=current_timestep,
1214
+ skip_layer_mask=skip_layer_mask,
 
 
 
 
1215
  skip_layer_strategy=skip_layer_strategy,
1216
  return_dict=False,
1217
  )[0]
 
1323
  )
1324
  else:
1325
  decode_timestep = None
1326
+ latents = self.tone_map_latents(latents, tone_map_compression_ratio)
1327
  image = vae_decode(
1328
  latents,
1329
  self.vae,
 
1745
  num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
1746
  return num_frames
1747
 
1748
+ @staticmethod
1749
+ def tone_map_latents(
1750
+ latents: torch.Tensor,
1751
+ compression: float,
1752
+ ) -> torch.Tensor:
1753
+ """
1754
+ Applies a non-linear tone-mapping function to latent values to reduce their dynamic range
1755
+ in a perceptually smooth way using a sigmoid-based compression.
1756
+
1757
+ This is useful for regularizing high-variance latents or for conditioning outputs
1758
+ during generation, especially when controlling dynamic behavior with a `compression` factor.
1759
+
1760
+ Parameters:
1761
+ ----------
1762
+ latents : torch.Tensor
1763
+ Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
1764
+ compression : float
1765
+ Compression strength in the range [0, 1].
1766
+ - 0.0: No tone-mapping (identity transform)
1767
+ - 1.0: Full compression effect
1768
+
1769
+ Returns:
1770
+ -------
1771
+ torch.Tensor
1772
+ The tone-mapped latent tensor of the same shape as input.
1773
+ """
1774
+ if not (0 <= compression <= 1):
1775
+ raise ValueError("Compression must be in the range [0, 1]")
1776
+
1777
+ # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
1778
+ scale_factor = compression * 0.75
1779
+ abs_latents = torch.abs(latents)
1780
+
1781
+ # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
1782
+ # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
1783
+ sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
1784
+ scales = 1.0 - 0.8 * scale_factor * sigmoid_term
1785
+
1786
+ filtered = latents * scales
1787
+ return filtered
1788
+
1789
 
1790
  def adain_filter_latent(
1791
  latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
ltx_video/schedulers/rf.py CHANGED
@@ -314,7 +314,7 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
314
  """
315
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
316
  process from the learned model outputs (most often the predicted noise).
317
- z_{t_1} = z_t - \Delta_t * v
318
  The method finds the next timestep that is lower than the input timestep(s) and denoises the latents
319
  to that level. The input timestep(s) are not required to be one of the predefined timesteps.
320
 
 
314
  """
315
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
316
  process from the learned model outputs (most often the predicted noise).
317
+ z_{t_1} = z_t - Delta_t * v
318
  The method finds the next timestep that is lower than the input timestep(s) and denoises the latents
319
  to that level. The input timestep(s) are not required to be one of the predefined timesteps.
320