LPX55 commited on
Commit
f7aaef1
·
1 Parent(s): 78d50d2

fix: structruring

Browse files
Files changed (5) hide show
  1. app.py +1 -1
  2. nf4.py +4 -4
  3. pipeline_hidream_image.py +686 -479
  4. pipeline_output.py +21 -0
  5. transformer_hidream_image.py +526 -0
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import logging
6
  from diffusers import DiffusionPipeline
7
  from transformer_hidream_image import HiDreamImageTransformer2DModel
8
- from pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline
9
  import subprocess
10
 
11
  try:
 
5
  import logging
6
  from diffusers import DiffusionPipeline
7
  from transformer_hidream_image import HiDreamImageTransformer2DModel
8
+ from pipeline_hidream_image import HiDreamImagePipeline
9
  import subprocess
10
 
11
  try:
nf4.py CHANGED
@@ -1,10 +1,10 @@
1
  import torch
2
  from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
3
 
4
- from . import HiDreamImagePipeline
5
- from . import HiDreamImageTransformer2DModel
6
- from .schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
7
- from .schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
8
 
9
 
10
  MODEL_PREFIX = "azaneko"
 
1
  import torch
2
  from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
3
 
4
+ from pipeline_hidream_image import HiDreamImagePipeline
5
+ from transformer_hidream_image import HiDreamImageTransformer2DModel
6
+ from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
7
+ from schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
8
 
9
 
10
  MODEL_PREFIX = "azaneko"
pipeline_hidream_image.py CHANGED
@@ -1,526 +1,733 @@
1
- from typing import Any, Dict, Optional, Tuple, List
2
-
3
- import torch
4
- import torch.nn as nn
5
  import einops
6
- from einops import repeat
7
-
8
- from diffusers.configuration_utils import ConfigMixin, register_to_config
9
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
- from diffusers.models.modeling_utils import ModelMixin
11
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
12
- from diffusers.utils.torch_utils import maybe_allow_in_graph
13
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
14
- from models.embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed
15
- from models.attention import HiDreamAttention, FeedForwardSwiGLU
16
- from models.attention_processor import HiDreamAttnProcessor_flashattn
17
- from models.moe import MOEFeedForwardSwiGLU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
21
- class TextProjection(nn.Module):
22
- def __init__(self, in_features, hidden_size):
23
- super().__init__()
24
- self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
25
-
26
- def forward(self, caption):
27
- hidden_states = self.linear(caption)
28
- return hidden_states
29
-
30
- class BlockType:
31
- TransformerBlock = 1
32
- SingleTransformerBlock = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- @maybe_allow_in_graph
35
- class HiDreamImageSingleTransformerBlock(nn.Module):
36
  def __init__(
37
  self,
38
- dim: int,
39
- num_attention_heads: int,
40
- attention_head_dim: int,
41
- num_routed_experts: int = 4,
42
- num_activated_experts: int = 2
 
 
 
 
 
43
  ):
44
  super().__init__()
45
- self.num_attention_heads = num_attention_heads
46
- self.adaLN_modulation = nn.Sequential(
47
- nn.SiLU(),
48
- nn.Linear(dim, 6 * dim, bias=True)
 
 
 
 
 
 
 
 
49
  )
50
- nn.init.zeros_(self.adaLN_modulation[1].weight)
51
- nn.init.zeros_(self.adaLN_modulation[1].bias)
52
-
53
- # 1. Attention
54
- self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
55
- self.attn1 = HiDreamAttention(
56
- query_dim=dim,
57
- heads=num_attention_heads,
58
- dim_head=attention_head_dim,
59
- processor = HiDreamAttnProcessor_flashattn(),
60
- single = True
61
  )
 
 
 
 
 
62
 
63
- # 3. Feed-forward
64
- self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
65
- if num_routed_experts > 0:
66
- self.ff_i = MOEFeedForwardSwiGLU(
67
- dim = dim,
68
- hidden_dim = 4 * dim,
69
- num_routed_experts = num_routed_experts,
70
- num_activated_experts = num_activated_experts,
71
- )
72
- else:
73
- self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
74
-
75
- def forward(
76
  self,
77
- image_tokens: torch.FloatTensor,
78
- image_tokens_masks: Optional[torch.FloatTensor] = None,
79
- text_tokens: Optional[torch.FloatTensor] = None,
80
- adaln_input: Optional[torch.FloatTensor] = None,
81
- rope: torch.FloatTensor = None,
82
-
83
- ) -> torch.FloatTensor:
84
- wtype = image_tokens.dtype
85
- shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
86
- self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
87
-
88
- # 1. MM-Attention
89
- norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
90
- norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
91
- attn_output_i = self.attn1(
92
- norm_image_tokens,
93
- image_tokens_masks,
94
- rope = rope,
 
95
  )
96
- image_tokens = gate_msa_i * attn_output_i + image_tokens
97
-
98
- # 2. Feed-forward
99
- norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
100
- norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
101
- ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
102
- image_tokens = ff_output_i + image_tokens
103
- return image_tokens
104
-
105
- @maybe_allow_in_graph
106
- class HiDreamImageTransformerBlock(nn.Module):
107
- def __init__(
 
 
 
 
 
 
 
 
 
108
  self,
109
- dim: int,
110
- num_attention_heads: int,
111
- attention_head_dim: int,
112
- num_routed_experts: int = 4,
113
- num_activated_experts: int = 2
 
 
114
  ):
115
- super().__init__()
116
- self.num_attention_heads = num_attention_heads
117
- self.adaLN_modulation = nn.Sequential(
118
- nn.SiLU(),
119
- nn.Linear(dim, 12 * dim, bias=True)
120
- )
121
- nn.init.zeros_(self.adaLN_modulation[1].weight)
122
- nn.init.zeros_(self.adaLN_modulation[1].bias)
123
-
124
- # 1. Attention
125
- self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
126
- self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
127
- self.attn1 = HiDreamAttention(
128
- query_dim=dim,
129
- heads=num_attention_heads,
130
- dim_head=attention_head_dim,
131
- processor = HiDreamAttnProcessor_flashattn(),
132
- single = False
133
  )
134
 
135
- # 3. Feed-forward
136
- self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
137
- if num_routed_experts > 0:
138
- self.ff_i = MOEFeedForwardSwiGLU(
139
- dim = dim,
140
- hidden_dim = 4 * dim,
141
- num_routed_experts = num_routed_experts,
142
- num_activated_experts = num_activated_experts,
143
  )
144
- else:
145
- self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
146
- self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
147
- self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
 
 
 
 
 
 
 
148
 
149
- def forward(
150
  self,
151
- image_tokens: torch.FloatTensor,
152
- image_tokens_masks: Optional[torch.FloatTensor] = None,
153
- text_tokens: Optional[torch.FloatTensor] = None,
154
- adaln_input: Optional[torch.FloatTensor] = None,
155
- rope: torch.FloatTensor = None,
156
- ) -> torch.FloatTensor:
157
- wtype = image_tokens.dtype
158
- shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
159
- shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
160
- self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
161
-
162
- # 1. MM-Attention
163
- norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
164
- norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
165
- norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
166
- norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
167
-
168
- attn_output_i, attn_output_t = self.attn1(
169
- norm_image_tokens,
170
- image_tokens_masks,
171
- norm_text_tokens,
172
- rope = rope,
173
  )
 
 
 
 
 
 
 
 
 
 
174
 
