fenfan commited on
Commit
3408cd5
·
verified ·
1 Parent(s): 0a6db9d

fix: update app.py to fix zero gpu in original style.

Browse files
Files changed (1) hide show
  1. app.py +81 -60
app.py CHANGED
@@ -11,75 +11,96 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
 
14
  import gradio as gr
15
  import torch
16
  import spaces
17
 
18
  from uno.flux.pipeline import UNOPipeline
19
 
20
- model_type = "flux-dev"
21
- offload = False
22
- device = "cuda"
23
-
24
- pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
25
-
26
-
27
- ## it seems must use decorator can be trigger zero GPU
28
- ## not work by mannualy decorate by fn = spaces.GPU(duration=120)(fn)
29
- @spaces.GPU(duration=120)
30
- def generate_callback(*args, **kwargs):
31
- return pipeline.gradio_generate(*args, **kwargs)
32
-
33
- with gr.Blocks() as demo:
34
- gr.Markdown(f"# UNO by UNO team")
35
- with gr.Row():
36
- with gr.Column():
37
- prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
38
- with gr.Row():
39
- image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
40
- image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
41
- image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
42
- image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
43
-
44
- with gr.Row():
45
- with gr.Column():
46
- ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
47
- with gr.Column():
48
- gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
49
- gr.Markdown(" 1->512 / 2->320 / 3...n->256")
50
-
51
- with gr.Row():
52
- with gr.Column():
53
- width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
54
- height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
55
- with gr.Column():
56
- gr.Markdown("📌 The model trained on 512x512 resolution.\n")
57
- gr.Markdown(
58
- "The size closer to 512 is more stable,"
59
- " and the higher size gives a better visual effect but is less stable"
60
- )
61
-
62
- with gr.Accordion("Generation Options", open=False):
63
- with gr.Row():
64
- num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
65
- guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
66
- seed = gr.Number(-1, label="Seed (-1 for random)")
67
 
68
- generate_btn = gr.Button("Generate")
 
 
 
 
 
 
69
 
70
- with gr.Column():
71
- output_image = gr.Image(label="Generated Image")
72
- download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
74
 
75
- inputs = [
76
- prompt, width, height, guidance, num_steps,
77
- seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
78
- ]
79
- generate_btn.click(
80
- fn=generate_callback,
81
- inputs=inputs,
82
- outputs=[output_image, download_btn],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
 
 
 
 
 
84
 
85
- demo.launch()
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+
15
+ import dataclasses
16
+
17
  import gradio as gr
18
  import torch
19
  import spaces
20
 
21
  from uno.flux.pipeline import UNOPipeline
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def create_demo(
25
+ model_type: str,
26
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
27
+ offload: bool = False,
28
+ ):
29
+ pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
30
+ pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
31
 
32
+ with gr.Blocks() as demo:
33
+ gr.Markdown(f"# UNO by UNO team")
34
+ with gr.Row():
35
+ with gr.Column():
36
+ prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
37
+ with gr.Row():
38
+ image_prompt1 = gr.Image(label="ref img1", visible=True, interactive=True, type="pil")
39
+ image_prompt2 = gr.Image(label="ref img2", visible=True, interactive=True, type="pil")
40
+ image_prompt3 = gr.Image(label="ref img3", visible=True, interactive=True, type="pil")
41
+ image_prompt4 = gr.Image(label="ref img4", visible=True, interactive=True, type="pil")
42
 
43
+ with gr.Row():
44
+ with gr.Column():
45
+ ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
46
+ with gr.Column():
47
+ gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
48
+ gr.Markdown(" 1->512 / 2->320 / 3...n->256")
49
 
50
+ with gr.Row():
51
+ with gr.Column():
52
+ width = gr.Slider(512, 2048, 512, step=16, label="Gneration Width")
53
+ height = gr.Slider(512, 2048, 512, step=16, label="Gneration Height")
54
+ with gr.Column():
55
+ gr.Markdown("📌 The model trained on 512x512 resolution.\n")
56
+ gr.Markdown(
57
+ "The size closer to 512 is more stable,"
58
+ " and the higher size gives a better visual effect but is less stable"
59
+ )
60
+
61
+ with gr.Accordion("Generation Options", open=False):
62
+ with gr.Row():
63
+ num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
64
+ guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
65
+ seed = gr.Number(-1, label="Seed (-1 for random)")
66
+
67
+ generate_btn = gr.Button("Generate")
68
+
69
+ with gr.Column():
70
+ output_image = gr.Image(label="Generated Image")
71
+ download_btn = gr.File(label="Download full-resolution", type="filepath", interactive=False)
72
+
73
+
74
+ inputs = [
75
+ prompt, width, height, guidance, num_steps,
76
+ seed, ref_long_side, image_prompt1, image_prompt2, image_prompt3, image_prompt4
77
+ ]
78
+ generate_btn.click(
79
+ fn=pipeline.gradio_generate,
80
+ inputs=inputs,
81
+ outputs=[output_image, download_btn],
82
+ )
83
+
84
+ return demo
85
+
86
+ if __name__ == "__main__":
87
+ from typing import Literal
88
+
89
+ from transformers import HfArgumentParser
90
+
91
+ @dataclasses.dataclass
92
+ class AppArgs:
93
+ name: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"
94
+ device: Literal["cuda", "cpu"] = "cuda" if torch.cuda.is_available() else "cpu"
95
+ offload: bool = dataclasses.field(
96
+ default=False,
97
+ metadata={"help": "If True, sequantial offload the models(ae, dit, text encoder) to CPU if not used."}
98
  )
99
+ port: int = 7860
100
+
101
+ parser = HfArgumentParser([AppArgs])
102
+ args_tuple = parser.parse_args_into_dataclasses() # type: tuple[AppArgs]
103
+ args = args_tuple[0]
104
 
105
+ demo = create_demo(args.name, args.device, args.offload)
106
+ demo.launch(server_port=args.port)