Duskfallcrew commited on
Commit
e9917a9
·
verified ·
1 Parent(s): 592aa6b

Create app.py

Browse files

I HAVE NO CLUE WHAT IM DOING

Files changed (1) hide show
  1. app.py +300 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import StableDiffusionXLPipeline
5
+ from huggingface_hub import HfApi, login
6
+ from huggingface_hub.utils import validate_repo_id, HfHubHTTPError
7
+ import re
8
+ import json
9
+ import glob
10
+ import gdown
11
+ import requests
12
+ import subprocess
13
+ from urllib.parse import urlparse, unquote
14
+ from pathlib import Path
15
+
16
+ # ---------------------- DEPENDENCIES ----------------------
17
+
18
+ def install_dependencies_gradio():
19
+ """Installs the necessary dependencies for the Gradio app. Run this ONCE."""
20
+ try:
21
+ !pip install -U torch diffusers transformers accelerate safetensors huggingface_hub xformers
22
+ print("Dependencies installed successfully.")
23
+ except Exception as e:
24
+ print(f"Error installing dependencies: {e}")
25
+
26
+ # ---------------------- UTILITY FUNCTIONS ----------------------
27
+
28
+ def get_save_dtype(save_precision_as):
29
+ """Determines the save dtype based on the user's choice."""
30
+ if save_precision_as == "fp16":
31
+ return torch.float16
32
+ elif save_precision_as == "bf16":
33
+ return torch.bfloat16
34
+ elif save_precision_as == "float":
35
+ return torch.float32 # Using float32 for "float" option
36
+ else:
37
+ return None
38
+
39
+ def determine_load_checkpoint(model_to_load):
40
+ """Determines if the model to load is a checkpoint or a Diffusers model."""
41
+ if model_to_load.endswith('.ckpt') or model_to_load.endswith('.safetensors'):
42
+ return True
43
+ elif os.path.isdir(model_to_load):
44
+ required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
45
+ if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
46
+ return False
47
+ return None # handle this case as required
48
+
49
+ def increment_filename(filename):
50
+ """Increments the filename to avoid overwriting existing files."""
51
+ base, ext = os.path.splitext(filename)
52
+ counter = 1
53
+ while os.path.exists(filename):
54
+ filename = f"{base}({counter}){ext}"
55
+ counter += 1
56
+ return filename
57
+
58
+ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
59
+ """Creates a Hugging Face model repository if it doesn't exist."""
60
+ if orgs_name == "":
61
+ repo_id = user["name"] + "/" + model_name.strip()
62
+ else:
63
+ repo_id = orgs_name + "/" + model_name.strip()
64
+
65
+ try:
66
+ validate_repo_id(repo_id)
67
+ api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
68
+ print(f"Model repo '{repo_id}' didn't exist, creating repo")
69
+ except HfHubHTTPError as e:
70
+ print(f"Model repo '{repo_id}' exists, skipping create repo")
71
+
72
+ print(f"Model repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
73
+
74
+ return repo_id
75
+
76
+ def is_diffusers_model(model_path):
77
+ """Checks if a given path is a valid Diffusers model directory."""
78
+ required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
79
+ return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
80
+
81
+ # ---------------------- CONVERSION AND UPLOAD FUNCTIONS ----------------------
82
+
83
+ def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
84
+ """Loads the SDXL model from a checkpoint or Diffusers model."""
85
+ model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if args.fp16 else "")
86
+ with output_widget:
87
+ print(f"Loading {model_load_message}: {args.model_to_load}")
88
+
89
+ if is_load_checkpoint:
90
+ loaded_model_data = load_from_sdxl_checkpoint(args, output_widget)
91
+ else:
92
+ loaded_model_data = load_sdxl_from_diffusers(args, load_dtype)
93
+
94
+ return loaded_model_data
95
+
96
+ def load_from_sdxl_checkpoint(args, output_widget):
97
+ """Loads the SDXL model components from a checkpoint file (placeholder)."""
98
+ # text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
99
+ # "sdxl_base_v1-0", args.model_to_load, "cpu"
100
+ # )
101
+
102
+ # Implement Load model from ckpt or safetensors
103
+ text_encoder1, text_encoder2, vae, unet = None, None, None, None
104
+
105
+ with output_widget:
106
+ print("Loading from Checkpoint not implemented, please implement based on your model needs.")
107
+
108
+ return text_encoder1, text_encoder2, vae, unet
109
+
110
+ def load_sdxl_from_diffusers(args, load_dtype):
111
+ """Loads an SDXL model from a Diffusers model directory."""
112
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
113
+ args.model_to_load, torch_dtype=load_dtype, tokenizer=None, tokenizer_2=None, scheduler=None
114
+ )
115
+ text_encoder1 = pipeline.text_encoder
116
+ text_encoder2 = pipeline.text_encoder_2
117
+ vae = pipeline.vae
118
+ unet = pipeline.unet
119
+
120
+ return text_encoder1, text_encoder2, vae, unet
121
+
122
+ def convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget):
123
+ """Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
124
+ text_encoder1, text_encoder2, vae, unet = loaded_model_data
125
+ model_save_message = "checkpoint" + ("" if save_dtype is None else f" in {save_dtype}") if is_save_checkpoint else "Diffusers"
126
+
127
+ with output_widget:
128
+ print(f"Converting and saving as {model_save_message}: {args.model_to_save}")
129
+
130
+ if is_save_checkpoint:
131
+ save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
132
+ else:
133
+ save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
134
+
135
+ def save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
136
+ """Saves the SDXL model components as a checkpoint file (placeholder)."""
137
+ # logit_scale = None
138
+ # ckpt_info = None
139
+
140
+ # key_count = sdxl_model_util.save_stable_diffusion_checkpoint(
141
+ # args.model_to_save, text_encoder1, text_encoder2, unet, args.epoch, args.global_step, ckpt_info, vae, logit_scale, save_dtype
142
+ # )
143
+
144
+ with output_widget:
145
+ print("Saving as Checkpoint not implemented, please implement based on your model needs.")
146
+ # print(f"Model saved. Total converted state_dict keys: {key_count}")
147
+
148
+ def save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
149
+ """Saves the SDXL model as a Diffusers model."""
150
+ with output_widget:
151
+ reference_model_message = args.reference_model if args.reference_model is not None else 'default model'
152
+ print(f"Copying scheduler/tokenizer config from: {reference_model_message}")
153
+
154
+ # Save diffusers pipeline
155
+ pipeline = StableDiffusionXLPipeline(
156
+ vae=vae,
157
+ text_encoder=text_encoder1,
158
+ text_encoder_2=text_encoder2,
159
+ unet=unet,
160
+ scheduler=None, # Replace None if there is a scheduler
161
+ tokenizer=None, # Replace None if there is a tokenizer
162
+ tokenizer_2=None # Replace None if there is a tokenizer_2
163
+ )
164
+
165
+ pipeline.save_pretrained(args.model_to_save)
166
+
167
+ with output_widget:
168
+ print(f"Model saved as {save_dtype}.")
169
+
170
+ def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, output_widget):
171
+ """Main conversion function."""
172
+ class Args: # Defining Args locally within convert_model
173
+ def __init__(self, model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16):
174
+ self.model_to_load = model_to_load
175
+ self.save_precision_as = save_precision_as
176
+ self.epoch = epoch
177
+ self.global_step = global_step
178
+ self.reference_model = reference_model
179
+ self.output_path = output_path
180
+ self.fp16 = fp16
181
+
182
+ args = Args(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16)
183
+ args.model_to_save = increment_filename(os.path.splitext(args.model_to_load)[0] + ".safetensors")
184
+
185
+ try:
186
+ load_dtype = torch.float16 if fp16 else None
187
+ save_dtype = get_save_dtype(save_precision_as)
188
+
189
+ is_load_checkpoint = determine_load_checkpoint(model_to_load)
190
+ is_save_checkpoint = not is_load_checkpoint # reverse of load model
191
+
192
+ loaded_model_data = load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget)
193
+ convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget)
194
+
195
+ with output_widget:
196
+ return f"Conversion complete. Model saved to {args.model_to_save}"
197
+
198
+ except Exception as e:
199
+ with output_widget:
200
+ return f"Conversion failed: {e}"
201
+
202
+ def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private, output_widget):
203
+ """Uploads a model to the Hugging Face Hub."""
204
+ try:
205
+ login(hf_token, add_to_git_credential=True)
206
+ api = HfApi()
207
+ user = api.whoami(hf_token)
208
+ model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
209
+
210
+ # Determine upload parameters (adjust as needed)
211
+ path_in_repo = ""
212
+ trained_model = os.path.basename(model_path)
213
+
214
+ path_in_repo_local = path_in_repo if path_in_repo and not is_diffusers_model(model_path) else ""
215
+
216
+ notification = f"Uploading {trained_model} from {model_path} to https://huggingface.co/{model_repo}"
217
+ with output_widget:
218
+ print(notification)
219
+
220
+ if os.path.isdir(model_path):
221
+ if is_diffusers_model(model_path):
222
+ commit_message = f"Upload diffusers format: {trained_model}"
223
+ print("Detected diffusers model. Adjusting upload parameters.")
224
+ else:
225
+ commit_message = f"Upload checkpoint: {trained_model}"
226
+ print("Detected regular model. Adjusting upload parameters.")
227
+
228
+ api.upload_folder(
229
+ folder_path=model_path,
230
+ path_in_repo=path_in_repo_local,
231
+ repo_id=model_repo,
232
+ commit_message=commit_message,
233
+ ignore_patterns=".ipynb_checkpoints",
234
+ )
235
+ else:
236
+ commit_message = f"Upload file: {trained_model}"
237
+ api.upload_file(
238
+ path_or_fileobj=model_path,
239
+ path_in_repo=path_in_repo_local,
240
+ repo_id=model_repo,
241
+ commit_message=commit_message,
242
+ )
243
+ with output_widget:
244
+ return f"Model upload complete! Check it out at https://huggingface.co/{model_repo}/tree/main"
245
+
246
+ except Exception as e:
247
+ with output_widget:
248
+ return f"Upload failed: {e}"
249
+
250
+ # ---------------------- GRADIO INTERFACE ----------------------
251
+
252
+ def main(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, hf_token, orgs_name, model_name, make_private):
253
+ """Main function orchestrating the entire process."""
254
+ output = gr.Markdown()
255
+
256
+ conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, output_path, fp16, output)
257
+
258
+ upload_output = upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private, output)
259
+
260
+ # Return a combined output
261
+ return f"{conversion_output}\n\n{upload_output}"
262
+
263
+ with gr.Blocks() as demo:
264
+
265
+ # Add initial warnings (only once)
266
+ gr.Markdown("""
267
+ ## **⚠️ IMPORTANT WARNINGS ⚠️**
268
+ This app may violate Google Colab AUP. Use at your own risk. `xformers` may cause issues.
269
+ """)
270
+
271
+ model_to_load = gr.Textbox(label="Model to Load (Checkpoint or Diffusers)", placeholder="Path to model")
272
+ with gr.Row():
273
+ save_precision_as = gr.Dropdown(
274
+ choices=["fp16", "bf16", "float"], value="fp16", label="Save Precision As"
275
+ )
276
+ fp16 = gr.Checkbox(label="Load as fp16 (Diffusers only)")
277
+ with gr.Row():
278
+ epoch = gr.Number(value=0, label="Epoch to Write (Checkpoint)")
279
+ global_step = gr.Number(value=0, label="Global Step to Write (Checkpoint)")
280
+
281
+ reference_model = gr.Textbox(label="Reference Diffusers Model",
282
+ placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0")
283
+ output_path = gr.Textbox(label="Output Path", value="/content/output")
284
+
285
+ gr.Markdown("## Hugging Face Hub Configuration")
286
+ hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
287
+ with gr.Row():
288
+ orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
289
+ model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
290
+ make_private = gr.Checkbox(label="Make Repository Private", value=False)
291
+
292
+ convert_button = gr.Button("Convert and Upload")
293
+ output = gr.Markdown()
294
+
295
+ convert_button.click(fn=main,
296
+ inputs=[model_to_load, save_precision_as, epoch, global_step, reference_model,
297
+ output_path, fp16, hf_token, orgs_name, model_name, make_private],
298
+ outputs=output)
299
+
300
+ demo.launch()