Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import os | |
import uuid | |
from pathlib import Path | |
import numpy as np | |
import spaces | |
import gradio as gr | |
import torch | |
from decord import cpu, VideoReader | |
from diffusers.training_utils import set_seed | |
import torch.nn.functional as F | |
import imageio | |
from kornia.filters import canny | |
from kornia.morphology import dilation | |
from third_party import MoGe | |
from geometrycrafter import ( | |
GeometryCrafterDiffPipeline, | |
GeometryCrafterDetermPipeline, | |
PMapAutoencoderKLTemporalDecoder, | |
UNetSpatioTemporalConditionModelVid2vid | |
) | |
from utils.glb_utils import pmap_to_glb | |
from utils.disp_utils import pmap_to_disp | |
examples = [ | |
# process_length: int, | |
# max_res: int, | |
# num_inference_steps: int, | |
# guidance_scale: float, | |
# window_size: int, | |
# decode_chunk_size: int, | |
# overlap: int, | |
["examples/video1.mp4", 60, 640, 5, 1.0, 110, 8, 25], | |
["examples/video2.mp4", 60, 640, 5, 1.0, 110, 8, 25], | |
["examples/video3.mp4", 60, 640, 5, 1.0, 110, 8, 25], | |
["examples/video4.mp4", 60, 640, 5, 1.0, 110, 8, 25], | |
] | |
model_type = 'diff' | |
cache_dir = 'workspace/cache' | |
unet = UNetSpatioTemporalConditionModelVid2vid.from_pretrained( | |
'TencentARC/GeometryCrafter', | |
subfolder='unet_diff' if model_type == 'diff' else 'unet_determ', | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float16, | |
cache_dir=cache_dir | |
).requires_grad_(False).to("cuda", dtype=torch.float16) | |
point_map_vae = PMapAutoencoderKLTemporalDecoder.from_pretrained( | |
'TencentARC/GeometryCrafter', | |
subfolder='point_map_vae', | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32, | |
cache_dir=cache_dir | |
).requires_grad_(False).to("cuda", dtype=torch.float32) | |
prior_model = MoGe( | |
cache_dir=cache_dir, | |
).requires_grad_(False).to('cuda', dtype=torch.float32) | |
if model_type == 'diff': | |
pipe = GeometryCrafterDiffPipeline.from_pretrained( | |
'stabilityai/stable-video-diffusion-img2vid-xt', | |
unet=unet, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
cache_dir=cache_dir | |
).to("cuda") | |
else: | |
pipe = GeometryCrafterDetermPipeline.from_pretrained( | |
'stabilityai/stable-video-diffusion-img2vid-xt', | |
unet=unet, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
cache_dir=cache_dir | |
).to("cuda") | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
except Exception as e: | |
print(e) | |
print("Xformers is not enabled") | |
# bugs at https://github.com/continue-revolution/sd-webui-animatediff/issues/101 | |
# pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_attention_slicing() | |
def read_video_frames(video_path, process_length, max_res): | |
print("==> processing video: ", video_path) | |
vid = VideoReader(video_path, ctx=cpu(0)) | |
fps = vid.get_avg_fps() | |
print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:])) | |
original_height, original_width = vid.get_batch([0]).shape[1:3] | |
if max(original_height, original_width) > max_res: | |
scale = max_res / max(original_height, original_width) | |
original_height, original_width = round(original_height * scale), round(original_width * scale) | |
else: | |
scale = 1.0 | |
height = round(original_height * scale / 64) * 64 | |
width = round(original_width * scale / 64) * 64 | |
vid = VideoReader(video_path, ctx=cpu(0), width=original_width, height=original_height) | |
frames_idx = list(range(0, min(len(vid), process_length) if process_length != -1 else len(vid))) | |
print( | |
f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}" | |
) | |
frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0 | |
return frames, height, width, fps | |
def compute_edge_mask(depth: torch.Tensor, edge_dilation_radius: int): | |
magnitude, edges = canny(depth[None, None, :, :], low_threshold=0.4, high_threshold=0.5) | |
magnitude = magnitude[0, 0] | |
edges = edges[0, 0] | |
mask = (edges > 0).float() | |
mask = dilation(mask[None, None, :, :], torch.ones((edge_dilation_radius,edge_dilation_radius), device=mask.device)) | |
return mask[0, 0] > 0.5 | |
def infer_geometry( | |
video: str, | |
process_length: int, | |
max_res: int, | |
num_inference_steps: int, | |
guidance_scale: float, | |
window_size: int, | |
decode_chunk_size: int, | |
overlap: int, | |
downsample_ratio: float = 1.0, # downsample pcd for visualization | |
remove_edge: bool = True, # remove edge for visualization | |
save_folder: str = os.path.join('workspace', 'GeometryCrafterApp'), | |
): | |
try: | |
run_id = str(uuid.uuid4()) | |
set_seed(42) | |
pipe.enable_xformers_memory_efficient_attention() | |
frames, height, width, fps = read_video_frames(video, process_length, max_res) | |
aspect_ratio = width / height | |
assert 0.5 <= aspect_ratio and aspect_ratio <= 2.0, "Error! The aspect ratio of video must fall in range [0.5, 2]." | |
frames_tensor = torch.tensor(frames.astype("float32"), device='cuda').float().permute(0, 3, 1, 2) | |
window_size = min(window_size, len(frames)) | |
if window_size == len(frames): | |
overlap = 0 | |
point_maps, valid_masks = pipe( | |
frames_tensor, | |
point_map_vae, | |
prior_model, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
window_size=window_size, | |
decode_chunk_size=decode_chunk_size, | |
overlap=overlap, | |
force_projection=True, | |
force_fixed_focal=True, | |
) | |
frames_tensor = frames_tensor.cpu() | |
point_maps = point_maps.cpu() | |
valid_masks = valid_masks.cpu() | |
os.makedirs(save_folder, exist_ok=True) | |
gc.collect() | |
torch.cuda.empty_cache() | |
output_npz_path = Path(save_folder, run_id, f'point_maps.npz') | |
output_npz_path.parent.mkdir(exist_ok=True) | |
np.savez_compressed( | |
output_npz_path, | |
point_map=point_maps.cpu().numpy().astype(np.float16), | |
mask=valid_masks.cpu().numpy().astype(np.bool_) | |
) | |
output_disp_path = Path(save_folder, run_id, f'disp.mp4') | |
output_disp_path.parent.mkdir(exist_ok=True) | |
colored_disp = pmap_to_disp(point_maps, valid_masks) | |
imageio.mimsave( | |
output_disp_path, (colored_disp*255).cpu().numpy().astype(np.uint8), fps=fps, macro_block_size=1) | |
# downsample for visualization | |
if downsample_ratio > 1.0: | |
H, W = point_maps.shape[1:3] | |
H, W = round(H / downsample_ratio), round(W / downsample_ratio) | |
point_maps = F.interpolate(point_maps.permute(0,3,1,2), (H, W)).permute(0,2,3,1) | |
frames = F.interpolate(frames_tensor, (H, W)).permute(0,2,3,1) | |
valid_masks = F.interpolate(valid_masks.float()[:, None], (H, W))[:, 0] > 0.5 | |
else: | |
H, W = point_maps.shape[1:3] | |
frames = frames_tensor.permute(0,2,3,1) | |
if remove_edge: | |
for i in range(len(valid_masks)): | |
edge_mask = compute_edge_mask(point_maps[i, :, :, 2], 3) | |
valid_masks[i] = valid_masks[i] & (~edge_mask) | |
indices = np.linspace(0, len(point_maps)-1, 6) | |
indices = np.round(indices).astype(np.int32) | |
mesh_seqs, frame_seqs = [], [] | |
for index in indices: | |
valid_mask = valid_masks[index].cpu().numpy() | |
point_map = point_maps[index].cpu().numpy() | |
frame = frames[index].cpu().numpy() | |
output_glb_path = Path(save_folder, run_id, f'{index:04}.glb') | |
output_glb_path.parent.mkdir(exist_ok=True) | |
glbscene = pmap_to_glb(point_map, valid_mask, frame) | |
glbscene.export(file_obj=output_glb_path) | |
mesh_seqs.append(output_glb_path) | |
frame_seqs.append(index) | |
gc.collect() | |
torch.cuda.empty_cache() | |
return [ | |
gr.Model3D(value=mesh_seqs[idx], label=f"Frame: {frame_seqs[idx]}") for idx in range(len(frame_seqs)) | |
] + [ | |
gr.Video(value=output_disp_path, label="Disparity", interactive=False), | |
gr.DownloadButton("Download Npz File", value=output_npz_path) | |
] | |
except Exception as e: | |
gc.collect() | |
torch.cuda.empty_cache() | |
raise gr.Error(str(e)) | |
# return [ | |
# gr.Model3D( | |
# label="Point Map", | |
# clear_color=[1.0, 1.0, 1.0, 1.0], | |
# interactive=False | |
# ), | |
# gr.Video(label="Disparity", interactive=False), | |
# gr.DownloadButton("Download Npz File", visible=False) | |
# ] | |
def build_demo(): | |
with gr.Blocks(analytics_enabled=False) as gradio_demo: | |
gr.HTML( | |
""" | |
<div align='center'> | |
<h1> GeometryCrafter: Consistent Geometry Estimation for Open-world Videos with Diffusion Priors </h1> \ | |
<span style='font-size:18px'>\ | |
<a href='https://scholar.google.com/citations?user=zHp0rMIAAAAJ'>Tian-Xing Xu</a>, \ | |
<a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ'>Xiangjun Gao</a>, \ | |
<a href='https://wbhu.github.io'>Wenbo Hu</a>, \ | |
<a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>, \ | |
<a href='https://scholar.google.com/citations?user=AWtV-EQAAAAJ'>Song-Hai Zhang</a>,\ | |
<a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ'>Ying Shan</a>\ | |
</span> \ | |
<br> | |
<br> | |
<span style='font-size:18px'>If you find GeometryCrafter useful, please help ⭐ the \ | |
<a style='font-size:18px' href='https://github.com/TencentARC/GeometryCrafter/'>[Github Repo]</a>\ | |
, which is important to Open-Source projects. Thanks!\ | |
<a href='https://arxiv.org/abs/2504.01016'>[arXiv]</a> \ | |
<a href='https://geometrycrafter.github.io'>[Project Page]</a> \ | |
</span> | |
</div> | |
""" | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
input_video = gr.Video( | |
label="Input Video", | |
sources=['upload'] | |
) | |
with gr.Row(equal_height=False): | |
with gr.Accordion("Advanced Settings", open=False): | |
process_length = gr.Slider( | |
label="process length", | |
minimum=-1, | |
maximum=280, | |
value=110, | |
step=1, | |
) | |
max_res = gr.Slider( | |
label="max resolution", | |
minimum=512, | |
maximum=2048, | |
value=1024, | |
step=64, | |
) | |
num_denoising_steps = gr.Slider( | |
label="num denoising steps", | |
minimum=1, | |
maximum=25, | |
value=5, | |
step=1, | |
) | |
guidance_scale = gr.Slider( | |
label="cfg scale", | |
minimum=1.0, | |
maximum=1.2, | |
value=1.0, | |
step=0.1, | |
) | |
window_size = gr.Slider( | |
label="shift window size", | |
minimum=10, | |
maximum=110, | |
value=110, | |
step=10, | |
) | |
decode_chunk_size = gr.Slider( | |
label="decode chunk size", | |
minimum=1, | |
maximum=16, | |
value=6, | |
step=1, | |
) | |
overlap = gr.Slider( | |
label="overlap", | |
minimum=1, | |
maximum=50, | |
value=25, | |
step=1, | |
) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(scale=1): | |
output_disp_video = gr.Video( | |
label="Disparity", | |
interactive=False | |
) | |
download_btn = gr.DownloadButton("Download Npz File") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
output_point_map0 = gr.Model3D( | |
label="Point Map 0", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
# display_mode="solid" | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
output_point_map1 = gr.Model3D( | |
label="Point Map 1", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
# display_mode="solid" | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
output_point_map2 = gr.Model3D( | |
label="Point Map 2", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
# display_mode="solid" | |
interactive=False | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
output_point_map3 = gr.Model3D( | |
label="Point Map 3", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
# display_mode="solid" | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
output_point_map4 = gr.Model3D( | |
label="Point Map 4", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
# display_mode="solid" | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
output_point_map5 = gr.Model3D( | |
label="Point Map 5", | |
clear_color=[1.0, 1.0, 1.0, 1.0], | |
# display_mode="solid" | |
interactive=False | |
) | |
gr.Markdown( | |
""" | |
Parameters: | |
- `process length`: only process the first `process length` frames for geometry estimation, `-1` denotes the whole video | |
- `max resolution`: downsample the long side to the target resolution (if exceed) before processing to save memory usage | |
- `num denoising steps`: the number of denoising iterations, `5` is enough for most cases | |
- `cfg scale`: recommended as the default value `1.0` | |
- `shift window size`: recommended as the default value `110` | |
- `decode chunk size`: chunk size for VAE decoder, you can set it as `4` or `6` to save memory usage | |
- `overlap`: recommended as the default value `25` | |
""" | |
) | |
gr.Examples( | |
examples=examples, | |
fn=infer_geometry, | |
inputs=[ | |
input_video, | |
process_length, | |
max_res, | |
num_denoising_steps, | |
guidance_scale, | |
window_size, | |
decode_chunk_size, | |
overlap, | |
], | |
outputs=[ | |
output_point_map0, output_point_map1, output_point_map2, | |
output_point_map3, output_point_map4, output_point_map5, | |
output_disp_video, download_btn | |
], | |
# cache_examples="lazy", | |
) | |
gr.HTML( | |
""" | |
<span style='font-size:18px'>Note: \ | |
For time quota consideration, we set the default parameters to be more efficient here,\ | |
with a trade-off of shorter video length and slightly lower quality.\ | |
You may adjust the parameters according to our \ | |
<a href='https://github.com/TencentARC/GeometryCrafter/'>[Github Repo]</a>\ | |
for better results if you have enough time quota. We only provide a simplified visualization\ | |
script in this page due to the lack of support for point cloud sequences. You can download\ | |
the npz file and open it with Viser backend in our repo for better visualization. \ | |
</span> | |
""" | |
) | |
generate_btn.click( | |
fn=infer_geometry, | |
inputs=[ | |
input_video, | |
process_length, | |
max_res, | |
num_denoising_steps, | |
guidance_scale, | |
window_size, | |
decode_chunk_size, | |
overlap, | |
], | |
outputs=[ | |
output_point_map0, output_point_map1, output_point_map2, | |
output_point_map3, output_point_map4, output_point_map5, | |
output_disp_video, download_btn | |
], | |
) | |
return gradio_demo | |
if __name__ == "__main__": | |
demo = build_demo() | |
demo.queue() | |
# demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False) | |
demo.launch(share=True) |