blanchon commited on
Commit
c7ed5da
·
1 Parent(s): 00e44cf
hi_diffusers/pipelines/hidream_image/pipeline_hidream_image.py CHANGED
@@ -1,31 +1,34 @@
1
  import inspect
2
- from typing import Any, Callable, Dict, List, Optional, Union
3
  import math
 
 
 
4
  import einops
5
  import torch
6
- from transformers import (
7
- CLIPTextModelWithProjection,
8
- CLIPTokenizer,
9
- T5EncoderModel,
10
- T5Tokenizer,
11
- LlamaForCausalLM,
12
- PreTrainedTokenizerFast
13
- )
14
-
15
  from diffusers.image_processor import VaeImageProcessor
16
  from diffusers.loaders import FromSingleFileMixin
17
  from diffusers.models.autoencoders import AutoencoderKL
 
18
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
19
  from diffusers.utils import (
20
- USE_PEFT_BACKEND,
21
  is_torch_xla_available,
22
  logging,
23
  )
24
  from diffusers.utils.torch_utils import randn_tensor
25
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
- from .pipeline_output import HiDreamImagePipelineOutput
27
- from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel
 
 
 
 
 
 
 
 
 
28
  from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
 
29
 
30
  if is_torch_xla_available():
31
  import torch_xla.core.xla_model as xm
@@ -36,6 +39,7 @@ else:
36
 
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
 
 
39
  # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
40
  def calculate_shift(
41
  image_seq_len,
@@ -49,13 +53,14 @@ def calculate_shift(
49
  mu = image_seq_len * m + b
50
  return mu
51
 
 
52
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
53
  def retrieve_timesteps(
54
  scheduler,
55
- num_inference_steps: Optional[int] = None,
56
- device: Optional[Union[str, torch.device]] = None,
57
- timesteps: Optional[List[int]] = None,
58
- sigmas: Optional[List[float]] = None,
59
  **kwargs,
60
  ):
61
  r"""
@@ -80,26 +85,34 @@ def retrieve_timesteps(
80
  Returns:
81
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
82
  second element is the number of inference steps.
 
83
  """
84
  if timesteps is not None and sigmas is not None:
85
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
 
86
  if timesteps is not None:
87
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
88
  if not accepts_timesteps:
89
- raise ValueError(
90
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
91
  f" timestep schedules. Please check whether you are using the correct scheduler."
92
  )
 
93
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
94
  timesteps = scheduler.timesteps
95
  num_inference_steps = len(timesteps)
96
  elif sigmas is not None:
97
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
98
  if not accept_sigmas:
99
- raise ValueError(
100
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
101
  f" sigmas schedules. Please check whether you are using the correct scheduler."
102
  )
 
103
  scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
104
  timesteps = scheduler.timesteps
105
  num_inference_steps = len(timesteps)
@@ -108,6 +121,7 @@ def retrieve_timesteps(
108
  timesteps = scheduler.timesteps
109
  return timesteps, num_inference_steps
110
 
 
111
  class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
112
  model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae"
113
  _optional_components = ["image_encoder", "feature_extractor"]
@@ -115,6 +129,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
115
 
116
  def __init__(
117
  self,
 
118
  scheduler: FlowMatchEulerDiscreteScheduler,
119
  vae: AutoencoderKL,
120
  text_encoder: CLIPTextModelWithProjection,
@@ -129,6 +144,7 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
129
  super().__init__()
130
 
131
  self.register_modules(
 
132
  vae=vae,
133
  text_encoder=text_encoder,
134
  text_encoder_2=text_encoder_2,
@@ -141,21 +157,25 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
141
  scheduler=scheduler,
142
  )
143
  self.vae_scale_factor = (
144
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
 
 
145
  )
146
  # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
147
  # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
148
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
 
 
149
  self.default_sample_size = 128
150
  self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
151
 
152
  def _get_t5_prompt_embeds(
153
  self,
154
- prompt: Union[str, List[str]] = None,
155
  num_images_per_prompt: int = 1,
156
  max_sequence_length: int = 128,
157
- device: Optional[torch.device] = None,
158
- dtype: Optional[torch.dtype] = None,
159
  ):
160
  device = device or self._execution_device
161
  dtype = dtype or self.text_encoder_3.dtype
@@ -173,33 +193,47 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
173
  )
174
  text_input_ids = text_inputs.input_ids
175
  attention_mask = text_inputs.attention_mask
176
- untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
177
-
178
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
179
- removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1])
 
 
 
 
 
 
 
 
 
 
180
  logger.warning(
181
  "The following part of your input was truncated because `max_sequence_length` is set to "
182
  f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
183
  )
184
 
185
- prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
 
 
186
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
187
  _, seq_len, _ = prompt_embeds.shape
188
 
189
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
190
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
192
  return prompt_embeds
193
-
194
  def _get_clip_prompt_embeds(
195
  self,
196
  tokenizer,
197
  text_encoder,
198
- prompt: Union[str, List[str]],
199
  num_images_per_prompt: int = 1,
200
  max_sequence_length: int = 128,
201
- device: Optional[torch.device] = None,
202
- dtype: Optional[torch.dtype] = None,
203
  ):
204
  device = device or self._execution_device
205
  dtype = dtype or text_encoder.dtype
@@ -216,14 +250,20 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
216
  )