175
- image_tokens = gate_msa_i * attn_output_i + image_tokens
176
- text_tokens = gate_msa_t * attn_output_t + text_tokens
177
-
178
- # 2. Feed-forward
179
- norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
180
- norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
181
- norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
182
- norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
183
-
184
- ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
185
- ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
186
- image_tokens = ff_output_i + image_tokens
187
- text_tokens = ff_output_t + text_tokens
188
- return image_tokens, text_tokens
 
189
 
190
- @maybe_allow_in_graph
191
- class HiDreamImageBlock(nn.Module):
192
- def __init__(
193
  self,
194
- dim: int,
195
- num_attention_heads: int,
196
- attention_head_dim: int,
197
- num_routed_experts: int = 4,
198
- num_activated_experts: int = 2,
199
- block_type: BlockType = BlockType.TransformerBlock,
 
 
 
 
 
 
 
 
 
 
 
 
200
  ):
201
- super().__init__()
202
- block_classes = {
203
- BlockType.TransformerBlock: HiDreamImageTransformerBlock,
204
- BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
205
- }
206
- self.block = block_classes[block_type](
207
- dim,
208
- num_attention_heads,
209
- attention_head_dim,
210
- num_routed_experts,
211
- num_activated_experts
 
 
 
 
 
 
212
  )
213
-
214
- def forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  self,
216
- image_tokens: torch.FloatTensor,
217
- image_tokens_masks: Optional[torch.FloatTensor] = None,
218
- text_tokens: Optional[torch.FloatTensor] = None,
219
- adaln_input: torch.FloatTensor = None,
220
- rope: torch.FloatTensor = None,
221
- ) -> torch.FloatTensor:
222
- return self.block(
223
- image_tokens,
224
- image_tokens_masks,
225
- text_tokens,
226
- adaln_input,
227
- rope,
228
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- class HiDreamImageTransformer2DModel(
231
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
232
- ):
233
- _supports_gradient_checkpointing = True
234
- _no_split_modules = ["HiDreamImageBlock"]
 
 
 
 
235
 
236
- @register_to_config
237
- def __init__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  self,
239
- patch_size: Optional[int] = None,
240
- in_channels: int = 64,
241
- out_channels: Optional[int] = None,
242
- num_layers: int = 16,
243
- num_single_layers: int = 32,
244
- attention_head_dim: int = 128,
245
- num_attention_heads: int = 20,
246
- caption_channels: List[int] = None,
247
- text_emb_dim: int = 2048,
248
- num_routed_experts: int = 4,
249
- num_activated_experts: int = 2,
250
- axes_dims_rope: Tuple[int, int] = (32, 32),
251
- max_resolution: Tuple[int, int] = (128, 128),
252
- llama_layers: List[int] = None,
253
  ):
254
- super().__init__()
255
- self.out_channels = out_channels or in_channels
256
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
257
- self.llama_layers = llama_layers
258
-
259
- self.t_embedder = TimestepEmbed(self.inner_dim)
260
- self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim)
261
- self.x_embedder = PatchEmbed(
262
- patch_size = patch_size,
263
- in_channels = in_channels,
264
- out_channels = self.inner_dim,
265
- )
266
- self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
267
-
268
- self.double_stream_blocks = nn.ModuleList(
269
- [
270
- HiDreamImageBlock(
271
- dim = self.inner_dim,
272
- num_attention_heads = self.config.num_attention_heads,
273
- attention_head_dim = self.config.attention_head_dim,
274
- num_routed_experts = num_routed_experts,
275
- num_activated_experts = num_activated_experts,
276
- block_type = BlockType.TransformerBlock
277
- )
278
- for i in range(self.config.num_layers)
279
- ]
280
- )
281
 
282
- self.single_stream_blocks = nn.ModuleList(
283
- [
284
- HiDreamImageBlock(
285
- dim = self.inner_dim,
286
- num_attention_heads = self.config.num_attention_heads,
287
- attention_head_dim = self.config.attention_head_dim,
288
- num_routed_experts = num_routed_experts,
289
- num_activated_experts = num_activated_experts,
290
- block_type = BlockType.SingleTransformerBlock
291
- )
292
- for i in range(self.config.num_single_layers)
293
- ]
294
- )
295
 
296
- self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels)
297
-
298
- caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
299
- caption_projection = []
300
- for caption_channel in caption_channels:
301
- caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim))
302
- self.caption_projection = nn.ModuleList(caption_projection)
303
- self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
304
-
305
- self.gradient_checkpointing = False
306
-
307
- def _set_gradient_checkpointing(self, module, value=False):
308
- if hasattr(module, "gradient_checkpointing"):
309
- module.gradient_checkpointing = value
310
-
311
- def expand_timesteps(self, timesteps, batch_size, device):
312
- if not torch.is_tensor(timesteps):
313
- is_mps = device.type == "mps"
314
- if isinstance(timesteps, float):
315
- dtype = torch.float32 if is_mps else torch.float64
316
- else:
317
- dtype = torch.int32 if is_mps else torch.int64
318
- timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
319
- elif len(timesteps.shape) == 0:
320
- timesteps = timesteps[None].to(device)
321
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
322
- timesteps = timesteps.expand(batch_size)
323
- return timesteps
324
-
325
- def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
326
- if is_training:
327
- x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
328
  else:
329
- x_arr = []
330
- for i, img_size in enumerate(img_sizes):
331
- pH, pW = img_size
332
- x_arr.append(
333
- einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
334
- p1=self.config.patch_size, p2=self.config.patch_size)
335
- )
336
- x = torch.cat(x_arr, dim=0)
337
- return x
338
-
339
- def patchify(self, x, max_seq, img_sizes=None):
340
- pz2 = self.config.patch_size * self.config.patch_size
341
- if isinstance(x, torch.Tensor):
342
- B, C = x.shape[0], x.shape[1]
343
- device = x.device
344
- dtype = x.dtype
345
- else:
346
- B, C = len(x), x[0].shape[0]
347
- device = x[0].device
348
- dtype = x[0].dtype
349
- x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
350
-
351
- if img_sizes is not None:
352
- for i, img_size in enumerate(img_sizes):
353
- x_masks[i, 0:img_size[0] * img_size[1]] = 1
354
- x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
355
- elif isinstance(x, torch.Tensor):
356
- pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size
357
- x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size)
358
- img_sizes = [[pH, pW]] * B
359
- x_masks = None
360
- else:
361
- raise NotImplementedError
362
- return x, x_masks, img_sizes
363
 
364
- def forward(
 
 
 
 
 
365
  self,
366
- hidden_states: torch.Tensor,
367
- timesteps: torch.LongTensor = None,
368
- encoder_hidden_states: torch.Tensor = None,
369
- pooled_embeds: torch.Tensor = None,
370
- img_sizes: Optional[List[Tuple[int, int]]] = None,
371
- img_ids: Optional[torch.Tensor] = None,
372
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  return_dict: bool = True,
 
 
 
 
374
  ):
375
- if joint_attention_kwargs is not None:
376
- joint_attention_kwargs = joint_attention_kwargs.copy()
377
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  else:
379
- lora_scale = 1.0
380
 
381
- if USE_PEFT_BACKEND:
382
- # weight the lora layers by setting `lora_scale` for each PEFT layer
383
- scale_lora_layers(self, lora_scale)
384
- else:
385
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
386
- logger.warning(
387
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
388
- )
389
 
