jamesliu1217 commited on
Commit
014a7de
·
verified ·
1 Parent(s): f1df3c1

support cfg-zero*

Browse files
Files changed (1) hide show
  1. src/pipeline.py +10 -5
src/pipeline.py CHANGED
@@ -526,9 +526,11 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
526
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
527
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
528
  max_sequence_length: int = 512,
529
- spatial_images=None,
530
- subject_images=None,
531
  cond_size=512,
 
 
532
  ):
533
 
534
  height = height or self.default_sample_size * self.vae_scale_factor
@@ -656,7 +658,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
656
  guidance = guidance.expand(latents.shape[0])
657
  else:
658
  guidance = None
659
-
660
  ## Caching conditions
661
  # clean the cache
662
  for name, attn_processor in self.transformer.attn_processors.items():
@@ -679,7 +681,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
679
  joint_attention_kwargs=self.joint_attention_kwargs,
680
  return_dict=False,
681
  )[0]
682
-
683
  # 6. Denoising loop
684
  with self.progress_bar(total=num_inference_steps) as progress_bar:
685
  for i, t in enumerate(timesteps):
@@ -700,6 +702,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
700
  joint_attention_kwargs=self.joint_attention_kwargs,
701
  return_dict=False,
702
  )[0]
 
 
 
703
 
704
  # compute the previous noisy sample x_t -> x_t-1
705
  latents_dtype = latents.dtype
@@ -742,4 +747,4 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
742
  if not return_dict:
743
  return (image,)
744
 
745
- return FluxPipelineOutput(images=image)
 
526
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
527
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
528
  max_sequence_length: int = 512,
529
+ spatial_images=[],
530
+ subject_images=[],
531
  cond_size=512,
532
+ use_zero_init: Optional[bool] = True,
533
+ zero_steps: Optional[int] = 0,
534
  ):
535
 
536
  height = height or self.default_sample_size * self.vae_scale_factor
 
658
  guidance = guidance.expand(latents.shape[0])
659
  else:
660
  guidance = None
661
+
662
  ## Caching conditions
663
  # clean the cache
664
  for name, attn_processor in self.transformer.attn_processors.items():
 
681
  joint_attention_kwargs=self.joint_attention_kwargs,
682
  return_dict=False,
683
  )[0]
684
+
685
  # 6. Denoising loop
686
  with self.progress_bar(total=num_inference_steps) as progress_bar:
687
  for i, t in enumerate(timesteps):
 
702
  joint_attention_kwargs=self.joint_attention_kwargs,
703
  return_dict=False,
704
  )[0]
705
+
706
+ if (i <= zero_steps) and use_zero_init:
707
+ noise_pred = noise_pred*0.
708
 
709
  # compute the previous noisy sample x_t -> x_t-1
710
  latents_dtype = latents.dtype
 
747
  if not return_dict:
748
  return (image,)
749
 
750
+ return FluxPipelineOutput(images=image)