217
 
218
  text_input_ids = text_inputs.input_ids
219
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
220
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
 
 
 
 
221
  removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
222
  logger.warning(
223
  "The following part of your input was truncated because CLIP can only handle sequences up to"
224
  f" {218} tokens: {removed_text}"
225
  )
226
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
 
 
227
 
228
  # Use pooled output of CLIPTextModel
229
  prompt_embeds = prompt_embeds[0]
@@ -234,14 +274,14 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
234
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
235
 
236
  return prompt_embeds
237
-
238
  def _get_llama3_prompt_embeds(
239
  self,
240
- prompt: Union[str, List[str]] = None,
241
  num_images_per_prompt: int = 1,
242
  max_sequence_length: int = 128,
243
- device: Optional[torch.device] = None,
244
- dtype: Optional[torch.dtype] = None,
245
  ):
246
  device = device or self._execution_device
247
  dtype = dtype or self.text_encoder_4.dtype
@@ -259,20 +299,30 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
259
  )
260
  text_input_ids = text_inputs.input_ids
261
  attention_mask = text_inputs.attention_mask
262
- untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
263
-
264
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
265
- removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1])
 
 
 
 
 
 
 
 
 
 
266
  logger.warning(
267
  "The following part of your input was truncated because `max_sequence_length` is set to "
268
  f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
269
  )
270
 
271
  outputs = self.text_encoder_4(
272
- text_input_ids.to(device),
273
- attention_mask=attention_mask.to(device),
274
  output_hidden_states=True,
275
- output_attentions=True
276
  )
277
 
278
  prompt_embeds = outputs.hidden_states[1:]
@@ -281,47 +331,46 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
281
 
282
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
283
  prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
284
- prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
 
 
285
  return prompt_embeds
286
-
287
  def encode_prompt(
288
  self,
289
- prompt: Union[str, List[str]],
290
- prompt_2: Union[str, List[str]],
291
- prompt_3: Union[str, List[str]],
292
- prompt_4: Union[str, List[str]],
293
- device: Optional[torch.device] = None,
294
- dtype: Optional[torch.dtype] = None,
295
  num_images_per_prompt: int = 1,
296
  do_classifier_free_guidance: bool = True,
297
- negative_prompt: Optional[Union[str, List[str]]] = None,
298
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
299
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
300
- negative_prompt_4: Optional[Union[str, List[str]]] = None,
301
- prompt_embeds: Optional[List[torch.FloatTensor]] = None,
302
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
303
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
304
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
305
  max_sequence_length: int = 128,
306
- lora_scale: Optional[float] = None,
307
  ):
308
  prompt = [prompt] if isinstance(prompt, str) else prompt
309
- if prompt is not None:
310
- batch_size = len(prompt)
311
- else:
312
- batch_size = prompt_embeds.shape[0]
313
 
314
  prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
315
- prompt = prompt,
316
- prompt_2 = prompt_2,
317
- prompt_3 = prompt_3,
318
- prompt_4 = prompt_4,
319
- device = device,
320
- dtype = dtype,
321
- num_images_per_prompt = num_images_per_prompt,
322
- prompt_embeds = prompt_embeds,
323
- pooled_prompt_embeds = pooled_prompt_embeds,
324
- max_sequence_length = max_sequence_length,
325
  )
326
 
327
  if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -331,58 +380,75 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
331
  negative_prompt_4 = negative_prompt_4 or negative_prompt
332
 
333
  # normalize str to list
334
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
 
 
 
 
335
  negative_prompt_2 = (
336
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
 
 
337
  )
338
  negative_prompt_3 = (
339
- batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
 
 
340
  )
341
  negative_prompt_4 = (
342
- batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
 
 
343
  )
344
 
345
  if prompt is not None and type(prompt) is not type(negative_prompt):
346
- raise TypeError(
347
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
348
  f" {type(prompt)}."
349
  )
350
- elif batch_size != len(negative_prompt):
351
- raise ValueError(
 
352
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
353
  f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
354
  " the batch size of `prompt`."
355
  )
356
-
 
357
  negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
358
- prompt = negative_prompt,
359
- prompt_2 = negative_prompt_2,
360
- prompt_3 = negative_prompt_3,
361
- prompt_4 = negative_prompt_4,
362
- device = device,
363
- dtype = dtype,
364
- num_images_per_prompt = num_images_per_prompt,
365
- prompt_embeds = negative_prompt_embeds,
366
- pooled_prompt_embeds = negative_pooled_prompt_embeds,
367
- max_sequence_length = max_sequence_length,
368
  )
369
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
 
 
 
 
 
370
 
371
  def _encode_prompt(
372
  self,
373
- prompt: Union[str, List[str]],
374
- prompt_2: Union[str, List[str]],
375
- prompt_3: Union[str, List[str]],
376
- prompt_4: Union[str, List[str]],
377
- device: Optional[torch.device] = None,
378
- dtype: Optional[torch.dtype] = None,
379
  num_images_per_prompt: int = 1,
380
- prompt_embeds: Optional[List[torch.FloatTensor]] = None,
381
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
382
  max_sequence_length: int = 128,
383
  ):