390
- # spatial forward
391
- batch_size = hidden_states.shape[0]
392
- hidden_states_type = hidden_states.dtype
393
-
394
- # 0. time
395
- timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
396
- timesteps = self.t_embedder(timesteps, hidden_states_type)
397
- p_embedder = self.p_embedder(pooled_embeds)
398
- adaln_input = timesteps + p_embedder
399
-
400
- hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
401
- if image_tokens_masks is None:
402
- pH, pW = img_sizes[0]
403
- img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
404
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
405
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
406
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
407
- hidden_states = self.x_embedder(hidden_states)
408
-
409
- T5_encoder_hidden_states = encoder_hidden_states[0]
410
- encoder_hidden_states = encoder_hidden_states[-1]
411
- encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
412
-
413
- if self.caption_projection is not None:
414
- new_encoder_hidden_states = []
415
- for i, enc_hidden_state in enumerate(encoder_hidden_states):
416
- enc_hidden_state = self.caption_projection[i](enc_hidden_state)
417
- enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
418
- new_encoder_hidden_states.append(enc_hidden_state)
419
- encoder_hidden_states = new_encoder_hidden_states
420
- T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
421
- T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
422
- encoder_hidden_states.append(T5_encoder_hidden_states)
423
-
424
- txt_ids = torch.zeros(
425
- batch_size,
426
- encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
427
- 3,
428
- device=img_ids.device, dtype=img_ids.dtype
429
  )
430
- ids = torch.cat((img_ids, txt_ids), dim=1)
431
- rope = self.pe_embedder(ids)
432
-
433
- # 2. Blocks
434
- block_id = 0
435
- initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
436
- initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
437
- for bid, block in enumerate(self.double_stream_blocks):
438
- cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
439
- cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
440
- if self.training and self.gradient_checkpointing:
441
- def create_custom_forward(module, return_dict=None):
442
- def custom_forward(*inputs):
443
- if return_dict is not None:
444
- return module(*inputs, return_dict=return_dict)
445
- else:
446
- return module(*inputs)
447
- return custom_forward
448
-
449
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
450
- hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
451
- create_custom_forward(block),
452
- hidden_states,
453
- image_tokens_masks,
454
- cur_encoder_hidden_states,
455
- adaln_input,
456
- rope,
457
- **ckpt_kwargs,
458
- )
459
- else:
460
- hidden_states, initial_encoder_hidden_states = block(
461
- image_tokens = hidden_states,
462
- image_tokens_masks = image_tokens_masks,
463
- text_tokens = cur_encoder_hidden_states,
464
- adaln_input = adaln_input,
465
- rope = rope,
466
- )
467
- initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
468
- block_id += 1
469
-
470
- image_tokens_seq_len = hidden_states.shape[1]
471
- hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
472
- hidden_states_seq_len = hidden_states.shape[1]
473
- if image_tokens_masks is not None:
474
- encoder_attention_mask_ones = torch.ones(
475
- (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
476
- device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  )
478
- image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
479
-
480
- for bid, block in enumerate(self.single_stream_blocks):
481
- cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
482
- hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
483
- if self.training and self.gradient_checkpointing:
484
- def create_custom_forward(module, return_dict=None):
485
- def custom_forward(*inputs):
486
- if return_dict is not None:
487
- return module(*inputs, return_dict=return_dict)
488
- else:
489
- return module(*inputs)
490
- return custom_forward
491
-
492
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
493
- hidden_states = torch.utils.checkpoint.checkpoint(
494
- create_custom_forward(block),
495
- hidden_states,
496
- image_tokens_masks,
497
- None,
498
- adaln_input,
499
- rope,
500
- **ckpt_kwargs,
501
- )
502
- else:
503
- hidden_states = block(
504
- image_tokens = hidden_states,
505
- image_tokens_masks = image_tokens_masks,
506
- text_tokens = None,
507
- adaln_input = adaln_input,
508
- rope = rope,
509
- )
510
- hidden_states = hidden_states[:, :hidden_states_seq_len]
511
- block_id += 1
512
-
513
- hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
514
- output = self.final_layer(hidden_states, adaln_input)
515
- output = self.unpatchify(output, img_sizes, self.training)
516
- if image_tokens_masks is not None:
517
- image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
- if USE_PEFT_BACKEND:
520
- # remove `lora_scale` from each PEFT layer
521
- unscale_lora_layers(self, lora_scale)
 
 
 
 
 
522
 
523
  if not return_dict:
524
- return (output, image_tokens_masks)
525
- return Transformer2DModelOutput(sample=output, mask=image_tokens_masks)
526
-
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ import math
 
4
  import einops
5
+ import torch
6
+ from transformers import (
7
+ CLIPTextModelWithProjection,
8
+ CLIPTokenizer,
9
+ T5EncoderModel,
10
+ T5Tokenizer,
11
+ LlamaForCausalLM,
12
+ PreTrainedTokenizerFast
13
+ )
14
+
15
+ from diffusers.image_processor import VaeImageProcessor
16
+ from diffusers.loaders import FromSingleFileMixin
17
+ from diffusers.models.autoencoders import AutoencoderKL
18
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
19
+ from diffusers.utils import (
20
+ USE_PEFT_BACKEND,
21
+ is_torch_xla_available,
22
+ logging,
23
+ )
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from pipeline_output import HiDreamImagePipelineOutput
27
+ from transformer_hidream_image import HiDreamImageTransformer2DModel
28
+ from schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
+
30
+ if is_torch_xla_available():
31
+ import torch_xla.core.xla_model as xm
32
+
33
+ XLA_AVAILABLE = True
34
+ else:
35
+ XLA_AVAILABLE = False
36
 
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
 
39
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
40
+ def calculate_shift(
41
+ image_seq_len,
42
+ base_seq_len: int = 256,
43
+ max_seq_len: int = 4096,
44
+ base_shift: float = 0.5,
45
+ max_shift: float = 1.15,
46
+ ):
47
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
48
+ b = base_shift - m * base_seq_len
49
+ mu = image_seq_len * m + b
50
+ return mu
51
+
52
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
53
+ def retrieve_timesteps(
54
+ scheduler,
55
+ num_inference_steps: Optional[int] = None,
56
+ device: Optional[Union[str, torch.device]] = None,
57
+ timesteps: Optional[List[int]] = None,
58
+ sigmas: Optional[List[float]] = None,
59
+ **kwargs,
60
+ ):
61
+ r"""
62
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
63
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
64
+
65
+ Args:
66
+ scheduler (`SchedulerMixin`):
67
+ The scheduler to get timesteps from.
68
+ num_inference_steps (`int`):
69
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
70
+ must be `None`.
71
+ device (`str` or `torch.device`, *optional*):
72
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
73
+ timesteps (`List[int]`, *optional*):
74
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
75
+ `num_inference_steps` and `sigmas` must be `None`.
76
+ sigmas (`List[float]`, *optional*):
77
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
78
+ `num_inference_steps` and `timesteps` must be `None`.
79
+
80
+ Returns:
81
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
82
+ second element is the number of inference steps.
83
+ """
84
+ if timesteps is not None and sigmas is not None:
85
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
86
+ if timesteps is not None:
87
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
88
+ if not accepts_timesteps:
89
+ raise ValueError(
90
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
91
+ f" timestep schedules. Please check whether you are using the correct scheduler."
92
+ )
93
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
94
+ timesteps = scheduler.timesteps
95
+ num_inference_steps = len(timesteps)
96
+ elif sigmas is not None:
97
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
98
+ if not accept_sigmas:
99
+ raise ValueError(
100
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
101
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
102
+ )
103
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
104
+ timesteps = scheduler.timesteps
105
+ num_inference_steps = len(timesteps)
106
+ else:
107
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
108
+ timesteps = scheduler.timesteps
109
+ return timesteps, num_inference_steps
110
+
111
+ class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin):
112
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae"
113
+ _optional_components = ["image_encoder", "feature_extractor"]
114
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
115
 
 
 
