import os import io import tarfile import numpy as np import torch from typing import Sequence, Mapping from omegaconf import DictConfig from pytorchvideo.data.encoded_video import EncodedVideo import random from .base_video_dataset import BaseVideoDataset ACTION_KEYS = [ "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward", "back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint", "swapHands", "attack", "use", "pickItem", "drop", ] def convert_action_space(actions): vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) vec_25[actions[:,0]==1, 11] = 1 vec_25[actions[:,0]==2, 12] = 1 vec_25[actions[:,4]==11, 16] = -1 vec_25[actions[:,4]==13, 16] = 1 vec_25[actions[:,3]==11, 15] = -1 vec_25[actions[:,3]==13, 15] = 1 vec_25[actions[:,5]==6, 24] = 1 vec_25[actions[:,5]==1, 24] = 1 vec_25[actions[:,1]==1, 13] = 1 vec_25[actions[:,1]==2, 14] = 1 vec_25[actions[:,7]==1, 2] = 1 return vec_25 # Dataset class class MinecraftVideoDataset(BaseVideoDataset): """ Minecraft video dataset for training and validation. Args: cfg (DictConfig): Configuration object. split (str): Dataset split ("training" or "validation"). """ def __init__(self, cfg: DictConfig, split: str = "training"): if split == "test": split = "validation" super().__init__(cfg, split) self.n_frames = cfg.n_frames_valid if split == "validation" and hasattr(cfg, "n_frames_valid") else cfg.n_frames self.use_plucker = cfg.use_plucker self.condition_similar_length = cfg.condition_similar_length self.customized_validation = cfg.customized_validation self.angle_range = cfg.angle_range self.pos_range = cfg.pos_range self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder self.training_dropout = 0.1 self.sample_more_place = getattr(cfg, "sample_more_place", False) self.within_context = getattr(cfg, "within_context", False) self.sample_more_event = getattr(cfg, "sample_more_event", False) self.causal_frame = getattr(cfg, "causal_frame", False) def get_data_paths(self, split: str): """ Retrieve all video file paths for the given split. Args: split (str): Dataset split ("training" or "validation"). Returns: List[Path]: List of video file paths. """ data_dir = self.save_dir / split paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) if not paths: sub_dirs = os.listdir(data_dir) for sub_dir in sub_dirs: sub_path = data_dir / sub_dir paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name) return paths def download_dataset(self): pass def __getitem__(self, idx: int): """ Retrieve a single data sample by index. Args: idx (int): Index of the data sample. Returns: Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timesteps. """ max_retries = 1000 for _ in range(max_retries): try: return self.load_data(idx) except Exception as e: print(f"Retrying due to error: {e}") idx = (idx + 1) % len(self) def load_data(self, idx): idx = self.idx_remap[idx] file_idx, frame_idx = self.split_idx(idx) action_path = self.data_paths[file_idx] video_path = self.data_paths[file_idx] action_path = video_path.with_suffix(".npz") actions_pool = np.load(action_path)['actions'] poses_pool = np.load(action_path)['poses'] poses_pool[0,1] = poses_pool[1,1] # wrong first in place assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}" if len(poses_pool) < len(actions_pool): poses_pool = np.pad(poses_pool, ((1, 0), (0, 0))) actions_pool = convert_action_space(actions_pool) video_raw = EncodedVideo.from_path(video_path, decode_audio=False) frame_idx = frame_idx + 100 # avoid first frames # first frame is useless if self.split == "validation": frame_idx = 240 if self.sample_more_place and self.split == "training": if random.uniform(0, 1) > 0.5: place_mask = (actions_pool[:,24]==1) place_mask[:100] = 0 valid_indices = np.where(place_mask)[0] random_index = np.random.choice(valid_indices) frame_idx = random_index - random.randint(1, self.n_frames-1) total_frame = video_raw.duration.numerator fps = 10 # video_raw.duration.denominator total_frame = total_frame * fps / video_raw.duration.denominator video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"] video = video.permute(1, 2, 3, 0).numpy() if self.split != "validation" and 'degrees' in np.load(action_path).keys(): degrees = np.load(action_path)['degrees'] actions_pool[:,16] *= degrees actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames]) pad_len = self.n_frames - len(video) poses_pool[:,:3] -= poses[:1,:3] poses_pool[:,-1] = -poses_pool[:,-1] poses_pool[:,3:] %= 360 poses[:,:3] -= poses[:1,:3] # do not normalize angle poses[:,-1] = -poses[:,-1] poses[:,3:] %= 360 assert len(video) >= self.n_frames, f"{video_path}" if self.split == "training" and self.condition_similar_length>0: if random.uniform(0, 1) > self.training_dropout: refer_frame_dis = poses[:,None] - poses_pool[None,:] refer_frame_dis = np.abs(refer_frame_dis) refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] = 360 - refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] valid_index = ((((refer_frame_dis[..., :3] <= self.pos_range).sum(-1))>=3) & (((refer_frame_dis[..., 3:] <= self.angle_range).sum(-1))>=2) & \ ((((refer_frame_dis[..., :3] > 0).sum(-1))>=1) | (((refer_frame_dis[..., 3:] > 0).sum(-1))>=1)) ).sum(0) valid_index[:100] = 0 # mute bad initial scene if self.add_frame_timestep_embedder and self.causal_frame and (actions_pool[:frame_idx,24]==1).sum() > 0: valid_index[frame_idx:] = 0 mask = valid_index >= 1 mask[0] = False candidate_indices = np.argwhere(mask) mask2 = valid_index >= 0 mask2[0] = False count = min(self.condition_similar_length, candidate_indices.shape[0]) selected_indices = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:,0] if count < self.condition_similar_length: candidate_indices2 = np.argwhere(mask2) selected_indices2 = candidate_indices2[np.random.choice(candidate_indices2.shape[0], self.condition_similar_length-count, replace=True)][:,0] selected_indices = np.concatenate([selected_indices, selected_indices2]) if self.sample_more_event: if random.uniform(0, 1) > 0.3: valid_idx = torch.nonzero(actions_pool[:frame_idx,24]==1)[:,0] if len(valid_idx) > self.condition_similar_length //2: valid_idx = valid_idx[-self.condition_similar_length //2:] if len(valid_idx) > 0: selected_indices[-len(valid_idx):] = valid_idx + 4 else: selected_indices = np.array(list(range(self.condition_similar_length))) * 0 + random.randint(0, frame_idx) video_pool = [] for si in selected_indices: video_pool.append(video_raw.get_clip(start_sec=si/fps, end_sec=(si+1)/fps)["video"][:,0].permute(1,2,0)) video_pool = np.stack(video_pool) video = np.concatenate([video, video_pool]) actions = np.concatenate([actions, actions_pool[selected_indices]]) poses = np.concatenate([poses, poses_pool[selected_indices]]) timestep = np.concatenate([np.array(list(range(frame_idx, frame_idx + self.n_frames))), selected_indices]) else: timestep = np.array(list(range(self.n_frames))) video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() if self.split == "validation" and not self.customized_validation: num_frame = actions.shape[0] actions[:] = 0 actions[:,16] = 1 poses[:] = 0 for ff in range(1, num_frame): poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15 if self.within_context: actions[:] = 0 actions[:self.n_frames//2+1,16] = 1 actions[self.n_frames//2+1:,16] = -1 poses[:] = 0 for ff in range(1, num_frame): poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15 return ( video[:: self.frame_skip], actions[:: self.frame_skip], poses[:: self.frame_skip], timestep )