384
  device = device or self._execution_device
385
-
386
  if prompt_embeds is None:
387
  prompt_2 = prompt_2 or prompt
388
  prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
@@ -396,38 +462,40 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
396
  pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
397
  self.tokenizer,
398
  self.text_encoder,
399
- prompt = prompt,
400
- num_images_per_prompt = num_images_per_prompt,
401
- max_sequence_length = max_sequence_length,
402
- device = device,
403
- dtype = dtype,
404
  )
405
 
406
  pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
407
  self.tokenizer_2,
408
  self.text_encoder_2,
409
- prompt = prompt_2,
410
- num_images_per_prompt = num_images_per_prompt,
411
- max_sequence_length = max_sequence_length,
412
- device = device,
413
- dtype = dtype,
414
  )
415
 
416
- pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
 
 
417
 
418
  t5_prompt_embeds = self._get_t5_prompt_embeds(
419
- prompt = prompt_3,
420
- num_images_per_prompt = num_images_per_prompt,
421
- max_sequence_length = max_sequence_length,
422
- device = device,
423
- dtype = dtype
424
  )
425
  llama3_prompt_embeds = self._get_llama3_prompt_embeds(
426
- prompt = prompt_4,
427
- num_images_per_prompt = num_images_per_prompt,
428
- max_sequence_length = max_sequence_length,
429
- device = device,
430
- dtype = dtype
431
  )
432
  prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
433
 
@@ -481,25 +549,28 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
481
  shape = (batch_size, num_channels_latents, height, width)
482
 
483
  if latents is None:
484
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
485
  else:
486
  if latents.shape != shape:
487
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
 
488
  latents = latents.to(device)
489
  return latents
490
-
491
  @property
492
  def guidance_scale(self):
493
  return self._guidance_scale
494
-
495
  @property
496
  def do_classifier_free_guidance(self):
497
  return self._guidance_scale > 1
498
-
499
  @property
500
  def joint_attention_kwargs(self):
501
  return self._joint_attention_kwargs
502
-
503
  @property
504
  def num_timesteps(self):
505
  return self._num_timesteps
@@ -507,37 +578,39 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
507
  @property
508
  def interrupt(self):
509
  return self._interrupt
510
-
511
  @torch.no_grad()
512
  def __call__(
513
  self,
514
- prompt: Union[str, List[str]] = None,
515
- prompt_2: Optional[Union[str, List[str]]] = None,
516
- prompt_3: Optional[Union[str, List[str]]] = None,
517
- prompt_4: Optional[Union[str, List[str]]] = None,
518
- height: Optional[int] = None,
519
- width: Optional[int] = None,
520
  num_inference_steps: int = 50,
521
- sigmas: Optional[List[float]] = None,
522
  guidance_scale: float = 5.0,
523
- negative_prompt: Optional[Union[str, List[str]]] = None,
524
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
525
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
526
- negative_prompt_4: Optional[Union[str, List[str]]] = None,
527
- num_images_per_prompt: Optional[int] = 1,
528
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
529
- latents: Optional[torch.FloatTensor] = None,
530
- prompt_embeds: Optional[torch.FloatTensor] = None,
531
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
532
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
533
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
534
- output_type: Optional[str] = "pil",
535
  return_dict: bool = True,
536
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
537
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
538
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
539
  max_sequence_length: int = 128,
540
  ):
 
 
541
  height = height or self.default_sample_size * self.vae_scale_factor
542
  width = width or self.default_sample_size * self.vae_scale_factor
543
 
@@ -545,7 +618,10 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
545
  S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
546
  scale = S_max / (width * height)
547
  scale = math.sqrt(scale)
548
- width, height = int(width * scale // division * division), int(height * scale // division * division)
 
 
 
549
 
550
  self._guidance_scale = guidance_scale
551
  self._joint_attention_kwargs = joint_attention_kwargs
@@ -562,7 +638,9 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
562
  device = self._execution_device
563
 
564
  lora_scale = (
565
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
 
 
566
  )
567
  (
568
  prompt_embeds,
@@ -591,13 +669,15 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
591
 
592
  if self.do_classifier_free_guidance:
593
  prompt_embeds_arr = []
594
- for n, p in zip(negative_prompt_embeds, prompt_embeds):
595
  if len(n.shape) == 3:
596
  prompt_embeds_arr.append(torch.cat([n, p], dim=0))
597
  else:
598
  prompt_embeds_arr.append(torch.cat([n, p], dim=1))
599
  prompt_embeds = prompt_embeds_arr
600
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
 
 
601
 
602
  # 4. Prepare latent variables
603
  num_channels_latents = self.transformer.config.in_channels
@@ -614,18 +694,21 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
614
 
615
  if latents.shape[-2] != latents.shape[-1]:
616
  B, C, H, W = latents.shape
617
- pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
 
 
 
618
 
619
  img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
620
  img_ids = torch.zeros(pH, pW, 3)
621
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
622
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
623
  img_ids = img_ids.reshape(pH * pW, -1)
624
  img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
625
- img_ids_pad[:pH*pW, :] = img_ids
626
 
627
- img_sizes = img_sizes.unsqueeze(0).to(latents.device)
628
- img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
629
  if self.do_classifier_free_guidance:
630
  img_sizes = img_sizes.repeat(2 * B, 1)
631
  img_ids = img_ids.repeat(2 * B, 1, 1)
@@ -636,7 +719,9 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
636
  mu = calculate_shift(self.transformer.max_seq)
637
  scheduler_kwargs = {"mu": mu}
638
  if isinstance(self.scheduler, FlowUniPCMultistepScheduler):
639
- self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu))
 
 
640
  timesteps = self.scheduler.timesteps