116
  def __init__(
117
  self,
118
+ scheduler: FlowMatchEulerDiscreteScheduler,
119
+ vae: AutoencoderKL,
120
+ text_encoder: CLIPTextModelWithProjection,
121
+ tokenizer: CLIPTokenizer,
122
+ text_encoder_2: CLIPTextModelWithProjection,
123
+ tokenizer_2: CLIPTokenizer,
124
+ text_encoder_3: T5EncoderModel,
125
+ tokenizer_3: T5Tokenizer,
126
+ text_encoder_4: LlamaForCausalLM,
127
+ tokenizer_4: PreTrainedTokenizerFast,
128
  ):
129
  super().__init__()
130
+
131
+ self.register_modules(
132
+ vae=vae,
133
+ text_encoder=text_encoder,
134
+ text_encoder_2=text_encoder_2,
135
+ text_encoder_3=text_encoder_3,
136
+ text_encoder_4=text_encoder_4,
137
+ tokenizer=tokenizer,
138
+ tokenizer_2=tokenizer_2,
139
+ tokenizer_3=tokenizer_3,
140
+ tokenizer_4=tokenizer_4,
141
+ scheduler=scheduler,
142
  )
143
+ self.vae_scale_factor = (
144
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
 
 
 
 
 
 
 
 
 
145
  )
146
+ # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
147
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
148
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
149
+ self.default_sample_size = 128
150
+ self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
151
 
152
+ def _get_t5_prompt_embeds(
 
 
 
 
 
 
 
 
 
 
 
 
153
  self,
154
+ prompt: Union[str, List[str]] = None,
155
+ num_images_per_prompt: int = 1,
156
+ max_sequence_length: int = 128,
157
+ device: Optional[torch.device] = None,
158
+ dtype: Optional[torch.dtype] = None,
159
+ ):
160
+ device = device or self._execution_device
161
+ dtype = dtype or self.text_encoder_3.dtype
162
+
163
+ prompt = [prompt] if isinstance(prompt, str) else prompt
164
+ batch_size = len(prompt)
165
+
166
+ text_inputs = self.tokenizer_3(
167
+ prompt,
168
+ padding="max_length",
169
+ max_length=min(max_sequence_length, self.tokenizer_3.model_max_length),
170
+ truncation=True,
171
+ add_special_tokens=True,
172
+ return_tensors="pt",
173
  )
174
+ text_input_ids = text_inputs.input_ids
175
+ attention_mask = text_inputs.attention_mask
176
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
177
+
178
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
179
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1])
180
+ logger.warning(
181
+ "The following part of your input was truncated because `max_sequence_length` is set to "
182
+ f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
183
+ )
184
+
185
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
186
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
187
+ _, seq_len, _ = prompt_embeds.shape
188
+
189
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
190
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
191
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
192
+ return prompt_embeds
193
+
194
+ def _get_clip_prompt_embeds(
195
  self,
196
+ tokenizer,
197
+ text_encoder,
198
+ prompt: Union[str, List[str]],
199
+ num_images_per_prompt: int = 1,
200
+ max_sequence_length: int = 128,
201
+ device: Optional[torch.device] = None,
202
+ dtype: Optional[torch.dtype] = None,
203
  ):
204
+ device = device or self._execution_device
205
+ dtype = dtype or text_encoder.dtype
206
+
207
+ prompt = [prompt] if isinstance(prompt, str) else prompt
208
+ batch_size = len(prompt)
209
+
210
+ text_inputs = tokenizer(
211
+ prompt,
212
+ padding="max_length",
213
+ max_length=min(max_sequence_length, 218),
214
+ truncation=True,
215
+ return_tensors="pt",
 
 
 
 
 
 
216
  )
217
 
218
+ text_input_ids = text_inputs.input_ids
219
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
220
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
221
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
222
+ logger.warning(
223
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
224
+ f" {218} tokens: {removed_text}"
 
225
  )
226
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
227
+
228
+ # Use pooled output of CLIPTextModel
229
+ prompt_embeds = prompt_embeds[0]
230
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
231
+
232
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
233
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
234
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
235
+
236
+ return prompt_embeds
237
 
238
+ def _get_llama3_prompt_embeds(
239
  self,
240
+ prompt: Union[str, List[str]] = None,
241
+ num_images_per_prompt: int = 1,
242
+ max_sequence_length: int = 128,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ):
246
+ device = device or self._execution_device
247
+ dtype = dtype or self.text_encoder_4.dtype
248
+
249
+ prompt = [prompt] if isinstance(prompt, str) else prompt
250
+ batch_size = len(prompt)
251
+
252
+ text_inputs = self.tokenizer_4(
253
+ prompt,
254
+ padding="max_length",
255
+ max_length=min(max_sequence_length, self.tokenizer_4.model_max_length),
256
+ truncation=True,
257
+ add_special_tokens=True,
258
+ return_tensors="pt",
 
 
 
259
  )
260
+ text_input_ids = text_inputs.input_ids
261
+ attention_mask = text_inputs.attention_mask
262
+ untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
263
+
264
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
265
+ removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1])
266
+ logger.warning(
267
+ "The following part of your input was truncated because `max_sequence_length` is set to "
268
+ f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
269
+ )
270
 
271
+ outputs = self.text_encoder_4(
272
+ text_input_ids.to(device),
273
+ attention_mask=attention_mask.to(device),
274
+ output_hidden_states=True,
275
+ output_attentions=True
276
+ )
277
+
278
+ prompt_embeds = outputs.hidden_states[1:]
279
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
280
+ _, _, seq_len, dim = prompt_embeds.shape
281
+
282
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
283
+ prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
284
+ prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
285
+ return prompt_embeds
286
 
287
+ def encode_prompt(
 
 
288
  self,
289
+ prompt: Union[str, List[str]],
290
+ prompt_2: Union[str, List[str]],
291
+ prompt_3: Union[str, List[str]],
292
+ prompt_4: Union[str, List[str]],
293
+ device: Optional[torch.device] = None,
294
+ dtype: Optional[torch.dtype] = None,
295
+ num_images_per_prompt: int = 1,
296
+ do_classifier_free_guidance: bool = True,
297
+ negative_prompt: Optional[Union[str, List[str]]] = None,
298
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
299
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
300
+ negative_prompt_4: Optional[Union[str, List[str]]] = None,
301
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
302
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
303
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
304
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
305
+ max_sequence_length: int = 128,
306
+ lora_scale: Optional[float] = None,
307
  ):
308
+ prompt = [prompt] if isinstance(prompt, str) else prompt
309
+ if prompt is not None:
310
+ batch_size = len(prompt)
311
+ else:
312
+ batch_size = prompt_embeds.shape[0]
313
+
314
+ prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
315
+ prompt = prompt,
316
+ prompt_2 = prompt_2,
317
+ prompt_3 = prompt_3,
318
+ prompt_4 = prompt_4,
319
+ device = device,
320
+ dtype = dtype,
321
+ num_images_per_prompt = num_images_per_prompt,
322
+ prompt_embeds = prompt_embeds,
323
+ pooled_prompt_embeds = pooled_prompt_embeds,
324
+ max_sequence_length = max_sequence_length,
325
  )
