blanchon commited on
Commit
6cac7b5
·
1 Parent(s): 18bbde3
Files changed (3) hide show
  1. app-fast.py +14 -15
  2. pyproject.toml +0 -1
  3. requirements.txt +0 -1
app-fast.py CHANGED
@@ -2,19 +2,16 @@ import gradio as gr
2
  import PIL
3
  import spaces
4
  import torch
5
- from diffusers import TorchAoConfig as DiffusersTorchAoConfig
6
  from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
7
  from hi_diffusers.schedulers.flash_flow_match import (
8
  FlashFlowMatchEulerDiscreteScheduler,
9
  )
10
- from torchao.quantization import Int4WeightOnlyConfig
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
  )
15
- from transformers import (
16
- TorchAoConfig as TransformersTorchAoConfig,
17
- )
18
 
19
  # Constants
20
  MODEL_PREFIX: str = "HiDream-ai"
@@ -41,8 +38,10 @@ RESOLUTION_OPTIONS: list[str] = [
41
 
42
  device = torch.device("cuda")
43
 
44
- quant_config = Int4WeightOnlyConfig(group_size=128)
45
- quantization_config = TransformersTorchAoConfig(quant_type=quant_config)
 
 
46
 
47
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
48
  text_encoder = AutoModelForCausalLM.from_pretrained(
@@ -50,18 +49,18 @@ text_encoder = AutoModelForCausalLM.from_pretrained(
50
  output_hidden_states=True,
51
  output_attentions=True,
52
  low_cpu_mem_usage=True,
53
- quantization_config=quantization_config,
54
- torch_dtype=torch.bfloat16, # Explicitly set dtype
55
- device_map="auto", # Still use auto, but ensure device consistency
56
  ).to(device) # Move model to the correct device after loading
57
 
58
- quantization_config = DiffusersTorchAoConfig("int8wo")
 
 
59
  transformer = HiDreamImageTransformer2DModel.from_pretrained(
60
  MODEL_PATH,
61
  subfolder="transformer",
62
- quantization_config=quantization_config,
63
- device_map="auto",
64
- torch_dtype=torch.bfloat16,
65
  ).to(device)
66
 
67
  scheduler = MODEL_CONFIGS["scheduler"](
@@ -75,7 +74,7 @@ pipe = HiDreamImagePipeline.from_pretrained(
75
  scheduler=scheduler,
76
  tokenizer_4=tokenizer,
77
  text_encoder_4=text_encoder,
78
- torch_dtype=torch.bfloat16,
79
  ).to(device)
80
 
81
  pipe.transformer = transformer
 
2
  import PIL
3
  import spaces
4
  import torch
5
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
6
  from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
7
  from hi_diffusers.schedulers.flash_flow_match import (
8
  FlashFlowMatchEulerDiscreteScheduler,
9
  )
 
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
13
  )
14
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
 
 
15
 
16
  # Constants
17
  MODEL_PREFIX: str = "HiDream-ai"
 
38
 
39
  device = torch.device("cuda")
40
 
41
+ quant_config = TransformersBitsAndBytesConfig(
42
+ load_in_8bit=True,
43
+ )
44
+
45
 
46
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
47
  text_encoder = AutoModelForCausalLM.from_pretrained(
 
49
  output_hidden_states=True,
50
  output_attentions=True,
51
  low_cpu_mem_usage=True,
52
+ quantization_config=quant_config,
53
+ torch_dtype=torch.float16,
 
54
  ).to(device) # Move model to the correct device after loading
55
 
56
+ quant_config = DiffusersBitsAndBytesConfig(
57
+ load_in_8bit=True,
58
+ )
59
  transformer = HiDreamImageTransformer2DModel.from_pretrained(
60
  MODEL_PATH,
61
  subfolder="transformer",
62
+ quantization_config=quant_config,
63
+ torch_dtype=torch.float16,
 
64
  ).to(device)
65
 
66
  scheduler = MODEL_CONFIGS["scheduler"](
 
74
  scheduler=scheduler,
75
  tokenizer_4=tokenizer,
76
  text_encoder_4=text_encoder,
77
+ torch_dtype=torch.float16,
78
  ).to(device)
79
 
80
  pipe.transformer = transformer
pyproject.toml CHANGED
@@ -9,7 +9,6 @@ dependencies = [
9
  "diffusers>=0.32.1",
10
  "einops>=0.7.0",
11
  "torch>=2.5.1",
12
- "torchao>=0.10.0",
13
  "torchvision>=0.20.1",
14
  "transformers>=4.47.1",
15
  ]
 
9
  "diffusers>=0.32.1",
10
  "einops>=0.7.0",
11
  "torch>=2.5.1",
 
12
  "torchvision>=0.20.1",
13
  "transformers>=4.47.1",
14
  ]
requirements.txt CHANGED
@@ -10,4 +10,3 @@ einops
10
  gradio
11
  spaces
12
  sentencepiece
13
- torchao
 
10
  gradio
11
  spaces
12
  sentencepiece