641
  else:
642
  timesteps, num_inference_steps = retrieve_timesteps(
@@ -646,7 +731,9 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
646
  sigmas=sigmas,
647
  **scheduler_kwargs,
648
  )
649
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
 
650
  self._num_timesteps = len(timesteps)
651
 
652
  # 6. Denoising loop
@@ -656,7 +743,11 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
656
  continue
657
 
658
  # expand the latents if we are doing classifier free guidance
659
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
 
 
 
 
660
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
661
  timestep = t.expand(latent_model_input.shape[0])
662
 
@@ -665,33 +756,42 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
665
  patch_size = self.transformer.config.patch_size
666
  pH, pW = H // patch_size, W // patch_size
667
  out = torch.zeros(
668
- (B, C, self.transformer.max_seq, patch_size * patch_size),
669
- dtype=latent_model_input.dtype,
670
- device=latent_model_input.device
671
  )
672
- latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
673
- out[:, :, 0:pH*pW] = latent_model_input
 
 
 
 
 
674
  latent_model_input = out
675
 
676
  noise_pred = self.transformer(
677
- hidden_states = latent_model_input,
678
- timesteps = timestep,
679
- encoder_hidden_states = prompt_embeds,
680
- pooled_embeds = pooled_prompt_embeds,
681
- img_sizes = img_sizes,
682
- img_ids = img_ids,
683
- return_dict = False,
684
  )[0]
685
  noise_pred = -noise_pred
686
 
687
  # perform guidance
688
  if self.do_classifier_free_guidance:
689
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
690
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
691
 
692
  # compute the previous noisy sample x_t -> x_t-1
693
  latents_dtype = latents.dtype
694
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
695
 
696
  if latents.dtype != latents_dtype:
697
  if torch.backends.mps.is_available():
@@ -706,10 +806,14 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
706
 
707
  latents = callback_outputs.pop("latents", latents)
708
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
709
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
 
 
710
 
711
  # call the callback, if provided
712
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
713
  progress_bar.update()
714
 
715
  if XLA_AVAILABLE:
@@ -719,7 +823,9 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
719
  image = latents
720
 
721
  else:
722
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
 
 
723
 
724
  image = self.vae.decode(latents, return_dict=False)[0]
725
  image = self.image_processor.postprocess(image, output_type=output_type)
@@ -730,4 +836,4 @@ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
730
  if not return_dict:
731
  return (image,)
732
 
733
- return HiDreamImagePipelineOutput(images=image)
 
1
  import inspect
 
2
  import math
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
  import einops
7
  import torch
 
 
 
 
 
 
 
 
 
8
  from diffusers.image_processor import VaeImageProcessor
9
  from diffusers.loaders import FromSingleFileMixin
10
  from diffusers.models.autoencoders import AutoencoderKL
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
13
  from diffusers.utils import (
 
14
  is_torch_xla_available,
15
  logging,
16
  )
17
  from diffusers.utils.torch_utils import randn_tensor
18
+ from transformers import (
19
+ CLIPTextModelWithProjection,
20
+ CLIPTokenizer,
21
+ LlamaForCausalLM,
22
+ PreTrainedTokenizerFast,
23
+ T5EncoderModel,
24
+ T5Tokenizer,
25
+ )
26
+
27
+ from ...models.transformers.transformer_hidream_image import (
28
+ HiDreamImageTransformer2DModel,
29
+ )
30
  from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+ from .pipeline_output import HiDreamImagePipelineOutput
32
 
33
  if is_torch_xla_available():
34
  import torch_xla.core.xla_model as xm
 
39
 
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
42
+
43
  # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
