LPX55 commited on
Commit
104f00c
·
verified ·
1 Parent(s): e21e3cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -12
app.py CHANGED
@@ -6,15 +6,94 @@ import logging
6
  from diffusers import DiffusionPipeline
7
  import subprocess
8
 
9
- # os.environ["CUDA_HOME"] = "/usr/local/cuda"
10
- # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
11
- # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
12
- # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
13
- # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
14
- # subprocess.call([CUDA_TOOLKIT_FILE, "--toolkit"])
15
- # os.environ["CUDA_HOME"] = "/usr/local/cuda"
16
-
17
- pipe = DiffusionPipeline.from_pretrained("azaneko/HiDream-I1-Fast-nf4")
18
-
19
- prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
20
- image = pipe(prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from diffusers import DiffusionPipeline
7
  import subprocess
8
 
9
+ from .nf4 import *
10
+
11
+ # Resolution options
12
+ RESOLUTION_OPTIONS = [
13
+ "1024 × 1024 (Square)",
14
+ "768 × 1360 (Portrait)",
15
+ "1360 × 768 (Landscape)",
16
+ "880 × 1168 (Portrait)",
17
+ "1168 × 880 (Landscape)",
18
+ "1248 × 832 (Landscape)",
19
+ "832 × 1248 (Portrait)"
20
+ ]
21
+
22
+ # Parse resolution string to get height and width
23
+ def parse_resolution(resolution_str):
24
+ return tuple(map(int, resolution_str.split("(")[0].strip().split(" × ")))
25
+
26
+ @spaces.GPU()
27
+ def gen_img_helper(model, prompt, res, seed):
28
+ global pipe, current_model
29
+
30
+ # 1. Check if the model matches loaded model, load the model if not
31
+ if model != current_model:
32
+ print(f"Unloading model {current_model}...")
33
+ del pipe
34
+ torch.cuda.empty_cache()
35
+
36
+ print(f"Loading model {model}...")
37
+ pipe, _ = load_models(model)
38
+ current_model = model
39
+ print("Model loaded successfully!")
40
+
41
+ # 2. Generate image
42
+ res = parse_resolution(res)
43
+ return generate_image(pipe, model, prompt, res, seed)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
48
+
49
+ # Initialize with default model
50
+ print("Loading default model (fast)...")
51
+ current_model = "fast"
52
+ pipe, _ = load_models(current_model)
53
+ print("Model loaded successfully!")
54
+
55
+ # Create Gradio interface
56
+ with gr.Blocks(title="HiDream-I1-nf4 Dashboard") as demo:
57
+ gr.Markdown("# HiDream-I1-nf4 Dashboard")
58
+
59
+ with gr.Row():
60
+ with gr.Column():
61
+ model_type = gr.Radio(
62
+ choices=list(MODEL_CONFIGS.keys()),
63
+ value="fast",
64
+ label="Model Type",
65
+ info="Select model variant"
66
+ )
67
+
68
+ prompt = gr.Textbox(
69
+ label="Prompt",
70
+ placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
71
+ lines=3
72
+ )
73
+
74
+ resolution = gr.Radio(
75
+ choices=RESOLUTION_OPTIONS,
76
+ value=RESOLUTION_OPTIONS[0],
77
+ label="Resolution",
78
+ info="Select image resolution"
79
+ )
80
+
81
+ seed = gr.Number(
82
+ label="Seed (use -1 for random)",
83
+ value=-1,
84
+ precision=0
85
+ )
86
+
87
+ generate_btn = gr.Button("Generate Image")
88
+ seed_used = gr.Number(label="Seed Used", interactive=False)
89
+
90
+ with gr.Column():
91
+ output_image = gr.Image(label="Generated Image", type="pil")
92
+
93
+ generate_btn.click(
94
+ fn=gen_img_helper,
95
+ inputs=[model_type, prompt, resolution, seed],
96
+ outputs=[output_image, seed_used]
97
+ )
98
+
99
+ demo.launch()