import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import spaces import gradio as gr import torch import numpy as np from PIL import Image from typing import List from torchvision import transforms from torchvision.transforms.functional import to_pil_image from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer ) from diffusers import DDPMScheduler, AutoencoderKL from accelerate import init_empty_weights from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from utils_mask import get_mask_location from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation # Initialize components tensor_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) mask[binary_mask] = 1 return Image.fromarray((mask * 255).astype(np.uint8)) base_path = 'yisol/IDM-VTON' example_path = os.path.join(os.path.dirname(__file__), 'example') # Memory-optimized model loading torch_dtype = torch.float16 with init_empty_weights(): unet = UNet2DConditionModel.from_pretrained( base_path, subfolder="unet", torch_dtype=torch_dtype, ) UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( base_path, subfolder="unet_encoder", torch_dtype=torch_dtype, ) text_encoder_one = CLIPTextModel.from_pretrained( base_path, subfolder="text_encoder", torch_dtype=torch_dtype, ) text_encoder_two = CLIPTextModelWithProjection.from_pretrained( base_path, subfolder="text_encoder_2", torch_dtype=torch_dtype, ) # Move encoders to CPU until needed text_encoder_one.to("cpu") text_encoder_two.to("cpu") UNet_Encoder.to("cpu") vae = AutoencoderKL.from_pretrained( base_path, subfolder="vae", torch_dtype=torch_dtype, ) pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, torch_dtype=torch_dtype, device_map="auto", low_cpu_mem_usage=True, offload_folder="offload", ) # Add after pipe = TryonPipeline.from_pretrained(...) pipe.enable_model_cpu_offload() # Better than sequential offload pipe.fuse_qkv_projections() # Reduces memory fragmentation pipe.enable_vae_slicing() pipe.enable_vae_tiling() pipe.enable_attention_slicing(1) pipe.unet_encoder = UNet_Encoder torch.backends.cuda.matmul.allow_tf32 = True torch.set_float32_matmul_precision('medium') parsing_model = None openpose_model = None @spaces.GPU def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed): global parsing_model, openpose_model if openpose_model is None: openpose_model = OpenPose(0) if parsing_model is None: parsing_model = Parsing(0) device = "cuda" PROCESS_SIZE = (384, 512) if is_checked: openpose_model.preprocessor.body_estimation.model.to(device) if pipe.device != torch.device(device): pipe.to(device) pipe.unet_encoder.to(device) text_encoder_one.to(device) text_encoder_two.to(device) # Move VAE to GPU just before inference pipe.vae.to(device) garm_img = garm_img.convert("RGB").resize(PROCESS_SIZE) human_img_orig = dict["background"].convert("RGB") # Image cropping logic if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) // 2 top = (height - target_height) // 2 right = left + target_width bottom = top + target_height cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize(PROCESS_SIZE) else: human_img = human_img_orig.resize(PROCESS_SIZE) # Mask generation if is_checked: resized_img = human_img.resize((384, 512)) keypoints = openpose_model(resized_img) model_parse, _ = parsing_model(resized_img) mask, _ = get_mask_location('hd', "upper_body", model_parse, keypoints) mask = mask.resize(PROCESS_SIZE) else: mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize(PROCESS_SIZE)) # Mask processing mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transform(human_img) mask_gray = to_pil_image((mask_gray + 1.0) / 2.0) # Pose estimation human_img_arg = _apply_exif_orientation(human_img.resize((384, 512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args([ 'show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda' ]) pose_img = args.func(args, human_img_arg)[:, :, ::-1] pose_img = Image.fromarray(pose_img).resize(PROCESS_SIZE) # Inference with torch.inference_mode(), torch.autocast("cuda"): prompt = f"model is wearing {garment_des}" negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" # Prompt encoding prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) # Garment encoding cloth_prompt = f"a photo of {garment_des}" prompt_embeds_c, *_ = pipe.encode_prompt( [cloth_prompt], num_images_per_prompt=1, do_classifier_free_guidance=False, ) # Prepare inputs pose_tensor = tensor_transform(pose_img).unsqueeze(0).to(device, torch.float16) garm_tensor = tensor_transform(garm_img).unsqueeze(0).to(device, torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed != -1 else None # Run pipeline images = pipe( prompt_embeds=prompt_embeds.to(device, torch.float16), negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16), num_inference_steps=int(denoise_steps), generator=generator, pose_img=pose_tensor, text_embeds_cloth=prompt_embeds_c.to(device, torch.float16), cloth=garm_tensor, mask_image=mask, image=human_img, height=PROCESS_SIZE[1], width=PROCESS_SIZE[0], ip_adapter_image=garm_img, guidance_scale=2.0, )[0] del mask, mask_gray, pose_tensor, garm_tensor, prompt_embeds, prompt_embeds_c torch.cuda.empty_cache() torch.cuda.ipc_collect() # Cleanup torch.cuda.empty_cache() if is_checked_crop: out_img = images[0].resize(crop_size) human_img_orig.paste(out_img, (left, top)) return human_img_orig, mask_gray return images[0], mask_gray # Gradio interface setup image_blocks = gr.Blocks() with image_blocks as demo: gr.Markdown("## IDM-VTON 👕👔👚") gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)") with gr.Row(): with gr.Column(): imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True) with gr.Row(): is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True) with gr.Row(): is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False) #example = gr.Examples( # inputs=imgs, # examples_per_page=10, # examples=human_ex_list #) with gr.Column(): garm_img = gr.Image(label="Garment", sources='upload', type="pil") with gr.Row(elem_id="prompt-container"): with gr.Row(): prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt") # example = gr.Examples( # inputs=garm_img, # examples_per_page=8, # examples=garm_list_path) with gr.Column(): # image_out = gr.Image(label="Output", elem_id="output-img", height=400) masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False) with gr.Column(): # image_out = gr.Image(label="Output", elem_id="output-img", height=400) image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False) with gr.Column(): try_button = gr.Button(value="Try-on") with gr.Accordion(label="Advanced Settings", open=False): with gr.Row(): denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=20, step=1) seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42) try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon') demo.queue(max_size=2).launch( server_name="0.0.0.0", share=False, max_threads=1 )