326
+
327
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
328
+ negative_prompt = negative_prompt or ""
329
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
330
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
331
+ negative_prompt_4 = negative_prompt_4 or negative_prompt
332
+
333
+ # normalize str to list
334
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
335
+ negative_prompt_2 = (
336
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
337
+ )
338
+ negative_prompt_3 = (
339
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
340
+ )
341
+ negative_prompt_4 = (
342
+ batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
343
+ )
344
+
345
+ if prompt is not None and type(prompt) is not type(negative_prompt):
346
+ raise TypeError(
347
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
348
+ f" {type(prompt)}."
349
+ )
350
+ elif batch_size != len(negative_prompt):
351
+ raise ValueError(
352
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
353
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
354
+ " the batch size of `prompt`."
355
+ )
356
+
357
+ negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
358
+ prompt = negative_prompt,
359
+ prompt_2 = negative_prompt_2,
360
+ prompt_3 = negative_prompt_3,
361
+ prompt_4 = negative_prompt_4,
362
+ device = device,
363
+ dtype = dtype,
364
+ num_images_per_prompt = num_images_per_prompt,
365
+ prompt_embeds = negative_prompt_embeds,
366
+ pooled_prompt_embeds = negative_pooled_prompt_embeds,
367
+ max_sequence_length = max_sequence_length,
368
+ )
369
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
370
+
371
+ def _encode_prompt(
372
  self,
373
+ prompt: Union[str, List[str]],
374
+ prompt_2: Union[str, List[str]],
375
+ prompt_3: Union[str, List[str]],
376
+ prompt_4: Union[str, List[str]],
377
+ device: Optional[torch.device] = None,
378
+ dtype: Optional[torch.dtype] = None,
379
+ num_images_per_prompt: int = 1,
380
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
381
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
382
+ max_sequence_length: int = 128,
383
+ ):
384
+ device = device or self._execution_device
385
+
386
+ if prompt_embeds is None:
387
+ prompt_2 = prompt_2 or prompt
388
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
389
+
390
+ prompt_3 = prompt_3 or prompt
391
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
392
+
393
+ prompt_4 = prompt_4 or prompt
394
+ prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
395
+
396
+ pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
397
+ self.tokenizer,
398
+ self.text_encoder,
399
+ prompt = prompt,
400
+ num_images_per_prompt = num_images_per_prompt,
401
+ max_sequence_length = max_sequence_length,
402
+ device = device,
403
+ dtype = dtype,
404
+ )
405
 
406
+ pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
407
+ self.tokenizer_2,
408
+ self.text_encoder_2,
409
+ prompt = prompt_2,
410
+ num_images_per_prompt = num_images_per_prompt,
411
+ max_sequence_length = max_sequence_length,
412
+ device = device,
413
+ dtype = dtype,
414
+ )
415
 
416
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
417
+
418
+ t5_prompt_embeds = self._get_t5_prompt_embeds(
419
+ prompt = prompt_3,
420
+ num_images_per_prompt = num_images_per_prompt,
421
+ max_sequence_length = max_sequence_length,
422
+ device = device,
423
+ dtype = dtype
424
+ )
425
+ llama3_prompt_embeds = self._get_llama3_prompt_embeds(
426
+ prompt = prompt_4,
427
+ num_images_per_prompt = num_images_per_prompt,
428
+ max_sequence_length = max_sequence_length,
429
+ device = device,
430
+ dtype = dtype
431
+ )
432
+ prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
433
+
434
+ return prompt_embeds, pooled_prompt_embeds
435
+
436
+ def enable_vae_slicing(self):
437
+ r"""
438
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
439
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
440
+ """
441
+ self.vae.enable_slicing()
442
+
443
+ def disable_vae_slicing(self):
444
+ r"""
445
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
446
+ computing decoding in one step.
447
+ """
448
+ self.vae.disable_slicing()
449
+
450
+ def enable_vae_tiling(self):
451
+ r"""
452
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
453
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
454
+ processing larger images.
455
+ """
456
+ self.vae.enable_tiling()
457
+
458
+ def disable_vae_tiling(self):
459
+ r"""
460
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
461
+ computing decoding in one step.
462
+ """
463
+ self.vae.disable_tiling()
464
+
465
+ def prepare_latents(
466
  self,
467
+ batch_size,
468
+ num_channels_latents,
469
+ height,
470
+ width,
471
+ dtype,
472
+ device,
473
+ generator,
474
+ latents=None,
 
 
 
 
 
 
475
  ):
476
+ # VAE applies 8x compression on images but we must also account for packing which requires
477
+ # latent height and width to be divisible by 2.
478
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
479
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
+ shape = (batch_size, num_channels_latents, height, width)
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
+ if latents is None:
484
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  else:
486
+ if latents.shape != shape:
487
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
488
+ latents = latents.to(device)
489
+ return latents
490
+
491
+ @property
492
+ def guidance_scale(self):
493
+ return self._guidance_scale
494
+
495
+ @property
496
+ def do_classifier_free_guidance(self):
497
+ return self._guidance_scale > 1
498
+
499
+ @property
500
+ def joint_attention_kwargs(self):
501
+ return self._joint_attention_kwargs
502
+
503
+ @property
504
+ def num_timesteps(self):
505
+ return self._num_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
+ @property
508
+ def interrupt(self):
509
+ return self._interrupt
510
+
511
+ @torch.no_grad()
512
+ def __call__(
513
  self,
514
+ prompt: Union[str, List[str]] = None,
515
+ prompt_2: Optional[Union[str, List[str]]] = None,
516
+ prompt_3: Optional[Union[str, List[str]]] = None,
517
+ prompt_4: Optional[Union[str, List[str]]] = None,
518
+ height: Optional[int] = None,
519
+ width: Optional[int] = None,
520
+ num_inference_steps: int = 50,
521
+ sigmas: Optional[List[float]] = None,
522
+ guidance_scale: float = 5.0,
523
+ negative_prompt: Optional[Union[str, List[str]]] = None,
524
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
525
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
526
+ negative_prompt_4: Optional[Union[str, List[str]]] = None,
527
+ num_images_per_prompt: Optional[int] = 1,
528
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
529
+ latents: Optional[torch.FloatTensor] = None,
530
+ prompt_embeds: Optional[torch.FloatTensor] = None,
531
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
532
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
533
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
534
+ output_type: Optional[str] = "pil",
535
  return_dict: bool = True,
536
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
537
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
538
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
539
+ max_sequence_length: int = 128,
540
  ):
