AngelBottomless's picture
Upload 18 files
0a4fc35 verified
raw
history blame
8.16 kB
import argparse
import os
import hashlib
import functools
import json
import yaml
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from diffusers import AutoencoderKL
from torchvision import transforms
from tqdm import tqdm
from imgproc import (
generate_crop_size_list,
to_rgb_if_rgba,
var_center_crop,
)
from data import read_general
# ---- Flux VAE scaling parameters ----
VAE_SCALE = 0.3611
VAE_SHIFT = 0.1159
def handle_image(image: Image.Image) -> Image.Image:
"""
Ensure the image is in RGB format, converting from RGBA, L, P, etc.
Raise ValueError if unrecognized mode.
"""
mode = image.mode.upper()
if mode == "RGB":
return image
elif mode == "RGBA":
return to_rgb_if_rgba(image)
elif mode in ("L", "P"):
return image.convert("RGB")
else:
raise ValueError(f"Unsupported image mode: {mode}")
def encode(vae: AutoencoderKL, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Encode a normalized image tensor to latents using the Flux VAE, applying SHIFT+SCALE.
img_tensor shape: (C, H, W) or (1,C,H,W). We'll reshape to (1,C,H,W) if needed.
"""
if img_tensor.dim() == 3:
img_tensor = img_tensor.unsqueeze(0) # (1,C,H,W)
img_tensor = img_tensor.to(device, non_blocking=True)
with torch.no_grad():
# bfloat16 casting for VAE encode
latent_dist = vae.encode(img_tensor).latent_dist
# use .mode()[0] or .sample() depending on whether you prefer the mode or random sample
latents = latent_dist.mode()[0]
latents = (latents - VAE_SHIFT) * VAE_SCALE
return latents.float()
def load_image_paths_from_yaml(yaml_path: str) -> list:
"""
Parse a YAML containing a 'META' key with paths to .jsonl files.
For each .jsonl (with 'type' == 'image_text'), read lines of JSON
where we expect an 'image_path' field. Collect these paths in a list.
"""
with open(yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
image_files = []
meta_list = data.get("META", [])
for meta_item in meta_list:
# Example: path=/data0/DanbooruWebp/booru1116Webp.jsonl
# type=image_text
ftype = meta_item.get("type", "")
fpath = meta_item.get("path", "")
if ftype != "image_text":
# skip unknown types
continue
if not os.path.isfile(fpath):
print(f"[Warning] JSONL file not found: {fpath}")
continue
# Open .jsonl and parse lines
with open(fpath, "r", encoding="utf-8") as fin:
for line in fin:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "image_path" in obj:
# This is the actual disk path for the image
image_files.append(obj["image_path"])
except Exception as e:
print(f"[Warning] JSON parse error in {fpath}: {e}")
continue
return image_files
def main():
parser = argparse.ArgumentParser(description="Cache image latents using Flux VAE")
parser.add_argument("--data_yaml", type=str, required=True,
help="Path to dataset YAML config (with META -> .jsonl paths)")
parser.add_argument("--resolution", type=int, required=True,
help="Target resolution (e.g., 256, 512, 1024) for center-crop/resize")
parser.add_argument("--total_split", type=int, default=1,
help="Total number of parallel splits/workers")
parser.add_argument("--current_worker_index", type=int, default=0,
help="Index of this worker (0-based)")
parser.add_argument("--patch_size", type=int, default=8,
help="Patch size used for generating potential crop sizes")
parser.add_argument("--random_top_k", type=int, default=1,
help="Number of top crop options from var_center_crop to randomly pick")
args = parser.parse_args()
# ------------------------------------------------------------------
# 1) Setup VAE model for encoding:
# ------------------------------------------------------------------
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="vae",
torch_dtype=torch.float16
).eval()
device = torch.device(
f"cuda:0" if torch.cuda.is_available() else "cpu"
)
vae.to(device)
# ------------------------------------------------------------------
# 2) Prepare your transform (crop -> tensor -> normalize).
# This must match how images are processed before training.
# ------------------------------------------------------------------
max_num_patches = round((args.resolution / (args.patch_size * 1.0)) ** 2)
crop_size_list = generate_crop_size_list(max_num_patches, args.patch_size)
image_transform = transforms.Compose([
transforms.Lambda(functools.partial(var_center_crop,
crop_size_list=crop_size_list,
random_top_k=args.random_top_k)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# ------------------------------------------------------------------
# 3) Load image paths from YAML / JSONL references:
# ------------------------------------------------------------------
image_files = load_image_paths_from_yaml(args.data_yaml)
if not image_files:
print("[INFO] No image files found. Check your YAML & JSONL contents.")
return
# ------------------------------------------------------------------
# 4) Process each image => transform => encode => save .npz
# ------------------------------------------------------------------
worker_idx = args.current_worker_index
total_split = args.total_split
res = args.resolution
for image_path in tqdm(image_files, desc=f"Worker {worker_idx}"):
# 4.a) Determine if this file belongs to the current worker
hash_val = int(hashlib.sha1(image_path.encode("utf-8")).hexdigest(), 16)
if hash_val % total_split != worker_idx:
continue
# 4.b) Construct cache path
base, _ = os.path.splitext(image_path)
out_path = f"{base}_{res}.npz"
if os.path.exists(out_path):
continue
# 4.c) Read the image from disk & handle mode
try:
pil_image = Image.open(read_general(image_path))
pil_image = handle_image(pil_image) # ensure RGB
except Exception as e:
print(f"[Warning] Could not open image {image_path}: {e}")
continue
# Optionally, you can do a simple resize (if your training expects it).
# Otherwise, rely solely on var_center_crop to pick a final crop size.
pil_image = pil_image.resize((res, res), Image.Resampling.LANCZOS)
# 4.d) Apply var_center_crop -> toTensor -> normalize
try:
transformed_tensor = image_transform(pil_image) # shape=(3,H,W)
except Exception as e:
print(f"[Warning] Skipping {image_path} due to transform error: {e}")
continue
transformed_tensor = transformed_tensor.to(torch.float16)
# 4.e) Encode with Flux VAE (shift+scale) => latent
latents = encode(vae, transformed_tensor, device=device)
latents_np = latents.cpu().numpy() # shape=(C, H//8, W//8) typically
# 4.f) Save latents to .npz
try:
np.savez_compressed(out_path, latent=latents_np)
except Exception as e:
print(f"[Error] Saving .npz for {image_path} failed: {e}")
if __name__ == "__main__":
main()