|
import gradio as gr |
|
import time |
|
|
|
import sys |
|
import subprocess |
|
import time |
|
from pathlib import Path |
|
|
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
from omegaconf.omegaconf import open_dict |
|
|
|
from utils.print_utils import cyan |
|
from utils.ckpt_utils import download_latest_checkpoint, is_run_id |
|
from utils.cluster_utils import submit_slurm_job |
|
from utils.distributed_utils import is_rank_zero |
|
import numpy as np |
|
import torch |
|
from datasets.video.minecraft_video_dataset import * |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
import subprocess |
|
from PIL import Image |
|
from datetime import datetime |
|
|
|
ACTION_KEYS = [ |
|
"inventory", |
|
"ESC", |
|
"hotbar.1", |
|
"hotbar.2", |
|
"hotbar.3", |
|
"hotbar.4", |
|
"hotbar.5", |
|
"hotbar.6", |
|
"hotbar.7", |
|
"hotbar.8", |
|
"hotbar.9", |
|
"forward", |
|
"back", |
|
"left", |
|
"right", |
|
"cameraY", |
|
"cameraX", |
|
"jump", |
|
"sneak", |
|
"sprint", |
|
"swapHands", |
|
"attack", |
|
"use", |
|
"pickItem", |
|
"drop", |
|
] |
|
|
|
|
|
KEY_TO_ACTION = { |
|
"Q": ("forward", 1), |
|
"E": ("back", 1), |
|
"W": ("cameraY", -1), |
|
"S": ("cameraY", 1), |
|
"A": ("cameraX", -1), |
|
"D": ("cameraX", 1), |
|
"U": ("drop", 1), |
|
"N": ("noop", 1), |
|
"1": ("hotbar.1", 1), |
|
} |
|
|
|
def parse_input_to_tensor(input_str): |
|
""" |
|
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation |
|
of the corresponding action key. |
|
|
|
Args: |
|
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS"). |
|
|
|
Returns: |
|
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action. |
|
""" |
|
|
|
seq_len = len(input_str) |
|
|
|
|
|
action_tensor = torch.zeros((seq_len, 25)) |
|
|
|
|
|
for i, char in enumerate(input_str): |
|
action, value = KEY_TO_ACTION.get(char.upper()) |
|
if action and action in ACTION_KEYS: |
|
index = ACTION_KEYS.index(action) |
|
action_tensor[i, index] = value |
|
|
|
return action_tensor |
|
|
|
def load_image_as_tensor(image_path: str) -> torch.Tensor: |
|
""" |
|
Load an image and convert it to a 0-1 normalized tensor. |
|
|
|
Args: |
|
image_path (str): Path to the image file. |
|
|
|
Returns: |
|
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1]. |
|
""" |
|
if isinstance(image_path, str): |
|
image = Image.open(image_path).convert("RGB") |
|
else: |
|
image = image_path |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
]) |
|
return transform(image) |
|
|
|
def run_local(cfg: DictConfig): |
|
|
|
from experiments import build_experiment |
|
|
|
|
|
hydra_cfg = hydra.core.hydra_config.HydraConfig.get() |
|
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) |
|
|
|
with open_dict(cfg): |
|
if cfg_choice["experiment"] is not None: |
|
cfg.experiment._name = cfg_choice["experiment"] |
|
if cfg_choice["dataset"] is not None: |
|
cfg.dataset._name = cfg_choice["dataset"] |
|
if cfg_choice["algorithm"] is not None: |
|
cfg.algorithm._name = cfg_choice["algorithm"] |
|
|
|
|
|
experiment = build_experiment(cfg, None, cfg.checkpoint_path) |
|
return experiment.exec_interactive(cfg.experiment.tasks[0]) |
|
|
|
memory_frames = [] |
|
memory_curr_frame = 0 |
|
input_history = "" |
|
ICE_PLAINS_IMAGE = "assets/ice_plains.png" |
|
DESERT_IMAGE = "assets/desert.png" |
|
SAVANNA_IMAGE = "assets/savanna.png" |
|
PLAINS_IMAGE = "assets/plans.png" |
|
PLACE_IMAGE = "assets/place.png" |
|
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png" |
|
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png" |
|
|
|
DEFAULT_IMAGE = ICE_PLAINS_IMAGE |
|
device = "cuda:0" |
|
|
|
def save_video(frames, path="output.mp4", fps=10): |
|
h, w, _ = frames[0].shape |
|
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h)) |
|
for frame in frames: |
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
|
out.release() |
|
|
|
ffmpeg_cmd = [ |
|
"ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path |
|
] |
|
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
return path |
|
|
|
@hydra.main( |
|
version_base=None, |
|
config_path="configurations", |
|
config_name="config", |
|
) |
|
def run(cfg: DictConfig): |
|
algo = run_local(cfg) |
|
algo.to("cuda:0") |
|
|
|
actions = torch.zeros((1, 25)) |
|
poses = torch.zeros((1, 5)) |
|
|
|
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE)) |
|
|
|
_ = algo.interactive(memory_frames[0], |
|
actions[0], |
|
poses[0], |
|
memory_curr_frame, |
|
device="cuda:0") |
|
|
|
def set_denoising_steps(denoising_steps, sampling_timesteps_state): |
|
algo.sampling_timesteps = denoising_steps |
|
algo.diffusion_model.sampling_timesteps = denoising_steps |
|
sampling_timesteps_state = denoising_steps |
|
print("set denoising steps to", algo.sampling_timesteps) |
|
return sampling_timesteps_state |
|
|
|
|
|
def update_image_and_log(keys): |
|
actions = parse_input_to_tensor(keys) |
|
global input_history |
|
global memory_curr_frame |
|
for i in range(len(actions)): |
|
memory_curr_frame += 1 |
|
new_frame = algo.interactive(memory_frames[0], |
|
actions[i], |
|
None, |
|
memory_curr_frame, |
|
device="cuda:0") |
|
|
|
memory_frames.append(new_frame) |
|
|
|
out_video = torch.stack(memory_frames) |
|
out_video = out_video.permute(0,2,3,1).numpy() |
|
out_video = np.clip(out_video, a_min=0.0, a_max=1.0) |
|
out_video = (out_video * 255).astype(np.uint8) |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
os.makedirs("outputs_gradio", exist_ok=True) |
|
filename = f"outputs_gradio/{timestamp}.mp4" |
|
save_video(out_video, filename) |
|
|
|
input_history += keys |
|
return out_video[-1], filename, input_history |
|
|
|
def reset(): |
|
global memory_curr_frame |
|
global input_history |
|
global memory_frames |
|
|
|
algo.reset() |
|
memory_frames = [] |
|
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE)) |
|
memory_curr_frame = 0 |
|
input_history = "" |
|
|
|
_ = algo.interactive(memory_frames[0], |
|
actions[0], |
|
poses[0], |
|
memory_curr_frame, |
|
device="cuda:0") |
|
return input_history, DEFAULT_IMAGE |
|
|
|
def on_image_click(SELECTED_IMAGE): |
|
global DEFAULT_IMAGE |
|
DEFAULT_IMAGE = SELECTED_IMAGE |
|
reset() |
|
return SELECTED_IMAGE |
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# WORLDMEM: Long-term Consistent World Generation with Memory |
|
|
|
<div style="text-align: center;"> |
|
<!-- Public Website --> |
|
<a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/"> |
|
<img src="https://img.shields.io/badge/public_website-8A2BE2"> |
|
</a> |
|
|
|
<!-- GitHub Stars --> |
|
<a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything"> |
|
<img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social"> |
|
</a> |
|
|
|
<!-- Project Page --> |
|
<a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/"> |
|
<img src="https://img.shields.io/badge/project_page-blue"> |
|
</a> |
|
|
|
<!-- arXiv Paper --> |
|
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX"> |
|
<img src="https://img.shields.io/badge/arXiv-paper-red"> |
|
</a> |
|
</div> |
|
|
|
""" |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
video_display = gr.Video(autoplay=True, loop=True) |
|
image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame") |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=2): |
|
input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1) |
|
log_output = gr.Textbox(label="History Log", interactive=False) |
|
with gr.Column(scale=1): |
|
slider = gr.Slider(minimum=10, maximum=50, value=algo.sampling_timesteps, step=1, label="Denoising Steps") |
|
submit_button = gr.Button("Generate") |
|
reset_btn = gr.Button("Reset") |
|
|
|
sampling_timesteps_state = gr.State(algo.sampling_timesteps) |
|
|
|
example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD", |
|
"DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEEAAAAAAAAAAAAAAAAAAAAAA"] |
|
|
|
def set_action(action): |
|
return action |
|
|
|
gr.Markdown("### Action sequence examples.") |
|
with gr.Row(): |
|
buttons = [] |
|
for action in example_actions[:2]: |
|
with gr.Column(scale=len(action)): |
|
buttons.append(gr.Button(action)) |
|
with gr.Row(): |
|
for action in example_actions[2:4]: |
|
with gr.Column(scale=len(action)): |
|
buttons.append(gr.Button(action)) |
|
with gr.Row(): |
|
for action in example_actions[4:5]: |
|
with gr.Column(scale=len(action)): |
|
buttons.append(gr.Button(action)) |
|
|
|
for button, action in zip(buttons, example_actions): |
|
button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box) |
|
|
|
|
|
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.") |
|
|
|
with gr.Row(): |
|
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains") |
|
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert") |
|
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna") |
|
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains") |
|
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains") |
|
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place") |
|
|
|
gr.Markdown( |
|
""" |
|
## Instructions & Notes: |
|
|
|
1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin. |
|
2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel. |
|
3. Click **"Reset"** to clear the current sequence and start fresh. |
|
4. Action sequences can be composed using the following keys: |
|
- W: turn up |
|
- S: turn down |
|
- A: turn left |
|
- D: turn right |
|
- Q: move forward |
|
- E: move backward |
|
- N: no-op (do nothing) |
|
- 1: switch to hotbar 1 |
|
- U: use item |
|
5. Higher denoising steps produce more detailed results but take longer. **20 steps** is a good balance between quality and speed. |
|
6. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()! |
|
7. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **zeqixiao1@gmail.com**. |
|
""" |
|
) |
|
|
|
submit_button.click(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output]) |
|
reset_btn.click(reset, outputs=[log_output, image_display]) |
|
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display) |
|
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display) |
|
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display) |
|
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display) |
|
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display) |
|
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display) |
|
|
|
slider.change(fn=set_denoising_steps, inputs=[slider, sampling_timesteps_state], outputs=sampling_timesteps_state) |
|
|
|
|
|
demo.launch(share=True) |
|
demo.launch(server_name="0.0.0.0", server_port=30066) |
|
|
|
if __name__ == "__main__": |
|
run() |
|
|