541
+ height = height or self.default_sample_size * self.vae_scale_factor
542
+ width = width or self.default_sample_size * self.vae_scale_factor
543
+
544
+ division = self.vae_scale_factor * 2
545
+ S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
546
+ scale = S_max / (width * height)
547
+ scale = math.sqrt(scale)
548
+ width, height = int(width * scale // division * division), int(height * scale // division * division)
549
+
550
+ self._guidance_scale = guidance_scale
551
+ self._joint_attention_kwargs = joint_attention_kwargs
552
+ self._interrupt = False
553
+
554
+ # 2. Define call parameters
555
+ if prompt is not None and isinstance(prompt, str):
556
+ batch_size = 1
557
+ elif prompt is not None and isinstance(prompt, list):
558
+ batch_size = len(prompt)
559
  else:
560
+ batch_size = prompt_embeds.shape[0]
561
 
562
+ device = self._execution_device
 
 
 
 
 
 
 
563
 
564
+ lora_scale = (
565
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  )
567
+ (
568
+ prompt_embeds,
569
+ negative_prompt_embeds,
570
+ pooled_prompt_embeds,
571
+ negative_pooled_prompt_embeds,
572
+ ) = self.encode_prompt(
573
+ prompt=prompt,
574
+ prompt_2=prompt_2,
575
+ prompt_3=prompt_3,
576
+ prompt_4=prompt_4,
577
+ negative_prompt=negative_prompt,
578
+ negative_prompt_2=negative_prompt_2,
579
+ negative_prompt_3=negative_prompt_3,
580
+ negative_prompt_4=negative_prompt_4,
581
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
582
+ prompt_embeds=prompt_embeds,
583
+ negative_prompt_embeds=negative_prompt_embeds,
584
+ pooled_prompt_embeds=pooled_prompt_embeds,
585
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
586
+ device=device,
587
+ num_images_per_prompt=num_images_per_prompt,
588
+ max_sequence_length=max_sequence_length,
589
+ lora_scale=lora_scale,
590
+ )
591
+
592
+ if self.do_classifier_free_guidance:
593
+ prompt_embeds_arr = []
594
+ for n, p in zip(negative_prompt_embeds, prompt_embeds):
595
+ if len(n.shape) == 3:
596
+ prompt_embeds_arr.append(torch.cat([n, p], dim=0))
597
+ else:
598
+ prompt_embeds_arr.append(torch.cat([n, p], dim=1))
599
+ prompt_embeds = prompt_embeds_arr
600
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
601
+
602
+ # 4. Prepare latent variables
603
+ num_channels_latents = self.transformer.config.in_channels
604
+ latents = self.prepare_latents(
605
+ batch_size * num_images_per_prompt,
606
+ num_channels_latents,
607
+ height,
608
+ width,
609
+ pooled_prompt_embeds.dtype,
610
+ device,
611
+ generator,
612
+ latents,
613
+ )
614
+
615
+ if latents.shape[-2] != latents.shape[-1]:
616
+ B, C, H, W = latents.shape
617
+ pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
618
+
619
+ img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
620
+ img_ids = torch.zeros(pH, pW, 3)
621
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
622
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
623
+ img_ids = img_ids.reshape(pH * pW, -1)
624
+ img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
625
+ img_ids_pad[:pH*pW, :] = img_ids
626
+
627
+ img_sizes = img_sizes.unsqueeze(0).to(latents.device)
628
+ img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
629
+ if self.do_classifier_free_guidance:
630
+ img_sizes = img_sizes.repeat(2 * B, 1)
631
+ img_ids = img_ids.repeat(2 * B, 1, 1)
632
+ else:
633
+ img_sizes = img_ids = None
634
+
635
+ # 5. Prepare timesteps
636
+ mu = calculate_shift(self.transformer.max_seq)
637
+ scheduler_kwargs = {"mu": mu}
638
+ if isinstance(self.scheduler, FlowUniPCMultistepScheduler):
639
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu))
640
+ timesteps = self.scheduler.timesteps
641
+ else:
642
+ timesteps, num_inference_steps = retrieve_timesteps(
643
+ self.scheduler,
644
+ num_inference_steps,
645
+ device,
646
+ sigmas=sigmas,
647
+ **scheduler_kwargs,
648
  )
649
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
650
+ self._num_timesteps = len(timesteps)
651
+
652
+ # 6. Denoising loop
653
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
654
+ for i, t in enumerate(timesteps):
655
+ if self.interrupt:
656
+ continue
657
+
658
+ # expand the latents if we are doing classifier free guidance
659
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
660
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
661
+ timestep = t.expand(latent_model_input.shape[0])
662
+
663
+ if latent_model_input.shape[-2] != latent_model_input.shape[-1]:
664
+ B, C, H, W = latent_model_input.shape
665
+ patch_size = self.transformer.config.patch_size
666
+ pH, pW = H // patch_size, W // patch_size
667
+ out = torch.zeros(
668
+ (B, C, self.transformer.max_seq, patch_size * patch_size),
669
+ dtype=latent_model_input.dtype,
670
+ device=latent_model_input.device
671
+ )
672
+ latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size)
673
+ out[:, :, 0:pH*pW] = latent_model_input
674
+ latent_model_input = out
675
+
676
+ noise_pred = self.transformer(
677
+ hidden_states = latent_model_input,
678
+ timesteps = timestep,
679
+ encoder_hidden_states = prompt_embeds,
680
+ pooled_embeds = pooled_prompt_embeds,
681
+ img_sizes = img_sizes,
682
+ img_ids = img_ids,
683
+ return_dict = False,
684
+ )[0]
685
+ noise_pred = -noise_pred
686
+
687
+ # perform guidance
688
+ if self.do_classifier_free_guidance:
689
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
690
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
691
+
692
+ # compute the previous noisy sample x_t -> x_t-1
693
+ latents_dtype = latents.dtype
694
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
695
+
696
+ if latents.dtype != latents_dtype:
697
+ if torch.backends.mps.is_available():
698
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
699
+ latents = latents.to(latents_dtype)
700
+
701
+ if callback_on_step_end is not None:
702
+ callback_kwargs = {}
703
+ for k in callback_on_step_end_tensor_inputs:
704
+ callback_kwargs[k] = locals()[k]
705
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
706
+
707
+ latents = callback_outputs.pop("latents", latents)
708
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
709
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
710
+
711
+ # call the callback, if provided
712
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
713
+ progress_bar.update()
714
+
715
+ if XLA_AVAILABLE:
716
+ xm.mark_step()
717
+
718
+ if output_type == "latent":
719
+ image = latents
720
 
721
+ else:
722
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
723
+
724
+ image = self.vae.decode(latents, return_dict=False)[0]
725
+ image = self.image_processor.postprocess(image, output_type=output_type)
726
+
727
+ # Offload all models
728
+ self.maybe_free_model_hooks()
729
 
730
  if not return_dict:
