huanngzh's picture
update
30d688e
import os
import random
import shutil
import subprocess
from typing import List
import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from inference_tg2mv_sdxl import prepare_pipeline, run_pipeline
from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
# install others
subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
MAX_SEED = np.iinfo(np.int32).max
NUM_VIEWS = 6
HEIGHT = 768
WIDTH = 768
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_DIR, exist_ok=True)
HEADER = """
# 🔮 Text to Texture with [MV-Adapter](https://github.com/huanngzh/MV-Adapter)
## State-of-the-art Open Source Texture Generation Using Multi-View Diffusion Model
<p style="font-size: 1.1em;">By <a href="https://www.tripo3d.ai/" style="color: #1E90FF; text-decoration: none; font-weight: bold;">Tripo</a></p>
"""
EXAMPLES = [
[
"examples/001.glb",
"Mater, a rusty and beat-up tow truck from the 2006 Disney/Pixar animated film 'Cars', with a rusty brown exterior, big blue eyes.",
],
[
"examples/002.glb",
"Optimus Prime, a character from Transformers, with blue, red and gray colors, and has a flame-like pattern on the body",
],
]
# MV-Adapter
pipe = prepare_pipeline(
base_model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
unet_model=None,
lora_model=None,
adapter_path="huanngzh/mv-adapter",
scheduler=None,
num_views=NUM_VIEWS,
device=DEVICE,
dtype=DTYPE,
)
if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
hf_hub_download(
"dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints"
)
if not os.path.exists("checkpoints/big-lama.pt"):
subprocess.run(
"wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
shell=True,
check=True,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
def start_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(save_dir, exist_ok=True)
print("start session, mkdir", save_dir)
def end_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(save_dir)
def get_random_hex():
random_bytes = os.urandom(8)
random_hex = random_bytes.hex()
return random_hex
def get_random_seed(randomize_seed, seed):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=90)
@torch.no_grad()
def run_mvadapter(
mesh_path,
prompt,
seed=42,
guidance_scale=7.0,
num_inference_steps=30,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
progress=gr.Progress(track_tqdm=True),
):
if isinstance(seed, str):
try:
seed = int(seed.strip())
except ValueError:
seed = 42
images, _, _ = run_pipeline(
pipe,
mesh_path=mesh_path,
num_views=NUM_VIEWS,
text=prompt,
height=HEIGHT,
width=WIDTH,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
negative_prompt=negative_prompt,
device=DEVICE,
)
torch.cuda.empty_cache()
return images
@spaces.GPU(duration=90)
@torch.no_grad()
def run_texturing(
mesh_path: str,
mv_images: List[Image.Image],
uv_unwarp: bool,
preprocess_mesh: bool,
uv_size: int,
req: gr.Request,
):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
mv_images = [item[0] for item in mv_images]
make_image_grid(mv_images, rows=1).save(mv_image_path)
from texture import ModProcessConfig, TexturePipeline
texture_pipe = TexturePipeline(
upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
inpaint_ckpt_path="checkpoints/big-lama.pt",
device=DEVICE,
)
textured_glb_path = texture_pipe(
mesh_path=mesh_path,
save_dir=save_dir,
save_name=f"texture_mesh_{get_random_hex()}",
uv_unwarp=uv_unwarp,
preprocess_mesh=preprocess_mesh,
uv_size=uv_size,
rgb_path=mv_image_path,
rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
).shaded_model_save_path
torch.cuda.empty_cache()
return textured_glb_path, textured_glb_path
with gr.Blocks(title="MVAdapter") as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
input_mesh = gr.Model3D(label="Input 3D mesh")
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt")
with gr.Accordion("Generation Settings", open=False):
seed = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=0, value=0
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=8,
maximum=50,
step=1,
value=30,
)
guidance_scale = gr.Slider(
label="CFG scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=7.0,
)
with gr.Accordion("Texture Settings", open=False):
with gr.Row():
uv_unwarp = gr.Checkbox(label="Unwarp UV", value=True)
preprocess_mesh = gr.Checkbox(label="Preprocess Mesh", value=False)
uv_size = gr.Slider(
label="UV Size", minimum=1024, maximum=8192, step=512, value=4096
)
gen_button = gr.Button("Generate Texture", variant="primary")
examples = gr.Examples(examples=EXAMPLES, inputs=[input_mesh, prompt])
with gr.Column():
mv_result = gr.Gallery(
label="Multi-View Results",
show_label=False,
columns=[3],
rows=[2],
object_fit="contain",
height="auto",
type="pil",
)
textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
gen_button.click(
get_random_seed, inputs=[randomize_seed, seed], outputs=[seed]
).then(
run_mvadapter,
inputs=[
input_mesh,
prompt,
seed,
guidance_scale,
num_inference_steps,
],
outputs=[mv_result],
).then(
run_texturing,
inputs=[input_mesh, mv_result, uv_unwarp, preprocess_mesh, uv_size],
outputs=[textured_model_output, download_glb],
).then(
lambda: gr.Button(interactive=True), outputs=[download_glb]
)
demo.load(start_session)
demo.unload(end_session)
demo.launch()