import argparse import os import hashlib import functools import json import yaml import numpy as np import torch import torch.nn.functional as F from PIL import Image from diffusers import AutoencoderKL from torchvision import transforms from tqdm import tqdm from imgproc import ( generate_crop_size_list, to_rgb_if_rgba, var_center_crop, ) from data import read_general # ---- Flux VAE scaling parameters ---- VAE_SCALE = 0.3611 VAE_SHIFT = 0.1159 def handle_image(image: Image.Image) -> Image.Image: """ Ensure the image is in RGB format, converting from RGBA, L, P, etc. Raise ValueError if unrecognized mode. """ mode = image.mode.upper() if mode == "RGB": return image elif mode == "RGBA": return to_rgb_if_rgba(image) elif mode in ("L", "P"): return image.convert("RGB") else: raise ValueError(f"Unsupported image mode: {mode}") def encode(vae: AutoencoderKL, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor: """ Encode a normalized image tensor to latents using the Flux VAE, applying SHIFT+SCALE. img_tensor shape: (C, H, W) or (1,C,H,W). We'll reshape to (1,C,H,W) if needed. """ if img_tensor.dim() == 3: img_tensor = img_tensor.unsqueeze(0) # (1,C,H,W) img_tensor = img_tensor.to(device, non_blocking=True) with torch.no_grad(): # bfloat16 casting for VAE encode latent_dist = vae.encode(img_tensor).latent_dist # use .mode()[0] or .sample() depending on whether you prefer the mode or random sample latents = latent_dist.mode()[0] latents = (latents - VAE_SHIFT) * VAE_SCALE return latents.float() def load_image_paths_from_yaml(yaml_path: str) -> list: """ Parse a YAML containing a 'META' key with paths to .jsonl files. For each .jsonl (with 'type' == 'image_text'), read lines of JSON where we expect an 'image_path' field. Collect these paths in a list. """ with open(yaml_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) image_files = [] meta_list = data.get("META", []) for meta_item in meta_list: # Example: path=/data0/DanbooruWebp/booru1116Webp.jsonl # type=image_text ftype = meta_item.get("type", "") fpath = meta_item.get("path", "") if ftype != "image_text": # skip unknown types continue if not os.path.isfile(fpath): print(f"[Warning] JSONL file not found: {fpath}") continue # Open .jsonl and parse lines with open(fpath, "r", encoding="utf-8") as fin: for line in fin: line = line.strip() if not line: continue try: obj = json.loads(line) if "image_path" in obj: # This is the actual disk path for the image image_files.append(obj["image_path"]) except Exception as e: print(f"[Warning] JSON parse error in {fpath}: {e}") continue return image_files def main(): parser = argparse.ArgumentParser(description="Cache image latents using Flux VAE") parser.add_argument("--data_yaml", type=str, required=True, help="Path to dataset YAML config (with META -> .jsonl paths)") parser.add_argument("--resolution", type=int, required=True, help="Target resolution (e.g., 256, 512, 1024) for center-crop/resize") parser.add_argument("--total_split", type=int, default=1, help="Total number of parallel splits/workers") parser.add_argument("--current_worker_index", type=int, default=0, help="Index of this worker (0-based)") parser.add_argument("--patch_size", type=int, default=8, help="Patch size used for generating potential crop sizes") parser.add_argument("--random_top_k", type=int, default=1, help="Number of top crop options from var_center_crop to randomly pick") args = parser.parse_args() # ------------------------------------------------------------------ # 1) Setup VAE model for encoding: # ------------------------------------------------------------------ vae = AutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.float16 ).eval() device = torch.device( f"cuda:0" if torch.cuda.is_available() else "cpu" ) vae.to(device) # ------------------------------------------------------------------ # 2) Prepare your transform (crop -> tensor -> normalize). # This must match how images are processed before training. # ------------------------------------------------------------------ max_num_patches = round((args.resolution / (args.patch_size * 1.0)) ** 2) crop_size_list = generate_crop_size_list(max_num_patches, args.patch_size) image_transform = transforms.Compose([ transforms.Lambda(functools.partial(var_center_crop, crop_size_list=crop_size_list, random_top_k=args.random_top_k)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) # ------------------------------------------------------------------ # 3) Load image paths from YAML / JSONL references: # ------------------------------------------------------------------ image_files = load_image_paths_from_yaml(args.data_yaml) if not image_files: print("[INFO] No image files found. Check your YAML & JSONL contents.") return # ------------------------------------------------------------------ # 4) Process each image => transform => encode => save .npz # ------------------------------------------------------------------ worker_idx = args.current_worker_index total_split = args.total_split res = args.resolution for image_path in tqdm(image_files, desc=f"Worker {worker_idx}"): # 4.a) Determine if this file belongs to the current worker hash_val = int(hashlib.sha1(image_path.encode("utf-8")).hexdigest(), 16) if hash_val % total_split != worker_idx: continue # 4.b) Construct cache path base, _ = os.path.splitext(image_path) out_path = f"{base}_{res}.npz" if os.path.exists(out_path): continue # 4.c) Read the image from disk & handle mode try: pil_image = Image.open(read_general(image_path)) pil_image = handle_image(pil_image) # ensure RGB except Exception as e: print(f"[Warning] Could not open image {image_path}: {e}") continue # Optionally, you can do a simple resize (if your training expects it). # Otherwise, rely solely on var_center_crop to pick a final crop size. pil_image = pil_image.resize((res, res), Image.Resampling.LANCZOS) # 4.d) Apply var_center_crop -> toTensor -> normalize try: transformed_tensor = image_transform(pil_image) # shape=(3,H,W) except Exception as e: print(f"[Warning] Skipping {image_path} due to transform error: {e}") continue transformed_tensor = transformed_tensor.to(torch.float16) # 4.e) Encode with Flux VAE (shift+scale) => latent latents = encode(vae, transformed_tensor, device=device) latents_np = latents.cpu().numpy() # shape=(C, H//8, W//8) typically # 4.f) Save latents to .npz try: np.savez_compressed(out_path, latent=latents_np) except Exception as e: print(f"[Error] Saving .npz for {image_path} failed: {e}") if __name__ == "__main__": main()