731
+ return (image,)
732
+
733
+ return HiDreamImagePipelineOutput(images=image)
pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from diffusers.utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class HiDreamImagePipelineOutput(BaseOutput):
12
+ """
13
+ Output class for HiDreamImage pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
transformer_hidream_image.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import einops
6
+ from einops import repeat
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
12
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
13
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
14
+ from models.embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed
15
+ from models.attention import HiDreamAttention, FeedForwardSwiGLU
16
+ from models.attention_processor import HiDreamAttnProcessor_flashattn
17
+ from models.moe import MOEFeedForwardSwiGLU
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ class TextProjection(nn.Module):
22
+ def __init__(self, in_features, hidden_size):
23
+ super().__init__()
24
+ self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
25
+
26
+ def forward(self, caption):
27
+ hidden_states = self.linear(caption)
28
+ return hidden_states
29
+
30
+ class BlockType:
31
+ TransformerBlock = 1
32
+ SingleTransformerBlock = 2
33
+
34
+ @maybe_allow_in_graph
35
+ class HiDreamImageSingleTransformerBlock(nn.Module):
36
+ def __init__(
37
+ self,
38
+ dim: int,
39
+ num_attention_heads: int,
40
+ attention_head_dim: int,
41
+ num_routed_experts: int = 4,
42
+ num_activated_experts: int = 2
43
+ ):
44
+ super().__init__()
45
+ self.num_attention_heads = num_attention_heads
46
+ self.adaLN_modulation = nn.Sequential(
47
+ nn.SiLU(),
48
+ nn.Linear(dim, 6 * dim, bias=True)
49
+ )
50
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
51
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
52
+
53
+ # 1. Attention
54
+ self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
55
+ self.attn1 = HiDreamAttention(
56
+ query_dim=dim,
57
+ heads=num_attention_heads,
58
+ dim_head=attention_head_dim,
59
+ processor = HiDreamAttnProcessor_flashattn(),
60
+ single = True
61
+ )
62
+
63
+ # 3. Feed-forward
64
+ self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
65
+ if num_routed_experts > 0:
66
+ self.ff_i = MOEFeedForwardSwiGLU(
67
+ dim = dim,
68
+ hidden_dim = 4 * dim,
69
+ num_routed_experts = num_routed_experts,
70
+ num_activated_experts = num_activated_experts,
71
+ )
72
+ else:
73
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
74
+
75
+ def forward(
76
+ self,
77
+ image_tokens: torch.FloatTensor,
78
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
79
+ text_tokens: Optional[torch.FloatTensor] = None,
80
+ adaln_input: Optional[torch.FloatTensor] = None,
81
+ rope: torch.FloatTensor = None,
82
+
83
+ ) -> torch.FloatTensor:
84
+ wtype = image_tokens.dtype
85
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
86
+ self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
87
+
88
+ # 1. MM-Attention
89
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
90
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
91
+ attn_output_i = self.attn1(
92
+ norm_image_tokens,
93
+ image_tokens_masks,
94
+ rope = rope,
95
+ )
96
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
97
+
98
+ # 2. Feed-forward
99
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
100
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
101
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
102
+ image_tokens = ff_output_i + image_tokens
103
+ return image_tokens
104
+
105
+ @maybe_allow_in_graph
106
+ class HiDreamImageTransformerBlock(nn.Module):
107
+ def __init__(
108
+ self,
109
+ dim: int,
110
+ num_attention_heads: int,
111
+ attention_head_dim: int,
112
+ num_routed_experts: int = 4,
113
+ num_activated_experts: int = 2
114
+ ):
115
+ super().__init__()
116
+ self.num_attention_heads = num_attention_heads
117
+ self.adaLN_modulation = nn.Sequential(
118
+ nn.SiLU(),
119
+ nn.Linear(dim, 12 * dim, bias=True)
120
+ )
121
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
122
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
123
+
124
+ # 1. Attention
125
+ self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
126
+ self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
127
+ self.attn1 = HiDreamAttention(
128
+ query_dim=dim,
129
+ heads=num_attention_heads,
130
+ dim_head=attention_head_dim,
131
+ processor = HiDreamAttnProcessor_flashattn(),
132
+ single = False
133
+ )
134
+
135
+ # 3. Feed-forward
136
+ self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
137
+ if num_routed_experts > 0:
138
+ self.ff_i = MOEFeedForwardSwiGLU(
139
+ dim = dim,
140
+ hidden_dim = 4 * dim,
141
+ num_routed_experts = num_routed_experts,
142
+ num_activated_experts = num_activated_experts,
143
+ )
144
+ else:
145
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
146
+ self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
147
+ self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim)
148
+
149
+ def forward(
150
+ self,
151
+ image_tokens: torch.FloatTensor,
152
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
153
+ text_tokens: Optional[torch.FloatTensor] = None,
154
+ adaln_input: Optional[torch.FloatTensor] = None,
155
+ rope: torch.FloatTensor = None,
156
+ ) -> torch.FloatTensor:
157
+ wtype = image_tokens.dtype
158
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
159
+ shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
160
+ self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
161
+
162
+ # 1. MM-Attention
163
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
164
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
165
+ norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
166
+ norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
167
+
168
+ attn_output_i, attn_output_t = self.attn1(
169
+ norm_image_tokens,
170
+ image_tokens_masks,
171
+ norm_text_tokens,
172
+ rope = rope,
173
+ )
174
+
175
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
176
+ text_tokens = gate_msa_t * attn_output_t + text_tokens
177
+
178
+ # 2. Feed-forward
179
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
180
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
181
+ norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
182
+ norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
183
+
184
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
185
+ ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
186
+ image_tokens = ff_output_i + image_tokens
187
+ text_tokens = ff_output_t + text_tokens
188
+ return image_tokens, text_tokens
189
+
190
+ @maybe_allow_in_graph
191
+ class HiDreamImageBlock(nn.Module):
192
+ def __init__(
193
+ self,
194
+ dim: int,
195
+ num_attention_heads: int,
196
+ attention_head_dim: int,
197
+ num_routed_experts: int = 4,
198
+ num_activated_experts: int = 2,
199
+ block_type: BlockType = BlockType.TransformerBlock,
200
+ ):
201
+ super().__init__()
202
+ block_classes = {
203
+ BlockType.TransformerBlock: HiDreamImageTransformerBlock,
204
+ BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
205
+ }
206
+ self.block = block_classes[block_type](
207
+ dim,
208
+ num_attention_heads,
209
+ attention_head_dim,
210
+ num_routed_experts,
211
+ num_activated_experts
212
+ )
213
+
214
+ def forward(
215
+ self,
216
+ image_tokens: torch.FloatTensor,
217
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
218
+ text_tokens: Optional[torch.FloatTensor] = None,
219
+ adaln_input: torch.FloatTensor = None,
220
+ rope: torch.FloatTensor = None,
221
+ ) -> torch.FloatTensor:
222
+ return self.block(
223
+ image_tokens,
224
+ image_tokens_masks,
225
+ text_tokens,
226
+ adaln_input,
227
+ rope,
228
+ )
229
+
230
+ class HiDreamImageTransformer2DModel(
231
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
232
+ ):
233
+ _supports_gradient_checkpointing = True
234
+ _no_split_modules = ["HiDreamImageBlock"]
235
+
236
+ @register_to_config
237
+ def __init__(
238
+ self,
239
+ patch_size: Optional[int] = None,
240
+ in_channels: int = 64,
241
+ out_channels: Optional[int] = None,
242
+ num_layers: int = 16,
243
+ num_single_layers: int = 32,
244
+ attention_head_dim: int = 128,
245
+ num_attention_heads: int = 20,
246
+ caption_channels: List[int] = None,
247
+ text_emb_dim: int = 2048,
248
+ num_routed_experts: int = 4,
249
+ num_activated_experts: int = 2,
250
+ axes_dims_rope: Tuple[int, int] = (32, 32),
251
+ max_resolution: Tuple[int, int] = (128, 128),
252
+ llama_layers: List[int] = None,
253
+ ):
254
+ super().__init__()
255
+ self.out_channels = out_channels or in_channels
256
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
257
+ self.llama_layers = llama_layers
258
+
259
+ self.t_embedder = TimestepEmbed(self.inner_dim)
260
+ self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim)
261
+ self.x_embedder = PatchEmbed(
262
+ patch_size = patch_size,
263
+ in_channels = in_channels,
264
+ out_channels = self.inner_dim,
265
+ )
266
+ self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
267
+
268
+ self.double_stream_blocks = nn.ModuleList(
269
+ [
270
+ HiDreamImageBlock(
271
+ dim = self.inner_dim,
272
+ num_attention_heads = self.config.num_attention_heads,
273
+ attention_head_dim = self.config.attention_head_dim,
274
+ num_routed_experts = num_routed_experts,
275
+ num_activated_experts = num_activated_experts,
276
+ block_type = BlockType.TransformerBlock
277
+ )
278
+ for i in range(self.config.num_layers)
279
+ ]
280
+ )
281
+
282
+ self.single_stream_blocks = nn.ModuleList(
283
+ [
284
+ HiDreamImageBlock(
285
+ dim = self.inner_dim,
286
+ num_attention_heads = self.config.num_attention_heads,
287
+ attention_head_dim = self.config.attention_head_dim,
288
+ num_routed_experts = num_routed_experts,
289
+ num_activated_experts = num_activated_experts,
290
+ block_type = BlockType.SingleTransformerBlock
291
+ )
292
+ for i in range(self.config.num_single_layers)
293
+ ]
294
+ )
295
+
296
+ self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels)
297
+
298
+ caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
299
+ caption_projection = []
300
+ for caption_channel in caption_channels:
301
+ caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim))
302
+ self.caption_projection = nn.ModuleList(caption_projection)
303
+ self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
304
+
305
+ self.gradient_checkpointing = False
306
+
307
+ def _set_gradient_checkpointing(self, module, value=False):
308
+ if hasattr(module, "gradient_checkpointing"):
309
+ module.gradient_checkpointing = value
310
+
311
+ def expand_timesteps(self, timesteps, batch_size, device):
312
+ if not torch.is_tensor(timesteps):
313
+ is_mps = device.type == "mps"
314
+ if isinstance(timesteps, float):
315
+ dtype = torch.float32 if is_mps else torch.float64
316
+ else:
317
+ dtype = torch.int32 if is_mps else torch.int64
318
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
319
+ elif len(timesteps.shape) == 0:
320
+ timesteps = timesteps[None].to(device)
321
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
322
+ timesteps = timesteps.expand(batch_size)
323
+ return timesteps
324
+
325
+ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
326
+ if is_training:
327
+ x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size)
328
+ else:
329
+ x_arr = []
330
+ for i, img_size in enumerate(img_sizes):
331
+ pH, pW = img_size
332
+ x_arr.append(
333
+ einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
334
+ p1=self.config.patch_size, p2=self.config.patch_size)
335
+ )
336
+ x = torch.cat(x_arr, dim=0)
337
+ return x
338
+
339
+ def patchify(self, x, max_seq, img_sizes=None):
340
+ pz2 = self.config.patch_size * self.config.patch_size
341
+ if isinstance(x, torch.Tensor):
342
+ B, C = x.shape[0], x.shape[1]
343
+ device = x.device
344
+ dtype = x.dtype
345
+ else:
346
+ B, C = len(x), x[0].shape[0]
347
+ device = x[0].device
348
+ dtype = x[0].dtype
349
+ x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
350
+
351
+ if img_sizes is not None:
352
+ for i, img_size in enumerate(img_sizes):
353
+ x_masks[i, 0:img_size[0] * img_size[1]] = 1
354
+ x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
355
+ elif isinstance(x, torch.Tensor):
356
+ pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size
357
+ x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size)
358
+ img_sizes = [[pH, pW]] * B
359
+ x_masks = None
360
+ else:
361
+ raise NotImplementedError
362
+ return x, x_masks, img_sizes
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ timesteps: torch.LongTensor = None,
368
+ encoder_hidden_states: torch.Tensor = None,
369
+ pooled_embeds: torch.Tensor = None,
370
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
371
+ img_ids: Optional[torch.Tensor] = None,
372
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
373
+ return_dict: bool = True,
374
+ ):
375
+ if joint_attention_kwargs is not None:
376
+ joint_attention_kwargs = joint_attention_kwargs.copy()
377
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
378
+ else:
379
+ lora_scale = 1.0
380
+
381
+ if USE_PEFT_BACKEND:
382
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
383
+ scale_lora_layers(self, lora_scale)
384
+ else:
385
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
386
+ logger.warning(
387
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
388
+ )
389
+
390
+ # spatial forward
391
+ batch_size = hidden_states.shape[0]
392
+ hidden_states_type = hidden_states.dtype
393
+
394
+ # 0. time
395
+ timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
396
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
397
+ p_embedder = self.p_embedder(pooled_embeds)
398
+ adaln_input = timesteps + p_embedder
399
+
400
+ hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
401
+ if image_tokens_masks is None:
402
+ pH, pW = img_sizes[0]
403
+ img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
404
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
405
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
406
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
407
+ hidden_states = self.x_embedder(hidden_states)
408
+
409
+ T5_encoder_hidden_states = encoder_hidden_states[0]
410
+ encoder_hidden_states = encoder_hidden_states[-1]
411
+ encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
412
+
413
+ if self.caption_projection is not None:
414
+ new_encoder_hidden_states = []
415
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
416
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
417
+ enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
418
+ new_encoder_hidden_states.append(enc_hidden_state)
419
+ encoder_hidden_states = new_encoder_hidden_states
420
+ T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
421
+ T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
422
+ encoder_hidden_states.append(T5_encoder_hidden_states)
423
+
424
+ txt_ids = torch.zeros(
425
+ batch_size,
426
+ encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
427
+ 3,
428
+ device=img_ids.device, dtype=img_ids.dtype
429
+ )
430
+ ids = torch.cat((img_ids, txt_ids), dim=1)
431
+ rope = self.pe_embedder(ids)
432
+
433
+ # 2. Blocks
434
+ block_id = 0
435
+ initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
436
+ initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
437
+ for bid, block in enumerate(self.double_stream_blocks):
438
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
439
+ cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
440
+ if self.training and self.gradient_checkpointing:
441
+ def create_custom_forward(module, return_dict=None):
442
+ def custom_forward(*inputs):
443
+ if return_dict is not None:
444
+ return module(*inputs, return_dict=return_dict)
445
+ else:
446
+ return module(*inputs)
447
+ return custom_forward
448
+
449
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
450
+ hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint(
451
+ create_custom_forward(block),
452
+ hidden_states,
453
+ image_tokens_masks,
454
+ cur_encoder_hidden_states,
455
+ adaln_input,
456
+ rope,
457
+ **ckpt_kwargs,
458
+ )
459
+ else:
460
+ hidden_states, initial_encoder_hidden_states = block(
461
+ image_tokens = hidden_states,
462
+ image_tokens_masks = image_tokens_masks,
463
+ text_tokens = cur_encoder_hidden_states,
464
+ adaln_input = adaln_input,
465
+ rope = rope,
466
+ )
467
+ initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
468
+ block_id += 1
469
+
470
+ image_tokens_seq_len = hidden_states.shape[1]
471
+ hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
472
+ hidden_states_seq_len = hidden_states.shape[1]
473
+ if image_tokens_masks is not None:
474
+ encoder_attention_mask_ones = torch.ones(
475
+ (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
476
+ device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
477
+ )
478
+ image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
479
+
480
+ for bid, block in enumerate(self.single_stream_blocks):
481
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
482
+ hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
483
+ if self.training and self.gradient_checkpointing:
484
+ def create_custom_forward(module, return_dict=None):
485
+ def custom_forward(*inputs):
486
+ if return_dict is not None:
487
+ return module(*inputs, return_dict=return_dict)
488
+ else:
489
+ return module(*inputs)
490
+ return custom_forward
491
+
492
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
493
+ hidden_states = torch.utils.checkpoint.checkpoint(
494
+ create_custom_forward(block),
495
+ hidden_states,
496
+ image_tokens_masks,
497
+ None,
498
+ adaln_input,
499
+ rope,
500
+ **ckpt_kwargs,
501
+ )
502
+ else:
503
+ hidden_states = block(
504
+ image_tokens = hidden_states,
505
+ image_tokens_masks = image_tokens_masks,
506
+ text_tokens = None,
507
+ adaln_input = adaln_input,
508
+ rope = rope,
509
+ )
510
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
511
+ block_id += 1
512
+
513
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
514
+ output = self.final_layer(hidden_states, adaln_input)
515
+ output = self.unpatchify(output, img_sizes, self.training)
516
+ if image_tokens_masks is not None:
517
+ image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len]
518
+
519
+ if USE_PEFT_BACKEND:
520
+ # remove `lora_scale` from each PEFT layer
521
+ unscale_lora_layers(self, lora_scale)
522
+
523
+ if not return_dict:
524
+ return (output, image_tokens_masks)
525
+ return Transformer2DModelOutput(sample=output, mask=image_tokens_masks)
526
+