|
import torch |
|
from typing import Sequence |
|
import numpy as np |
|
import io |
|
import tarfile |
|
from pytorchvideo.data.encoded_video import EncodedVideo |
|
from omegaconf import DictConfig |
|
from tqdm import tqdm |
|
|
|
from .base_video_dataset import BaseVideoDataset |
|
from typing import Mapping, Sequence |
|
import os |
|
import math |
|
from packaging import version as pver |
|
from PIL import Image |
|
import random |
|
|
|
def euler_to_rotation_matrix(pitch, yaw): |
|
""" |
|
Convert euler angles (pitch, yaw) to a 3x3 rotation matrix. |
|
pitch: rotation around x-axis (in radians) |
|
yaw: rotation around y-axis (in radians) |
|
""" |
|
|
|
R_x = np.array([ |
|
[1, 0, 0], |
|
[0, math.cos(pitch), -math.sin(pitch)], |
|
[0, math.sin(pitch), math.cos(pitch)] |
|
]) |
|
|
|
|
|
R_y = np.array([ |
|
[math.cos(yaw), 0, math.sin(yaw)], |
|
[0, 1, 0], |
|
[-math.sin(yaw), 0, math.cos(yaw)] |
|
]) |
|
|
|
|
|
R = np.dot(R_y, R_x) |
|
return R |
|
|
|
def custom_meshgrid(*args): |
|
|
|
if pver.parse(torch.__version__) < pver.parse('1.10'): |
|
return torch.meshgrid(*args) |
|
else: |
|
return torch.meshgrid(*args, indexing='ij') |
|
|
|
def camera_to_world_to_world_to_camera(camera_to_world): |
|
""" |
|
Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation. |
|
""" |
|
|
|
R = camera_to_world[:3, :3] |
|
T = camera_to_world[:3, 3] |
|
|
|
|
|
world_to_camera = np.eye(4) |
|
|
|
|
|
world_to_camera[:3, :3] = R.T |
|
|
|
|
|
world_to_camera[:3, 3] = -np.dot(R.T, T) |
|
|
|
return world_to_camera |
|
|
|
def euler_to_camera_to_world_matrix(pose): |
|
|
|
x, y, z, pitch, yaw = pose |
|
|
|
pitch = math.radians(pitch) |
|
yaw = math.radians(yaw) |
|
|
|
|
|
R = euler_to_rotation_matrix(pitch, yaw) |
|
|
|
|
|
camera_to_world = np.eye(4) |
|
|
|
|
|
camera_to_world[:3, :3] = R |
|
|
|
|
|
camera_to_world[:3, 3] = [x, y, z] |
|
|
|
return camera_to_world |
|
|
|
def tensor_to_gif(tensor, output_path, fps=10): |
|
""" |
|
Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF. |
|
|
|
Args: |
|
tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255]. |
|
output_path (str): Path to save the output GIF. |
|
fps (int): Frames per second for the GIF. |
|
""" |
|
|
|
if tensor.max() <= 1.0: |
|
tensor = (tensor * 255).byte() |
|
else: |
|
tensor = tensor.byte() |
|
|
|
|
|
frames = tensor.permute(0, 2, 3, 1).cpu().numpy() |
|
|
|
|
|
pil_frames = [Image.fromarray(frame) for frame in frames] |
|
|
|
|
|
pil_frames[0].save( |
|
output_path, |
|
save_all=True, |
|
append_images=pil_frames[1:], |
|
duration=int(1000 / fps), |
|
loop=0 |
|
) |
|
|
|
def get_relative_pose(cam_params, zero_first_frame_scale): |
|
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
|
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
|
source_cam_c2w = abs_c2ws[0] |
|
if zero_first_frame_scale: |
|
cam_to_origin = 0 |
|
else: |
|
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) |
|
target_cam_c2w = np.array([ |
|
[1, 0, 0, 0], |
|
[0, 1, 0, -cam_to_origin], |
|
[0, 0, 1, 0], |
|
[0, 0, 0, 1] |
|
]) |
|
abs2rel = target_cam_c2w @ abs_w2cs[0] |
|
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] |
|
ret_poses = np.array(ret_poses, dtype=np.float32) |
|
return ret_poses |
|
|
|
def ray_condition(K, c2w, H, W, device): |
|
|
|
|
|
|
|
B = K.shape[0] |
|
|
|
j, i = custom_meshgrid( |
|
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
|
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
|
) |
|
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
|
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 |
|
|
|
fx, fy, cx, cy = K.chunk(4, dim=-1) |
|
|
|
zs = torch.ones_like(i) |
|
xs = (i - cx) / fx * zs |
|
ys = (j - cy) / fy * zs |
|
zs = zs.expand_as(ys) |
|
|
|
directions = torch.stack((xs, ys, zs), dim=-1) |
|
directions = directions / directions.norm(dim=-1, keepdim=True) |
|
|
|
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) |
|
rays_o = c2w[..., :3, 3] |
|
rays_o = rays_o[:, :, None].expand_as(rays_d) |
|
|
|
rays_dxo = torch.linalg.cross(rays_o, rays_d) |
|
plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
|
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) |
|
|
|
return plucker |
|
|
|
class Camera(object): |
|
def __init__(self, entry, focal_length=0.35): |
|
self.fx = focal_length |
|
self.fy = focal_length*640/360 |
|
self.cx = 0.5 |
|
self.cy = 0.5 |
|
self.c2w_mat = euler_to_camera_to_world_matrix(entry) |
|
self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat)) |
|
|
|
|
|
ACTION_KEYS = [ |
|
"inventory", |
|
"ESC", |
|
"hotbar.1", |
|
"hotbar.2", |
|
"hotbar.3", |
|
"hotbar.4", |
|
"hotbar.5", |
|
"hotbar.6", |
|
"hotbar.7", |
|
"hotbar.8", |
|
"hotbar.9", |
|
"forward", |
|
"back", |
|
"left", |
|
"right", |
|
"cameraY", |
|
"cameraX", |
|
"jump", |
|
"sneak", |
|
"sprint", |
|
"swapHands", |
|
"attack", |
|
"use", |
|
"pickItem", |
|
"drop", |
|
] |
|
|
|
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor: |
|
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS)) |
|
for i, current_actions in enumerate(actions): |
|
for j, action_key in enumerate(ACTION_KEYS): |
|
if action_key.startswith("camera"): |
|
if action_key == "cameraX": |
|
value = current_actions["camera"][0] |
|
elif action_key == "cameraY": |
|
value = current_actions["camera"][1] |
|
else: |
|
raise ValueError(f"Unknown camera action key: {action_key}") |
|
max_val = 20 |
|
bin_size = 0.5 |
|
num_buckets = int(max_val / bin_size) |
|
value = (value - num_buckets) / num_buckets |
|
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}" |
|
else: |
|
value = current_actions[action_key] |
|
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}" |
|
actions_one_hot[i, j] = value |
|
|
|
return actions_one_hot |
|
|
|
def simpletomulti(actions): |
|
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) |
|
vec_25[actions==1, 11] = 1 |
|
vec_25[actions==2, 16] = -1 |
|
vec_25[actions==3, 16] = 1 |
|
vec_25[actions==4, 15] = -1 |
|
vec_25[actions==5, 15] = 1 |
|
return vec_25 |
|
|
|
def simpletomulti2(actions): |
|
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) |
|
vec_25[actions[:,0]==1, 11] = 1 |
|
vec_25[actions[:,0]==2, 12] = 1 |
|
vec_25[actions[:,4]==11, 16] = -1 |
|
vec_25[actions[:,4]==13, 16] = 1 |
|
vec_25[actions[:,3]==11, 15] = -1 |
|
vec_25[actions[:,3]==13, 15] = 1 |
|
vec_25[actions[:,5]==6, 24] = 1 |
|
vec_25[actions[:,5]==1, 24] = 1 |
|
vec_25[actions[:,1]==1, 13] = 1 |
|
vec_25[actions[:,1]==2, 14] = 1 |
|
vec_25[actions[:,7]==1, 2] = 1 |
|
return vec_25 |
|
|
|
class MinecraftVideoPoseDataset(BaseVideoDataset): |
|
""" |
|
Minecraft dataset |
|
""" |
|
|
|
def __init__(self, cfg: DictConfig, split: str = "training"): |
|
if split == "test": |
|
split = "validation" |
|
super().__init__(cfg, split) |
|
|
|
if hasattr(cfg, "n_frames_valid") and split == "validation": |
|
self.n_frames = cfg.n_frames_valid |
|
|
|
def get_data_paths(self, split): |
|
data_dir = self.save_dir / split |
|
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) |
|
|
|
if len(paths) == 0: |
|
sub_path = os.listdir(data_dir) |
|
for sp in sub_path: |
|
data_dir = self.save_dir / split / sp |
|
paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) |
|
return paths |
|
|
|
def get_data_lengths(self, split): |
|
lengths = [300] * len(self.get_data_paths(split)) |
|
return lengths |
|
|
|
def download_dataset(self) -> Sequence[int]: |
|
from internetarchive import download |
|
|
|
part_suffixes = [ |
|
"aa", |
|
"ab", |
|
"ac", |
|
"ad", |
|
"ae", |
|
"af", |
|
"ag", |
|
"ah", |
|
"ai", |
|
"aj", |
|
"ak", |
|
] |
|
for part_suffix in part_suffixes: |
|
identifier = f"minecraft_marsh_dataset_{part_suffix}" |
|
file_name = f"minecraft.tar.part{part_suffix}" |
|
download(identifier, file_name, destdir=self.save_dir, verbose=True) |
|
|
|
combined_bytes = io.BytesIO() |
|
for part_suffix in part_suffixes: |
|
identifier = f"minecraft_marsh_dataset_{part_suffix}" |
|
file_name = f"minecraft.tar.part{part_suffix}" |
|
part_file = self.save_dir / identifier / file_name |
|
with open(part_file, "rb") as part: |
|
combined_bytes.write(part.read()) |
|
combined_bytes.seek(0) |
|
with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive: |
|
combined_archive.extractall(self.save_dir) |
|
(self.save_dir / "minecraft/test").rename(self.save_dir / "validation") |
|
(self.save_dir / "minecraft/train").rename(self.save_dir / "training") |
|
(self.save_dir / "minecraft").rmdir() |
|
for part_suffix in part_suffixes: |
|
identifier = f"minecraft_marsh_dataset_{part_suffix}" |
|
file_name = f"minecraft.tar.part{part_suffix}" |
|
part_file = self.save_dir / identifier / file_name |
|
part_file.rmdir() |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
max_retries = 1000 |
|
for mr in range(max_retries): |
|
try: |
|
return self.load_data(idx) |
|
except Exception as e: |
|
print(f"{mr} Error: {e}") |
|
|
|
|
|
|
|
|
|
idx = (idx + 1) % self.__len__() |
|
|
|
def load_data(self, idx): |
|
idx = self.idx_remap[idx] |
|
file_idx, frame_idx = self.split_idx(idx) |
|
action_path = self.data_paths[file_idx] |
|
video_path = self.data_paths[file_idx] |
|
|
|
action_path = video_path.with_suffix(".npz") |
|
actions_pool = np.load(action_path)['actions'] |
|
poses_pool = np.load(action_path)['poses'] |
|
|
|
poses_pool[0,1] = poses_pool[1,1] |
|
|
|
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}" |
|
|
|
if len(poses_pool) < len(actions_pool): |
|
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0))) |
|
|
|
actions_pool = simpletomulti2(actions_pool) |
|
video_raw = EncodedVideo.from_path(video_path, decode_audio=False) |
|
|
|
frame_idx = frame_idx + 100 |
|
|
|
if self.split == "validation": |
|
frame_idx = 240 |
|
|
|
total_frame = video_raw.duration.numerator |
|
fps = 10 |
|
total_frame = total_frame * fps / video_raw.duration.denominator |
|
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"] |
|
|
|
video = video.permute(1, 2, 3, 0).numpy() |
|
|
|
if self.split != "validation" and 'degrees' in np.load(action_path).keys(): |
|
degrees = np.load(action_path)['degrees'] |
|
actions_pool[:,16] *= degrees |
|
|
|
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) |
|
|
|
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames]) |
|
pad_len = self.n_frames - len(video) |
|
poses_pool[:,:3] -= poses[:1,:3] |
|
|
|
poses_pool[:,-1] = -poses_pool[:,-1] |
|
poses_pool[:,3:] %= 360 |
|
|
|
poses[:,:3] -= poses[:1,:3] |
|
|
|
poses[:,-1] = -poses[:,-1] |
|
poses[:,3:] %= 360 |
|
|
|
nonterminal = np.ones(self.n_frames) |
|
if len(video) < self.n_frames: |
|
video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0))) |
|
actions = np.pad(actions, ((0, pad_len),)) |
|
poses = np.pad(actions, ((0, pad_len),)) |
|
nonterminal[-pad_len:] = 0 |
|
|
|
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() |
|
|
|
return ( |
|
video[:: self.frame_skip], |
|
actions[:: self.frame_skip], |
|
poses[:: self.frame_skip] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
import torch |
|
from unittest.mock import MagicMock |
|
import tqdm |
|
|
|
cfg = MagicMock() |
|
cfg.resolution = 64 |
|
cfg.external_cond_dim = 0 |
|
cfg.n_frames = 64 |
|
cfg.save_dir = "data/minecraft" |
|
cfg.validation_multiplier = 1 |
|
|
|
dataset = MinecraftVideoDataset(cfg, "training") |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16) |
|
|
|
for batch in tqdm.tqdm(dataloader): |
|
pass |
|
|