Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
import hashlib | |
import functools | |
import json | |
import yaml | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from diffusers import AutoencoderKL | |
from torchvision import transforms | |
from tqdm import tqdm | |
from imgproc import ( | |
generate_crop_size_list, | |
to_rgb_if_rgba, | |
var_center_crop, | |
) | |
from data import read_general | |
# ---- Flux VAE scaling parameters ---- | |
VAE_SCALE = 0.3611 | |
VAE_SHIFT = 0.1159 | |
def handle_image(image: Image.Image) -> Image.Image: | |
""" | |
Ensure the image is in RGB format, converting from RGBA, L, P, etc. | |
Raise ValueError if unrecognized mode. | |
""" | |
mode = image.mode.upper() | |
if mode == "RGB": | |
return image | |
elif mode == "RGBA": | |
return to_rgb_if_rgba(image) | |
elif mode in ("L", "P"): | |
return image.convert("RGB") | |
else: | |
raise ValueError(f"Unsupported image mode: {mode}") | |
def encode(vae: AutoencoderKL, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor: | |
""" | |
Encode a normalized image tensor to latents using the Flux VAE, applying SHIFT+SCALE. | |
img_tensor shape: (C, H, W) or (1,C,H,W). We'll reshape to (1,C,H,W) if needed. | |
""" | |
if img_tensor.dim() == 3: | |
img_tensor = img_tensor.unsqueeze(0) # (1,C,H,W) | |
img_tensor = img_tensor.to(device, non_blocking=True) | |
with torch.no_grad(): | |
# bfloat16 casting for VAE encode | |
latent_dist = vae.encode(img_tensor).latent_dist | |
# use .mode()[0] or .sample() depending on whether you prefer the mode or random sample | |
latents = latent_dist.mode()[0] | |
latents = (latents - VAE_SHIFT) * VAE_SCALE | |
return latents.float() | |
def load_image_paths_from_yaml(yaml_path: str) -> list: | |
""" | |
Parse a YAML containing a 'META' key with paths to .jsonl files. | |
For each .jsonl (with 'type' == 'image_text'), read lines of JSON | |
where we expect an 'image_path' field. Collect these paths in a list. | |
""" | |
with open(yaml_path, "r", encoding="utf-8") as f: | |
data = yaml.safe_load(f) | |
image_files = [] | |
meta_list = data.get("META", []) | |
for meta_item in meta_list: | |
# Example: path=/data0/DanbooruWebp/booru1116Webp.jsonl | |
# type=image_text | |
ftype = meta_item.get("type", "") | |
fpath = meta_item.get("path", "") | |
if ftype != "image_text": | |
# skip unknown types | |
continue | |
if not os.path.isfile(fpath): | |
print(f"[Warning] JSONL file not found: {fpath}") | |
continue | |
# Open .jsonl and parse lines | |
with open(fpath, "r", encoding="utf-8") as fin: | |
for line in fin: | |
line = line.strip() | |
if not line: | |
continue | |
try: | |
obj = json.loads(line) | |
if "image_path" in obj: | |
# This is the actual disk path for the image | |
image_files.append(obj["image_path"]) | |
except Exception as e: | |
print(f"[Warning] JSON parse error in {fpath}: {e}") | |
continue | |
return image_files | |
def main(): | |
parser = argparse.ArgumentParser(description="Cache image latents using Flux VAE") | |
parser.add_argument("--data_yaml", type=str, required=True, | |
help="Path to dataset YAML config (with META -> .jsonl paths)") | |
parser.add_argument("--resolution", type=int, required=True, | |
help="Target resolution (e.g., 256, 512, 1024) for center-crop/resize") | |
parser.add_argument("--total_split", type=int, default=1, | |
help="Total number of parallel splits/workers") | |
parser.add_argument("--current_worker_index", type=int, default=0, | |
help="Index of this worker (0-based)") | |
parser.add_argument("--patch_size", type=int, default=8, | |
help="Patch size used for generating potential crop sizes") | |
parser.add_argument("--random_top_k", type=int, default=1, | |
help="Number of top crop options from var_center_crop to randomly pick") | |
args = parser.parse_args() | |
# ------------------------------------------------------------------ | |
# 1) Setup VAE model for encoding: | |
# ------------------------------------------------------------------ | |
vae = AutoencoderKL.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
subfolder="vae", | |
torch_dtype=torch.float16 | |
).eval() | |
device = torch.device( | |
f"cuda:0" if torch.cuda.is_available() else "cpu" | |
) | |
vae.to(device) | |
# ------------------------------------------------------------------ | |
# 2) Prepare your transform (crop -> tensor -> normalize). | |
# This must match how images are processed before training. | |
# ------------------------------------------------------------------ | |
max_num_patches = round((args.resolution / (args.patch_size * 1.0)) ** 2) | |
crop_size_list = generate_crop_size_list(max_num_patches, args.patch_size) | |
image_transform = transforms.Compose([ | |
transforms.Lambda(functools.partial(var_center_crop, | |
crop_size_list=crop_size_list, | |
random_top_k=args.random_top_k)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
]) | |
# ------------------------------------------------------------------ | |
# 3) Load image paths from YAML / JSONL references: | |
# ------------------------------------------------------------------ | |
image_files = load_image_paths_from_yaml(args.data_yaml) | |
if not image_files: | |
print("[INFO] No image files found. Check your YAML & JSONL contents.") | |
return | |
# ------------------------------------------------------------------ | |
# 4) Process each image => transform => encode => save .npz | |
# ------------------------------------------------------------------ | |
worker_idx = args.current_worker_index | |
total_split = args.total_split | |
res = args.resolution | |
for image_path in tqdm(image_files, desc=f"Worker {worker_idx}"): | |
# 4.a) Determine if this file belongs to the current worker | |
hash_val = int(hashlib.sha1(image_path.encode("utf-8")).hexdigest(), 16) | |
if hash_val % total_split != worker_idx: | |
continue | |
# 4.b) Construct cache path | |
base, _ = os.path.splitext(image_path) | |
out_path = f"{base}_{res}.npz" | |
if os.path.exists(out_path): | |
continue | |
# 4.c) Read the image from disk & handle mode | |
try: | |
pil_image = Image.open(read_general(image_path)) | |
pil_image = handle_image(pil_image) # ensure RGB | |
except Exception as e: | |
print(f"[Warning] Could not open image {image_path}: {e}") | |
continue | |
# Optionally, you can do a simple resize (if your training expects it). | |
# Otherwise, rely solely on var_center_crop to pick a final crop size. | |
pil_image = pil_image.resize((res, res), Image.Resampling.LANCZOS) | |
# 4.d) Apply var_center_crop -> toTensor -> normalize | |
try: | |
transformed_tensor = image_transform(pil_image) # shape=(3,H,W) | |
except Exception as e: | |
print(f"[Warning] Skipping {image_path} due to transform error: {e}") | |
continue | |
transformed_tensor = transformed_tensor.to(torch.float16) | |
# 4.e) Encode with Flux VAE (shift+scale) => latent | |
latents = encode(vae, transformed_tensor, device=device) | |
latents_np = latents.cpu().numpy() # shape=(C, H//8, W//8) typically | |
# 4.f) Save latents to .npz | |
try: | |
np.savez_compressed(out_path, latent=latents_np) | |
except Exception as e: | |
print(f"[Error] Saving .npz for {image_path} failed: {e}") | |
if __name__ == "__main__": | |
main() | |