44
  def calculate_shift(
45
  image_seq_len,
 
53
  mu = image_seq_len * m + b
54
  return mu
55
 
56
+
57
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
58
  def retrieve_timesteps(
59
  scheduler,
60
+ num_inference_steps: int | None = None,
61
+ device: str | torch.device | None = None,
62
+ timesteps: list[int] | None = None,
63
+ sigmas: list[float] | None = None,
64
  **kwargs,
65
  ):
66
  r"""
 
85
  Returns:
86
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
87
  second element is the number of inference steps.
88
+
89
  """
90
  if timesteps is not None and sigmas is not None:
91
+ msg = "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
92
+ raise ValueError(msg)
93
  if timesteps is not None:
94
+ accepts_timesteps = "timesteps" in set(
95
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
96
+ )
97
  if not accepts_timesteps:
98
+ msg = (
99
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
100
  f" timestep schedules. Please check whether you are using the correct scheduler."
101
  )
102
+ raise ValueError(msg)
103
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
104
  timesteps = scheduler.timesteps
105
  num_inference_steps = len(timesteps)
106
  elif sigmas is not None:
107
+ accept_sigmas = "sigmas" in set(
108
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
109
+ )
110
  if not accept_sigmas:
111
+ msg = (
112
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
113
  f" sigmas schedules. Please check whether you are using the correct scheduler."
114
  )
115
+ raise ValueError(msg)
116
  scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
117
  timesteps = scheduler.timesteps
118
  num_inference_steps = len(timesteps)
 
121
  timesteps = scheduler.timesteps
122
  return timesteps, num_inference_steps
123
 
124
+
125
  class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
126
  model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae"
127
  _optional_components = ["image_encoder", "feature_extractor"]
 
129
 
130
  def __init__(
131
  self,
132
+ transformer: HiDreamImageTransformer2DModel,
133
  scheduler: FlowMatchEulerDiscreteScheduler,
134
  vae: AutoencoderKL,
135
  text_encoder: CLIPTextModelWithProjection,
 
144
  super().__init__()
145
 
146
  self.register_modules(
147
+ transformer=transformer,
148
  vae=vae,
149
  text_encoder=text_encoder,
150
  text_encoder_2=text_encoder_2,
 
157
  scheduler=scheduler,
158
  )
159
  self.vae_scale_factor = (
160
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
161
+ if hasattr(self, "vae") and self.vae is not None
162
+ else 8
163
  )
164
  # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
165
  # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
166
+ self.image_processor = VaeImageProcessor(
167
+ vae_scale_factor=self.vae_scale_factor * 2
168
+ )
169
  self.default_sample_size = 128
170
  self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
171
 
172
  def _get_t5_prompt_embeds(
173
  self,
174
+ prompt: str | list[str] | None = None,
175
  num_images_per_prompt: int = 1,
176
  max_sequence_length: int = 128,
177
+ device: torch.device | None = None,
178
+ dtype: torch.dtype | None = None,
179
  ):
180
  device = device or self._execution_device
181
  dtype = dtype or self.text_encoder_3.dtype
 
193
  )
194
  text_input_ids = text_inputs.input_ids
195
  attention_mask = text_inputs.attention_mask
196
+ untruncated_ids = self.tokenizer_3(
197
+ prompt, padding="longest", return_tensors="pt"
198
+ ).input_ids
199
+
200
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
201
+ text_input_ids, untruncated_ids
202
+ ):
203
+ removed_text = self.tokenizer_3.batch_decode(
204
+ untruncated_ids[
205
+ :,
206
+ min(max_sequence_length, self.tokenizer_3.model_max_length)
207
+ - 1 : -1,
208
+ ]
209
+ )
210
  logger.warning(
211
  "The following part of your input was truncated because `max_sequence_length` is set to "
212
  f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
213
  )
214
 
215
+ prompt_embeds = self.text_encoder_3(
216
+ text_input_ids.to(device), attention_mask=attention_mask.to(device)
217
+ )[0]
218
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
219
  _, seq_len, _ = prompt_embeds.shape
220
 
221
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
222
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
223
+ prompt_embeds = prompt_embeds.view(
224
+ batch_size * num_images_per_prompt, seq_len, -1
225
+ )
226
  return prompt_embeds
227
+
228
  def _get_clip_prompt_embeds(
229
  self,
230
  tokenizer,
231
  text_encoder,
232
+ prompt: str | list[str],
233
  num_images_per_prompt: int = 1,
234
  max_sequence_length: int = 128,
235
+ device: torch.device | None = None,
236
+ dtype: torch.dtype | None = None,
237
  ):
238
  device = device or self._execution_device
239
  dtype = dtype or text_encoder.dtype
 
250
  )
251
 
252
  text_input_ids = text_inputs.input_ids
253
+ untruncated_ids = tokenizer(
254
+ prompt, padding="longest", return_tensors="pt"
255
+ ).input_ids
256
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
257
+ text_input_ids, untruncated_ids
258
+ ):
259
  removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
260
  logger.warning(
261
  "The following part of your input was truncated because CLIP can only handle sequences up to"
262
  f" {218} tokens: {removed_text}"
263
  )
264
+ prompt_embeds = text_encoder(
265
+ text_input_ids.to(device), output_hidden_states=True
266
+ )
267
 
268
  # Use pooled output of CLIPTextModel
269
  prompt_embeds = prompt_embeds[0]
 
274
  prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
275
 
276
  return prompt_embeds
277
+
278
  def _get_llama3_prompt_embeds(
279
  self,
280
+ prompt: str | list[str] | None = None,
281
  num_images_per_prompt: int = 1,
282
  max_sequence_length: int = 128,
283
+ device: torch.device | None = None,
284
+ dtype: torch.dtype | None = None,
285
  ):
286
  device = device or self._execution_device
287
  dtype = dtype or self.text_encoder_4.dtype
 
299
  )
300
  text_input_ids = text_inputs.input_ids
301
  attention_mask = text_inputs.attention_mask
