blanchon commited on
Commit
8cfc368
·
1 Parent(s): 478f981

update deps

Browse files
Files changed (2) hide show
  1. app.py +6 -19
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import spaces
3
  import torch
4
  from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
@@ -48,10 +49,6 @@ RESOLUTION_OPTIONS: list[str] = [
48
  ]
49
 
50
 
51
- def parse_resolution(res_str: str) -> tuple[int, int]:
52
- return tuple(map(int, res_str.replace(" ", "").split("x")))
53
-
54
-
55
  tokenizer = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
56
  text_encoder = LlamaForCausalLM.from_pretrained(
57
  LLAMA_MODEL_NAME,
@@ -85,25 +82,22 @@ pipe.transformer = transformer
85
 
86
  @spaces.GPU(duration=90)
87
  def generate_image(
88
- model_type: str,
89
  prompt: str,
90
  resolution: str,
91
  seed: int,
92
- ) -> tuple[object, int]:
93
- config = MODEL_CONFIGS[model_type]
94
-
95
  if seed == -1:
96
  seed = torch.randint(0, 1_000_000, (1,)).item()
97
 
98
- height, width = parse_resolution(resolution)
99
  generator = torch.Generator("cuda").manual_seed(seed)
100
 
101
  image = pipe(
102
  prompt=prompt,
103
  height=height,
104
  width=width,
105
- guidance_scale=config["guidance_scale"],
106
- num_inference_steps=config["num_inference_steps"],
107
  generator=generator,
108
  ).images[0]
109
 
@@ -117,13 +111,6 @@ with gr.Blocks(title="HiDream Image Generator") as demo:
117
 
118
  with gr.Row():
119
  with gr.Column():
120
- model_type = gr.Radio(
121
- choices=list(MODEL_CONFIGS.keys()),
122
- value="full",
123
- label="Model Type",
124
- info="Choose between full, fast or dev variants",
125
- )
126
-
127
  prompt = gr.Textbox(
128
  label="Prompt",
129
  placeholder="e.g. A futuristic city with floating cars at sunset",
@@ -145,7 +132,7 @@ with gr.Blocks(title="HiDream Image Generator") as demo:
145
 
146
  generate_btn.click(
147
  fn=generate_image,
148
- inputs=[model_type, prompt, resolution, seed],
149
  outputs=[output_image, seed_used],
150
  )
151
 
 
1
  import gradio as gr
2
+ import PIL
3
  import spaces
4
  import torch
5
  from hi_diffusers import HiDreamImagePipeline, HiDreamImageTransformer2DModel
 
49
  ]
50
 
51
 
 
 
 
 
52
  tokenizer = PreTrainedTokenizerFast.from_pretrained(LLAMA_MODEL_NAME, use_fast=False)
53
  text_encoder = LlamaForCausalLM.from_pretrained(
54
  LLAMA_MODEL_NAME,
 
82
 
83
  @spaces.GPU(duration=90)
84
  def generate_image(
 
85
  prompt: str,
86
  resolution: str,
87
  seed: int,
88
+ ) -> tuple[PIL.Image.Image, int]:
 
 
89
  if seed == -1:
90
  seed = torch.randint(0, 1_000_000, (1,)).item()
91
 
92
+ height, width = tuple(map(int, resolution.replace(" ", "").split("x")))
93
  generator = torch.Generator("cuda").manual_seed(seed)
94
 
95
  image = pipe(
96
  prompt=prompt,
97
  height=height,
98
  width=width,
99
+ guidance_scale=MODEL_CONFIGS["guidance_scale"],
100
+ num_inference_steps=MODEL_CONFIGS["num_inference_steps"],
101
  generator=generator,
102
  ).images[0]
103
 
 
111
 
112
  with gr.Row():
113
  with gr.Column():
 
 
 
 
 
 
 
114
  prompt = gr.Textbox(
115
  label="Prompt",
116
  placeholder="e.g. A futuristic city with floating cars at sunset",
 
132
 
133
  generate_btn.click(
134
  fn=generate_image,
135
+ inputs=[prompt, resolution, seed],
136
  outputs=[output_image, seed_used],
137
  )
138
 
requirements.txt CHANGED
@@ -4,7 +4,7 @@ diffusers
4
  transformers
5
  accelerate
6
  xformers
7
- https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
8
  einops
9
  gradio
10
- spaces
 
4
  transformers
5
  accelerate
6
  xformers
7
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
8
  einops
9
  gradio
10
+ spaces