import torch from typing import Sequence import numpy as np import io import tarfile from pytorchvideo.data.encoded_video import EncodedVideo from omegaconf import DictConfig from tqdm import tqdm from .base_video_dataset import BaseVideoDataset from typing import Mapping, Sequence import os import math from packaging import version as pver from PIL import Image import random def euler_to_rotation_matrix(pitch, yaw): """ Convert euler angles (pitch, yaw) to a 3x3 rotation matrix. pitch: rotation around x-axis (in radians) yaw: rotation around y-axis (in radians) """ # Rotation matrix around x-axis (pitch) R_x = np.array([ [1, 0, 0], [0, math.cos(pitch), -math.sin(pitch)], [0, math.sin(pitch), math.cos(pitch)] ]) # Rotation matrix around y-axis (yaw) R_y = np.array([ [math.cos(yaw), 0, math.sin(yaw)], [0, 1, 0], [-math.sin(yaw), 0, math.cos(yaw)] ]) # Combined rotation matrix R = np.dot(R_y, R_x) return R def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse('1.10'): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing='ij') def camera_to_world_to_world_to_camera(camera_to_world): """ Convert Camera-to-World matrix to World-to-Camera matrix by inverting the transformation. """ # Extract rotation (R) and translation (T) R = camera_to_world[:3, :3] T = camera_to_world[:3, 3] # Calculate World-to-Camera (inverse) matrix world_to_camera = np.eye(4) # The rotation part of World-to-Camera is the transpose of Camera-to-World's rotation world_to_camera[:3, :3] = R.T # The translation part is the negative of the rotated translation world_to_camera[:3, 3] = -np.dot(R.T, T) return world_to_camera def euler_to_camera_to_world_matrix(pose): x, y, z, pitch, yaw = pose # Convert pitch and yaw to radians pitch = math.radians(pitch) yaw = math.radians(yaw) # Get the rotation matrix from Euler angles R = euler_to_rotation_matrix(pitch, yaw) # Create the 4x4 transformation matrix (rotation + translation) camera_to_world = np.eye(4) # Set the rotation part (upper 3x3) camera_to_world[:3, :3] = R # Set the translation part (last column) camera_to_world[:3, 3] = [x, y, z] return camera_to_world def tensor_to_gif(tensor, output_path, fps=10): """ Converts a PyTorch tensor of shape (F, 3, H, W) to a GIF. Args: tensor (torch.Tensor): Input tensor of shape (F, 3, H, W) with values in range [0, 1] or [0, 255]. output_path (str): Path to save the output GIF. fps (int): Frames per second for the GIF. """ # Ensure the tensor is in [0, 255] range if tensor.max() <= 1.0: tensor = (tensor * 255).byte() else: tensor = tensor.byte() # Convert tensor to numpy array and rearrange to (F, H, W, 3) frames = tensor.permute(0, 2, 3, 1).cpu().numpy() # Convert frames to PIL Images pil_frames = [Image.fromarray(frame) for frame in frames] # Save as GIF pil_frames[0].save( output_path, save_all=True, append_images=pil_frames[1:], duration=int(1000 / fps), loop=0 ) def get_relative_pose(cam_params, zero_first_frame_scale): abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] source_cam_c2w = abs_c2ws[0] if zero_first_frame_scale: cam_to_origin = 0 else: cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) target_cam_c2w = np.array([ [1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1] ]) abs2rel = target_cam_c2w @ abs_w2cs[0] ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] ret_poses = np.array(ret_poses, dtype=np.float32) return ret_poses def ray_condition(K, c2w, H, W, device): # c2w: B, V, 4, 4 # K: B, V, 4 B = K.shape[0] j, i = custom_meshgrid( torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), ) i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 zs = torch.ones_like(i) # [B, HxW] xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs zs = zs.expand_as(ys) directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW rays_o = c2w[..., :3, 3] # B, V, 3 rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW # c2w @ dirctions rays_dxo = torch.linalg.cross(rays_o, rays_d) plucker = torch.cat([rays_dxo, rays_d], dim=-1) plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 return plucker class Camera(object): def __init__(self, entry, focal_length=0.35): self.fx = focal_length # 0.35 correspond to 110 fov self.fy = focal_length*640/360 self.cx = 0.5 self.cy = 0.5 self.c2w_mat = euler_to_camera_to_world_matrix(entry) self.w2c_mat = camera_to_world_to_world_to_camera(np.copy(self.c2w_mat)) ACTION_KEYS = [ "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward", "back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint", "swapHands", "attack", "use", "pickItem", "drop", ] def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor: actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS)) for i, current_actions in enumerate(actions): for j, action_key in enumerate(ACTION_KEYS): if action_key.startswith("camera"): if action_key == "cameraX": value = current_actions["camera"][0] elif action_key == "cameraY": value = current_actions["camera"][1] else: raise ValueError(f"Unknown camera action key: {action_key}") max_val = 20 bin_size = 0.5 num_buckets = int(max_val / bin_size) value = (value - num_buckets) / num_buckets assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}" else: value = current_actions[action_key] assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}" actions_one_hot[i, j] = value return actions_one_hot def simpletomulti(actions): vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) vec_25[actions==1, 11] = 1 vec_25[actions==2, 16] = -1 vec_25[actions==3, 16] = 1 vec_25[actions==4, 15] = -1 vec_25[actions==5, 15] = 1 return vec_25 def simpletomulti2(actions): vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) vec_25[actions[:,0]==1, 11] = 1 vec_25[actions[:,0]==2, 12] = 1 vec_25[actions[:,4]==11, 16] = -1 vec_25[actions[:,4]==13, 16] = 1 vec_25[actions[:,3]==11, 15] = -1 vec_25[actions[:,3]==13, 15] = 1 vec_25[actions[:,5]==6, 24] = 1 vec_25[actions[:,5]==1, 24] = 1 vec_25[actions[:,1]==1, 13] = 1 vec_25[actions[:,1]==2, 14] = 1 vec_25[actions[:,7]==1, 2] = 1 return vec_25 class MinecraftVideoPoseDataset(BaseVideoDataset): """ Minecraft dataset """ def __init__(self, cfg: DictConfig, split: str = "training"): if split == "test": split = "validation" super().__init__(cfg, split) if hasattr(cfg, "n_frames_valid") and split == "validation": self.n_frames = cfg.n_frames_valid def get_data_paths(self, split): data_dir = self.save_dir / split paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) if len(paths) == 0: sub_path = os.listdir(data_dir) for sp in sub_path: data_dir = self.save_dir / split / sp paths = paths+sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) return paths def get_data_lengths(self, split): lengths = [300] * len(self.get_data_paths(split)) return lengths def download_dataset(self) -> Sequence[int]: from internetarchive import download part_suffixes = [ "aa", "ab", "ac", "ad", "ae", "af", "ag", "ah", "ai", "aj", "ak", ] for part_suffix in part_suffixes: identifier = f"minecraft_marsh_dataset_{part_suffix}" file_name = f"minecraft.tar.part{part_suffix}" download(identifier, file_name, destdir=self.save_dir, verbose=True) combined_bytes = io.BytesIO() for part_suffix in part_suffixes: identifier = f"minecraft_marsh_dataset_{part_suffix}" file_name = f"minecraft.tar.part{part_suffix}" part_file = self.save_dir / identifier / file_name with open(part_file, "rb") as part: combined_bytes.write(part.read()) combined_bytes.seek(0) with tarfile.open(fileobj=combined_bytes, mode="r") as combined_archive: combined_archive.extractall(self.save_dir) (self.save_dir / "minecraft/test").rename(self.save_dir / "validation") (self.save_dir / "minecraft/train").rename(self.save_dir / "training") (self.save_dir / "minecraft").rmdir() for part_suffix in part_suffixes: identifier = f"minecraft_marsh_dataset_{part_suffix}" file_name = f"minecraft.tar.part{part_suffix}" part_file = self.save_dir / identifier / file_name part_file.rmdir() def __getitem__(self, idx): # return self.load_data(idx) max_retries = 1000 for mr in range(max_retries): try: return self.load_data(idx) except Exception as e: print(f"{mr} Error: {e}") # idx = self.idx_remap[idx] # file_idx, frame_idx = self.split_idx(idx) # video_path = self.data_paths[file_idx] # os.remove(video_path) idx = (idx + 1) % self.__len__() def load_data(self, idx): idx = self.idx_remap[idx] file_idx, frame_idx = self.split_idx(idx) action_path = self.data_paths[file_idx] video_path = self.data_paths[file_idx] action_path = video_path.with_suffix(".npz") actions_pool = np.load(action_path)['actions'] poses_pool = np.load(action_path)['poses'] poses_pool[0,1] = poses_pool[1,1] # wrong first in place assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}" if len(poses_pool) < len(actions_pool): poses_pool = np.pad(poses_pool, ((1, 0), (0, 0))) actions_pool = simpletomulti2(actions_pool) video_raw = EncodedVideo.from_path(video_path, decode_audio=False) frame_idx = frame_idx + 100 # avoid first frames # first frame is useless if self.split == "validation": frame_idx = 240 total_frame = video_raw.duration.numerator fps = 10 # video_raw.duration.denominator total_frame = total_frame * fps / video_raw.duration.denominator video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"] video = video.permute(1, 2, 3, 0).numpy() if self.split != "validation" and 'degrees' in np.load(action_path).keys(): degrees = np.load(action_path)['degrees'] actions_pool[:,16] *= degrees actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) # (t, ) poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames]) pad_len = self.n_frames - len(video) poses_pool[:,:3] -= poses[:1,:3] # poses_pool[:,3:] = -poses_pool[:,3:] poses_pool[:,-1] = -poses_pool[:,-1] poses_pool[:,3:] %= 360 poses[:,:3] -= poses[:1,:3] # do not normalize angle # poses[:,3:] = -poses[:,3:] poses[:,-1] = -poses[:,-1] poses[:,3:] %= 360 nonterminal = np.ones(self.n_frames) if len(video) < self.n_frames: video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0))) actions = np.pad(actions, ((0, pad_len),)) poses = np.pad(actions, ((0, pad_len),)) nonterminal[-pad_len:] = 0 video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() return ( video[:: self.frame_skip], actions[:: self.frame_skip], poses[:: self.frame_skip] ) if __name__ == "__main__": import torch from unittest.mock import MagicMock import tqdm cfg = MagicMock() cfg.resolution = 64 cfg.external_cond_dim = 0 cfg.n_frames = 64 cfg.save_dir = "data/minecraft" cfg.validation_multiplier = 1 dataset = MinecraftVideoDataset(cfg, "training") dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16) for batch in tqdm.tqdm(dataloader): pass