blanchon commited on
Commit
0b72dec
·
1 Parent(s): c7ed5da
Files changed (1) hide show
  1. app-fast.py +1 -11
app-fast.py CHANGED
@@ -36,13 +36,10 @@ RESOLUTION_OPTIONS: list[str] = [
36
  "832 x 1248",
37
  ]
38
 
39
- device = torch.device("cuda")
40
-
41
  quant_config = TransformersBitsAndBytesConfig(
42
  load_in_4bit=True,
43
  )
44
 
45
-
46
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
47
  text_encoder = AutoModelForCausalLM.from_pretrained(
48
  LLAMA_MODEL_NAME,
@@ -71,8 +68,8 @@ scheduler = MODEL_CONFIGS["scheduler"](
71
 
72
  pipe = HiDreamImagePipeline.from_pretrained(
73
  MODEL_PATH,
74
- scheduler=scheduler,
75
  transformer=transformer,
 
76
  tokenizer_4=tokenizer,
77
  text_encoder_4=text_encoder,
78
  device_map="balanced",
@@ -90,8 +87,6 @@ def generate_image(
90
  if seed == -1:
91
  seed = torch.randint(0, 1_000_000, (1,)).item()
92
 
93
- # msg = "ℹ️ This spaces currently crash because of the memory usage. Please help me fix 😅"
94
- # raise gr.Error(msg, duration=10)
95
  height, width = tuple(map(int, resolution.replace(" ", "").split("x")))
96
  generator = torch.Generator("cuda").manual_seed(seed)
97
 
@@ -128,11 +123,6 @@ with gr.Blocks(title="HiDream Image Generator Fast") as demo:
128
 
129
  seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
130
  generate_btn = gr.Button("Generate Image", variant="primary")
131
- # generate_btn = gr.Button(
132
- # "This space currently crash because of the memory usage. Please help me fix 😅",
133
- # variant="primary",
134
- # interactive=False,
135
- # )
136
  seed_used = gr.Number(label="Seed Used", interactive=False)
137
 
138
  with gr.Column():
 
36
  "832 x 1248",
37
  ]
38
 
 
 
39
  quant_config = TransformersBitsAndBytesConfig(
40
  load_in_4bit=True,
41
  )
42
 
 
43
  tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
44
  text_encoder = AutoModelForCausalLM.from_pretrained(
45
  LLAMA_MODEL_NAME,
 
68
 
69
  pipe = HiDreamImagePipeline.from_pretrained(
70
  MODEL_PATH,
 
71
  transformer=transformer,
72
+ scheduler=scheduler,
73
  tokenizer_4=tokenizer,
74
  text_encoder_4=text_encoder,
75
  device_map="balanced",
 
87
  if seed == -1:
88
  seed = torch.randint(0, 1_000_000, (1,)).item()
89
 
 
 
90
  height, width = tuple(map(int, resolution.replace(" ", "").split("x")))
91
  generator = torch.Generator("cuda").manual_seed(seed)
92
 
 
123
 
124
  seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
125
  generate_btn = gr.Button("Generate Image", variant="primary")
 
 
 
 
 
126
  seed_used = gr.Number(label="Seed Used", interactive=False)
127
 
128
  with gr.Column():