302
+ untruncated_ids = self.tokenizer_4(
303
+ prompt, padding="longest", return_tensors="pt"
304
+ ).input_ids
305
+
306
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
307
+ text_input_ids, untruncated_ids
308
+ ):
309
+ removed_text = self.tokenizer_4.batch_decode(
310
+ untruncated_ids[
311
+ :,
312
+ min(max_sequence_length, self.tokenizer_4.model_max_length)
313
+ - 1 : -1,
314
+ ]
315
+ )
316
  logger.warning(
317
  "The following part of your input was truncated because `max_sequence_length` is set to "
318
  f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
319
  )
320
 
321
  outputs = self.text_encoder_4(
322
+ text_input_ids.to(device),
323
+ attention_mask=attention_mask.to(device),
324
  output_hidden_states=True,
325
+ output_attentions=True,
326
  )
327
 
328
  prompt_embeds = outputs.hidden_states[1:]
 
331
 
332
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
333
  prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
334
+ prompt_embeds = prompt_embeds.view(
335
+ -1, batch_size * num_images_per_prompt, seq_len, dim
336
+ )
337
  return prompt_embeds
338
+
339
  def encode_prompt(
340
  self,
341
+ prompt: str | list[str],
342
+ prompt_2: str | list[str],
343
+ prompt_3: str | list[str],
344
+ prompt_4: str | list[str],
345
+ device: torch.device | None = None,
346
+ dtype: torch.dtype | None = None,
347
  num_images_per_prompt: int = 1,
348
  do_classifier_free_guidance: bool = True,
349
+ negative_prompt: str | list[str] | None = None,
350
+ negative_prompt_2: str | list[str] | None = None,
351
+ negative_prompt_3: str | list[str] | None = None,
352
+ negative_prompt_4: str | list[str] | None = None,
353
+ prompt_embeds: list[torch.FloatTensor] | None = None,
354
+ negative_prompt_embeds: torch.FloatTensor | None = None,
355
+ pooled_prompt_embeds: torch.FloatTensor | None = None,
356
+ negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
357
  max_sequence_length: int = 128,
358
+ lora_scale: float | None = None,
359
  ):
360
  prompt = [prompt] if isinstance(prompt, str) else prompt
361
+ batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
 
 
 
362
 
363
  prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
364
+ prompt=prompt,
365
+ prompt_2=prompt_2,
366
+ prompt_3=prompt_3,
367
+ prompt_4=prompt_4,
368
+ device=device,
369
+ dtype=dtype,
370
+ num_images_per_prompt=num_images_per_prompt,
371
+ prompt_embeds=prompt_embeds,
372
+ pooled_prompt_embeds=pooled_prompt_embeds,
373
+ max_sequence_length=max_sequence_length,
374
  )
375
 
376
  if do_classifier_free_guidance and negative_prompt_embeds is None:
 
380
  negative_prompt_4 = negative_prompt_4 or negative_prompt
381
 
382
  # normalize str to list
383
+ negative_prompt = (
384
+ batch_size * [negative_prompt]
385
+ if isinstance(negative_prompt, str)
386
+ else negative_prompt
387
+ )
388
  negative_prompt_2 = (
389
+ batch_size * [negative_prompt_2]
390
+ if isinstance(negative_prompt_2, str)
391
+ else negative_prompt_2
392
  )
393
  negative_prompt_3 = (
394
+ batch_size * [negative_prompt_3]
395
+ if isinstance(negative_prompt_3, str)
396
+ else negative_prompt_3
397
  )
398
  negative_prompt_4 = (
399
+ batch_size * [negative_prompt_4]
400
+ if isinstance(negative_prompt_4, str)
401
+ else negative_prompt_4
402
  )
403
 
404
  if prompt is not None and type(prompt) is not type(negative_prompt):
405
+ msg = (
406
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
407
  f" {type(prompt)}."
408
  )
409
+ raise TypeError(msg)
410
+ if batch_size != len(negative_prompt):
411
+ msg = (
412
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
413
  f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
414
  " the batch size of `prompt`."
415
  )
416
+ raise ValueError(msg)
417
+
418
  negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
419
+ prompt=negative_prompt,
420
+ prompt_2=negative_prompt_2,
421
+ prompt_3=negative_prompt_3,
422
+ prompt_4=negative_prompt_4,
423
+ device=device,
424
+ dtype=dtype,
425
+ num_images_per_prompt=num_images_per_prompt,
426
+ prompt_embeds=negative_prompt_embeds,
427
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
428
+ max_sequence_length=max_sequence_length,
429
  )
430
+ return (
431
+ prompt_embeds,
432
+ negative_prompt_embeds,
433
+ pooled_prompt_embeds,
434
+ negative_pooled_prompt_embeds,
435
+ )
436
 
437
  def _encode_prompt(
438
  self,
439
+ prompt: str | list[str],
440
+ prompt_2: str | list[str],
441
+ prompt_3: str | list[str],
442
+ prompt_4: str | list[str],
443
+ device: torch.device | None = None,
444
+ dtype: torch.dtype | None = None,
445
  num_images_per_prompt: int = 1,
446
+ prompt_embeds: list[torch.FloatTensor] | None = None,
447
+ pooled_prompt_embeds: torch.FloatTensor | None = None,
448
  max_sequence_length: int = 128,
449
  ):
