worldmem / datasets /video /minecraft_video_dataset_pose.py
xizaoqu
init
27ca8b3
raw
history blame
14 kB
import torch
from typing import Sequence
import numpy as np
import io
import tarfile
from pytorchvideo.data.encoded_video import EncodedVideo
from omegaconf import DictConfig
from tqdm import tqdm
from .base_video_dataset import BaseVideoDataset
from typing import Mapping, Sequence
import os
import math
from packaging import version as pver
from PIL import Image
import random
def euler_to_rotation_matrix(pitch, yaw):
"""
Convert euler angles (pitch, yaw) to a 3x3 rotation matrix.
pitch: rotation around x-axis (in radians)
yaw: rotation around y-axis (in radians)
"""
# Rotation matrix around x-axis (pitch)
R_x = np.array([
[1, 0, 0],
[0, math.cos(pitch), -math.sin(pitch)],
[0, math.sin(pitch), math.cos(pitch)]
])
# Rotation matrix around y-axis (yaw)
R_y = np.array([
[math.cos(yaw), 0, math.sin(yaw)],
[0, 1, 0],
[-math.sin(yaw), 0, math.cos(yaw)]
])
# Combined rotation matrix
R = np.dot(R_y, R_x)
return R
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def camera_to_world_to_world_to_camera(camera_to_world):
"""
Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation.
"""
# Extract rotation (R) and translation (T)
R = camera_to_world[:3, :3]
T = camera_to_world[:3, 3]
# Calculate World-to-Camera (inverse) matrix
world_to_camera = np.eye(4)
# The rotation part of World-to-Camera is the transpose of Camera-to-World's rotation
world_to_camera[:3, :3] = R.T
# The translation part is the negative of the rotated translation
world_to_camera[:3, 3] = -np.dot(R.T, T)
return world_to_camera
def euler_to_camera_to_world_matrix(pose):
x, y, z, pitch, yaw = pose
# Convert pitch and yaw to radians
pitch = math.radians(pitch)
yaw = math.radians(yaw)
# Get the rotation matrix from Euler angles
R = euler_to_rotation_matrix(pitch, yaw)
# Create the 4x4 transformation matrix (rotation + translation)
camera_to_world = np.eye(4)
# Set the rotation part (upper 3x3)
camera_to_world[:3, :3] = R
# Set the translation part (last column)
camera_to_world[:3, 3] = [x, y, z]
return camera_to_world
def tensor_to_gif(tensor, output_path, fps=10):
"""
Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF.
Args:
tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255].
output_path (str): Path to save the output GIF.
fps (int): Frames per second for the GIF.
"""
# Ensure the tensor is in [0, 255] range
if tensor.max() <= 1.0:
tensor = (tensor * 255).byte()
else:
tensor = tensor.byte()
# Convert tensor to numpy array and rearrange to (F, H, W, 3)
frames = tensor.permute(0, 2, 3, 1).cpu().numpy()
# Convert frames to PIL Images
pil_frames = [Image.fromarray(frame) for frame in frames]
# Save as GIF
pil_frames[0].save(
output_path,
save_all=True,
append_images=pil_frames[1:],
duration=int(1000 / fps),
loop=0
)
def get_relative_pose(cam_params, zero_first_frame_scale):
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
source_cam_c2w = abs_c2ws[0]
if zero_first_frame_scale:
cam_to_origin = 0
else:
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def ray_condition(K, c2w, H, W, device):
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.linalg.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
return plucker
class Camera(object):
def __init__(self, entry, focal_length=0.35):
self.fx = focal_length # 0.35 correspond to 110 fov
self.fy = focal_length*640/360
self.cx = 0.5
self.cy = 0.5
self.c2w_mat = euler_to_camera_to_world_matrix(entry)
self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat))
ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraY",
"cameraX",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
for i, current_actions in enumerate(actions):
for j, action_key in enumerate(ACTION_KEYS):
if action_key.startswith("camera"):
if action_key == "cameraX":
value = current_actions["camera"][0]
elif action_key == "cameraY":
value = current_actions["camera"][1]
else:
raise ValueError(f"Unknown camera action key: {action_key}")
max_val = 20
bin_size = 0.5
num_buckets = int(max_val / bin_size)
value = (value - num_buckets) / num_buckets
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
else:
value = current_actions[action_key]
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
actions_one_hot[i, j] = value
return actions_one_hot
def simpletomulti(actions):
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
vec_25[actions==1, 11] = 1
vec_25[actions==2, 16] = -1
vec_25[actions==3, 16] = 1
vec_25[actions==4, 15] = -1
vec_25[actions==5, 15] = 1
return vec_25
def simpletomulti2(actions):
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS))
vec_25[actions[:,0]==1, 11] = 1
vec_25[actions[:,0]==2, 12] = 1
vec_25[actions[:,4]==11, 16] = -1
vec_25[actions[:,4]==13, 16] = 1
vec_25[actions[:,3]==11, 15] = -1
vec_25[actions[:,3]==13, 15] = 1
vec_25[actions[:,5]==6, 24] = 1
vec_25[actions[:,5]==1, 24] = 1
vec_25[actions[:,1]==1, 13] = 1
vec_25[actions[:,1]==2, 14] = 1
vec_25[actions[:,7]==1, 2] = 1
return vec_25
class MinecraftVideoPoseDataset(BaseVideoDataset):
"""
Minecraft dataset
"""
def __init__(self, cfg: DictConfig, split: str = "training"):
if split == "test":
split = "validation"
super().__init__(cfg, split)
if hasattr(cfg, "n_frames_valid") and split == "validation":
self.n_frames = cfg.n_frames_valid
def get_data_paths(self, split):
data_dir = self.save_dir / split
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
if len(paths) == 0:
sub_path = os.listdir(data_dir)
for sp in sub_path:
data_dir = self.save_dir / split / sp
paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name)
return paths
def get_data_lengths(self, split):
lengths = [300] * len(self.get_data_paths(split))
return lengths
def download_dataset(self) -> Sequence[int]:
from internetarchive import download
part_suffixes = [
"aa",
"ab",
"ac",
"ad",
"ae",
"af",
"ag",
"ah",
"ai",
"aj",
"ak",
]
for part_suffix in part_suffixes:
identifier = f"minecraft_marsh_dataset_{part_suffix}"
file_name = f"minecraft.tar.part{part_suffix}"
download(identifier, file_name, destdir=self.save_dir, verbose=True)
combined_bytes = io.BytesIO()
for part_suffix in part_suffixes:
identifier = f"minecraft_marsh_dataset_{part_suffix}"
file_name = f"minecraft.tar.part{part_suffix}"
part_file = self.save_dir / identifier / file_name
with open(part_file, "rb") as part:
combined_bytes.write(part.read())
combined_bytes.seek(0)
with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive:
combined_archive.extractall(self.save_dir)
(self.save_dir / "minecraft/test").rename(self.save_dir / "validation")
(self.save_dir / "minecraft/train").rename(self.save_dir / "training")
(self.save_dir / "minecraft").rmdir()
for part_suffix in part_suffixes:
identifier = f"minecraft_marsh_dataset_{part_suffix}"
file_name = f"minecraft.tar.part{part_suffix}"
part_file = self.save_dir / identifier / file_name
part_file.rmdir()
def __getitem__(self, idx):
# return self.load_data(idx)
max_retries = 1000
for mr in range(max_retries):
try:
return self.load_data(idx)
except Exception as e:
print(f"{mr} Error: {e}")
# idx = self.idx_remap[idx]
# file_idx, frame_idx = self.split_idx(idx)
# video_path = self.data_paths[file_idx]
# os.remove(video_path)
idx = (idx + 1) % self.__len__()
def load_data(self, idx):
idx = self.idx_remap[idx]
file_idx, frame_idx = self.split_idx(idx)
action_path = self.data_paths[file_idx]
video_path = self.data_paths[file_idx]
action_path = video_path.with_suffix(".npz")
actions_pool = np.load(action_path)['actions']
poses_pool = np.load(action_path)['poses']
poses_pool[0,1] = poses_pool[1,1] # wrong first in place
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}"
if len(poses_pool) < len(actions_pool):
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0)))
actions_pool = simpletomulti2(actions_pool)
video_raw = EncodedVideo.from_path(video_path, decode_audio=False)
frame_idx = frame_idx + 100 # avoid first frames # first frame is useless
if self.split == "validation":
frame_idx = 240
total_frame = video_raw.duration.numerator
fps = 10 # video_raw.duration.denominator
total_frame = total_frame * fps / video_raw.duration.denominator
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"]
video = video.permute(1, 2, 3, 0).numpy()
if self.split != "validation" and 'degrees' in np.load(action_path).keys():
degrees = np.load(action_path)['degrees']
actions_pool[:,16] *= degrees
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) # (t, )
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames])
pad_len = self.n_frames - len(video)
poses_pool[:,:3] -= poses[:1,:3]
# poses_pool[:,3:] = -poses_pool[:,3:]
poses_pool[:,-1] = -poses_pool[:,-1]
poses_pool[:,3:] %= 360
poses[:,:3] -= poses[:1,:3] # do not normalize angle
# poses[:,3:] = -poses[:,3:]
poses[:,-1] = -poses[:,-1]
poses[:,3:] %= 360
nonterminal = np.ones(self.n_frames)
if len(video) < self.n_frames:
video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0)))
actions = np.pad(actions, ((0, pad_len),))
poses = np.pad(actions, ((0, pad_len),))
nonterminal[-pad_len:] = 0
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
return (
video[:: self.frame_skip],
actions[:: self.frame_skip],
poses[:: self.frame_skip]
)
if __name__ == "__main__":
import torch
from unittest.mock import MagicMock
import tqdm
cfg = MagicMock()
cfg.resolution = 64
cfg.external_cond_dim = 0
cfg.n_frames = 64
cfg.save_dir = "data/minecraft"
cfg.validation_multiplier = 1
dataset = MinecraftVideoDataset(cfg, "training")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)
for batch in tqdm.tqdm(dataloader):
pass