450
  device = device or self._execution_device
451
+
452
  if prompt_embeds is None:
453
  prompt_2 = prompt_2 or prompt
454
  prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
 
462
  pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
463
  self.tokenizer,
464
  self.text_encoder,
465
+ prompt=prompt,
466
+ num_images_per_prompt=num_images_per_prompt,
467
+ max_sequence_length=max_sequence_length,
468
+ device=device,
469
+ dtype=dtype,
470
  )
471
 
472
  pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
473
  self.tokenizer_2,
474
  self.text_encoder_2,
475
+ prompt=prompt_2,
476
+ num_images_per_prompt=num_images_per_prompt,
477
+ max_sequence_length=max_sequence_length,
478
+ device=device,
479
+ dtype=dtype,
480
  )
481
 
482
+ pooled_prompt_embeds = torch.cat(
483
+ [pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1
484
+ )
485
 
486
  t5_prompt_embeds = self._get_t5_prompt_embeds(
487
+ prompt=prompt_3,
488
+ num_images_per_prompt=num_images_per_prompt,
489
+ max_sequence_length=max_sequence_length,
490
+ device=device,
491
+ dtype=dtype,
492
  )
493
  llama3_prompt_embeds = self._get_llama3_prompt_embeds(
494
+ prompt=prompt_4,
495
+ num_images_per_prompt=num_images_per_prompt,
496
+ max_sequence_length=max_sequence_length,
497
+ device=device,
498
+ dtype=dtype,
499
  )
500
  prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
501
 
 
549
  shape = (batch_size, num_channels_latents, height, width)
550
 
551
  if latents is None:
552
+ latents = randn_tensor(
553
+ shape, generator=generator, device=device, dtype=dtype
554
+ )
555
  else:
556
  if latents.shape != shape:
557
+ msg = f"Unexpected latents shape, got {latents.shape}, expected {shape}"
558
+ raise ValueError(msg)
559
  latents = latents.to(device)
560
  return latents
561
+
562
  @property
563
  def guidance_scale(self):
564
  return self._guidance_scale
565
+
566
  @property
567
  def do_classifier_free_guidance(self):
568
  return self._guidance_scale > 1
569
+
570
  @property
571
  def joint_attention_kwargs(self):
572
  return self._joint_attention_kwargs
573
+
574
  @property
575
  def num_timesteps(self):
576
  return self._num_timesteps
 
578
  @property
579
  def interrupt(self):
580
  return self._interrupt
581
+
582
  @torch.no_grad()
583
  def __call__(
584
  self,
585
+ prompt: str | list[str] | None = None,
586
+ prompt_2: str | list[str] | None = None,
587
+ prompt_3: str | list[str] | None = None,
588
+ prompt_4: str | list[str] | None = None,
589
+ height: int | None = None,
590
+ width: int | None = None,
591
  num_inference_steps: int = 50,
592
+ sigmas: list[float] | None = None,
593
  guidance_scale: float = 5.0,
594
+ negative_prompt: str | list[str] | None = None,
595
+ negative_prompt_2: str | list[str] | None = None,
596
+ negative_prompt_3: str | list[str] | None = None,
597
+ negative_prompt_4: str | list[str] | None = None,
598
+ num_images_per_prompt: int | None = 1,
599
+ generator: torch.Generator | list[torch.Generator] | None = None,
600
+ latents: torch.FloatTensor | None = None,
601
+ prompt_embeds: torch.FloatTensor | None = None,
602
+ negative_prompt_embeds: torch.FloatTensor | None = None,
603
+ pooled_prompt_embeds: torch.FloatTensor | None = None,
604
+ negative_pooled_prompt_embeds: torch.FloatTensor | None = None,
605
+ output_type: str | None = "pil",
606
  return_dict: bool = True,
607
+ joint_attention_kwargs: dict[str, Any] | None = None,
608
+ callback_on_step_end: Callable[[int, int, dict], None] | None = None,
609
+ callback_on_step_end_tensor_inputs: list[str] | None = None,
610
  max_sequence_length: int = 128,
611
  ):
612
+ if callback_on_step_end_tensor_inputs is None:
613
+ callback_on_step_end_tensor_inputs = ["latents"]
614
  height = height or self.default_sample_size * self.vae_scale_factor
615
  width = width or self.default_sample_size * self.vae_scale_factor
616
 
 
618
  S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
619
  scale = S_max / (width * height)
620
  scale = math.sqrt(scale)
621
+ width, height = (
622
+ int(width * scale // division * division),
623
+ int(height * scale // division * division),
624
+ )
625
 
626
  self._guidance_scale = guidance_scale
627
  self._joint_attention_kwargs = joint_attention_kwargs
 
638
  device = self._execution_device
639
 
640
  lora_scale = (
641
+ self.joint_attention_kwargs.get("scale", None)
642
+ if self.joint_attention_kwargs is not None
643
+ else None
644
  )
645
  (
646
  prompt_embeds,
 
669
 
670
  if self.do_classifier_free_guidance:
671
  prompt_embeds_arr = []
672
+ for n, p in zip(negative_prompt_embeds, prompt_embeds, strict=False):
673
  if len(n.shape) == 3:
674
  prompt_embeds_arr.append(torch.cat([n, p], dim=0))
675
  else:
676
  prompt_embeds_arr.append(torch.cat([n, p], dim=1))
677
  prompt_embeds = prompt_embeds_arr
678
+ pooled_prompt_embeds = torch.cat(
679
+ [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
680
+ )
681
 
682
  # 4. Prepare latent variables
683
  num_channels_latents = self.transformer.config.in_channels
 
694
 
695
  if latents.shape[-2] != latents.shape[-1]:
696
  B, C, H, W = latents.shape
697
+ pH, pW = (
698
+ H // self.transformer.config.patch_size,
699
+ W // self.transformer.config.patch_size,
700
+ )
701
 
702
  img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
703
  img_ids = torch.zeros(pH, pW, 3)
704
+ img_ids[..., 1] += torch.arange(pH)[:, None]
705
+ img_ids[..., 2] += torch.arange(pW)[None, :]
706
  img_ids = img_ids.reshape(pH * pW, -1)
707
  img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
708
+ img_ids_pad[: pH * pW, :] = img_ids
709
 
710
+ img_sizes = img_sizes.unsqueeze(0).to(latents.device)
711
+ img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
712
  if self.do_classifier_free_guidance:
713
  img_sizes = img_sizes.repeat(2 * B, 1)
714
  img_ids = img_ids.repeat(2 * B, 1, 1)
 
719
  mu = calculate_shift(self.transformer.max_seq)
720
  scheduler_kwargs = {"mu": mu}
721
  if isinstance(self.scheduler, FlowUniPCMultistepScheduler):
722
+ self.scheduler.set_timesteps(
723
+ num_inference_steps, device=device, shift=math.exp(mu)
724
+ )
725
  timesteps = self.scheduler.timesteps
726
  else:
727
  timesteps, num_inference_steps = retrieve_timesteps(
 
731
  sigmas=sigmas,
732
  **scheduler_kwargs,
733
  )
734
+ num_warmup_steps = max(
735
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
736
+ )
737
  self._num_timesteps = len(timesteps)
738
 
739
  # 6. Denoising loop
 
743
  continue
744
 
745
  # expand the latents if we are doing classifier free guidance
746
+ latent_model_input = (
747
+ torch.cat([latents] * 2)
748
+ if self.do_classifier_free_guidance
749
+ else latents
750
+ )
751
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
752
  timestep = t.expand(latent_model_input.shape[0])
753
 
 
756
  patch_size = self.transformer.config.patch_size
757
  pH, pW = H // patch_size, W // patch_size
758
  out = torch.zeros(
759
+ (B, C, self.transformer.max_seq, patch_size * patch_size),
760
+ dtype=latent_model_input.dtype,
761
+ device=latent_model_input.device,
762
  )
763
+ latent_model_input = einops.rearrange(
764
+ latent_model_input,
765
+ "B C (H p1) (W p2) -> B C (H W) (p1 p2)",
766
+ p1=patch_size,
767
+ p2=patch_size,
768
+ )
769
+ out[:, :, 0 : pH * pW] = latent_model_input
770
  latent_model_input = out
771
 
772
  noise_pred = self.transformer(
773
+ hidden_states=latent_model_input,
774
+ timesteps=timestep,
775
+ encoder_hidden_states=prompt_embeds,
776
+ pooled_embeds=pooled_prompt_embeds,
777
+ img_sizes=img_sizes,
778
+ img_ids=img_ids,
779
+ return_dict=False,
780
  )[0]
781
  noise_pred = -noise_pred
782
 
783
  # perform guidance
784
  if self.do_classifier_free_guidance:
785
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
786
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
787
+ noise_pred_text - noise_pred_uncond
788
+ )
789
 
790
  # compute the previous noisy sample x_t -> x_t-1
791
  latents_dtype = latents.dtype
792
+ latents = self.scheduler.step(
793
+ noise_pred, t, latents, return_dict=False
794
+ )[0]
795
 
796
  if latents.dtype != latents_dtype:
797
  if torch.backends.mps.is_available():
 
806
 
807
  latents = callback_outputs.pop("latents", latents)
808
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
809
+ negative_prompt_embeds = callback_outputs.pop(
810
+ "negative_prompt_embeds", negative_prompt_embeds
811
+ )
812
 
813
  # call the callback, if provided
814
+ if i == len(timesteps) - 1 or (
815
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
816
+ ):
817
  progress_bar.update()
818
 
819
  if XLA_AVAILABLE:
 
823
  image = latents
824
 
825
  else:
826
+ latents = (
827
+ latents / self.vae.config.scaling_factor
828
+ ) + self.vae.config.shift_factor
829
 
830
  image = self.vae.decode(latents, return_dict=False)[0]
831
  image = self.image_processor.postprocess(image, output_type=output_type)
 
836
  if not return_dict:
837
  return (image,)
838
 
839
+ return HiDreamImagePipelineOutput(images=image)