|
import os |
|
import io |
|
import tarfile |
|
import numpy as np |
|
import torch |
|
from typing import Sequence, Mapping |
|
from omegaconf import DictConfig |
|
from pytorchvideo.data.encoded_video import EncodedVideo |
|
import random |
|
|
|
from .base_video_dataset import BaseVideoDataset |
|
|
|
|
|
|
|
|
|
ACTION_KEYS = [ |
|
"inventory", |
|
"ESC", |
|
"hotbar.1", |
|
"hotbar.2", |
|
"hotbar.3", |
|
"hotbar.4", |
|
"hotbar.5", |
|
"hotbar.6", |
|
"hotbar.7", |
|
"hotbar.8", |
|
"hotbar.9", |
|
"forward", |
|
"back", |
|
"left", |
|
"right", |
|
"cameraY", |
|
"cameraX", |
|
"jump", |
|
"sneak", |
|
"sprint", |
|
"swapHands", |
|
"attack", |
|
"use", |
|
"pickItem", |
|
"drop", |
|
] |
|
|
|
def convert_action_space(actions): |
|
vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) |
|
vec_25[actions[:,0]==1, 11] = 1 |
|
vec_25[actions[:,0]==2, 12] = 1 |
|
vec_25[actions[:,4]==11, 16] = -1 |
|
vec_25[actions[:,4]==13, 16] = 1 |
|
vec_25[actions[:,3]==11, 15] = -1 |
|
vec_25[actions[:,3]==13, 15] = 1 |
|
vec_25[actions[:,5]==6, 24] = 1 |
|
vec_25[actions[:,5]==1, 24] = 1 |
|
vec_25[actions[:,1]==1, 13] = 1 |
|
vec_25[actions[:,1]==2, 14] = 1 |
|
vec_25[actions[:,7]==1, 2] = 1 |
|
return vec_25 |
|
|
|
|
|
class MinecraftVideoDataset(BaseVideoDataset): |
|
""" |
|
Minecraft video dataset for training and validation. |
|
|
|
Args: |
|
cfg (DictConfig): Configuration object. |
|
split (str): Dataset split ("training" or "validation"). |
|
""" |
|
def __init__(self, cfg: DictConfig, split: str = "training"): |
|
if split == "test": |
|
split = "validation" |
|
super().__init__(cfg, split) |
|
self.n_frames = cfg.n_frames_valid if split == "validation" and hasattr(cfg, "n_frames_valid") else cfg.n_frames |
|
self.use_plucker = cfg.use_plucker |
|
self.condition_similar_length = cfg.condition_similar_length |
|
self.customized_validation = cfg.customized_validation |
|
self.angle_range = cfg.angle_range |
|
self.pos_range = cfg.pos_range |
|
self.add_frame_timestep_embedder = cfg.add_frame_timestep_embedder |
|
self.training_dropout = 0.1 |
|
self.sample_more_place = getattr(cfg, "sample_more_place", False) |
|
self.within_context = getattr(cfg, "within_context", False) |
|
self.sample_more_event = getattr(cfg, "sample_more_event", False) |
|
self.causal_frame = getattr(cfg, "causal_frame", False) |
|
|
|
def get_data_paths(self, split: str): |
|
""" |
|
Retrieve all video file paths for the given split. |
|
|
|
Args: |
|
split (str): Dataset split ("training" or "validation"). |
|
|
|
Returns: |
|
List[Path]: List of video file paths. |
|
""" |
|
data_dir = self.save_dir / split |
|
paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) |
|
if not paths: |
|
sub_dirs = os.listdir(data_dir) |
|
for sub_dir in sub_dirs: |
|
sub_path = data_dir / sub_dir |
|
paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name) |
|
return paths |
|
|
|
def download_dataset(self): |
|
pass |
|
|
|
def __getitem__(self, idx: int): |
|
""" |
|
Retrieve a single data sample by index. |
|
|
|
Args: |
|
idx (int): Index of the data sample. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timesteps. |
|
""" |
|
max_retries = 1000 |
|
for _ in range(max_retries): |
|
try: |
|
return self.load_data(idx) |
|
except Exception as e: |
|
print(f"Retrying due to error: {e}") |
|
idx = (idx + 1) % len(self) |
|
|
|
def load_data(self, idx): |
|
idx = self.idx_remap[idx] |
|
file_idx, frame_idx = self.split_idx(idx) |
|
action_path = self.data_paths[file_idx] |
|
video_path = self.data_paths[file_idx] |
|
|
|
action_path = video_path.with_suffix(".npz") |
|
actions_pool = np.load(action_path)['actions'] |
|
poses_pool = np.load(action_path)['poses'] |
|
|
|
|
|
poses_pool[0,1] = poses_pool[1,1] |
|
|
|
assert poses_pool[:,1].max() - poses_pool[:,1].min() < 2, f"wrong~~~~{poses_pool[:,1].max() - poses_pool[:,1].min()}-{video_path}" |
|
|
|
|
|
if len(poses_pool) < len(actions_pool): |
|
poses_pool = np.pad(poses_pool, ((1, 0), (0, 0))) |
|
|
|
actions_pool = convert_action_space(actions_pool) |
|
video_raw = EncodedVideo.from_path(video_path, decode_audio=False) |
|
|
|
frame_idx = frame_idx + 100 |
|
|
|
if self.split == "validation": |
|
frame_idx = 240 |
|
|
|
if self.sample_more_place and self.split == "training": |
|
if random.uniform(0, 1) > 0.5: |
|
place_mask = (actions_pool[:,24]==1) |
|
place_mask[:100] = 0 |
|
valid_indices = np.where(place_mask)[0] |
|
random_index = np.random.choice(valid_indices) |
|
frame_idx = random_index - random.randint(1, self.n_frames-1) |
|
|
|
total_frame = video_raw.duration.numerator |
|
fps = 10 |
|
total_frame = total_frame * fps / video_raw.duration.denominator |
|
video = video_raw.get_clip(start_sec=frame_idx/fps, end_sec=(frame_idx+self.n_frames)/fps)["video"] |
|
video = video.permute(1, 2, 3, 0).numpy() |
|
|
|
if self.split != "validation" and 'degrees' in np.load(action_path).keys(): |
|
degrees = np.load(action_path)['degrees'] |
|
actions_pool[:,16] *= degrees |
|
|
|
actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) |
|
|
|
poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames]) |
|
pad_len = self.n_frames - len(video) |
|
poses_pool[:,:3] -= poses[:1,:3] |
|
poses_pool[:,-1] = -poses_pool[:,-1] |
|
poses_pool[:,3:] %= 360 |
|
|
|
poses[:,:3] -= poses[:1,:3] |
|
poses[:,-1] = -poses[:,-1] |
|
poses[:,3:] %= 360 |
|
|
|
assert len(video) >= self.n_frames, f"{video_path}" |
|
|
|
if self.split == "training" and self.condition_similar_length>0: |
|
if random.uniform(0, 1) > self.training_dropout: |
|
refer_frame_dis = poses[:,None] - poses_pool[None,:] |
|
refer_frame_dis = np.abs(refer_frame_dis) |
|
refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] = 360 - refer_frame_dis[...,3:][refer_frame_dis[...,3:] > 180] |
|
valid_index = ((((refer_frame_dis[..., :3] <= self.pos_range).sum(-1))>=3) & (((refer_frame_dis[..., 3:] <= self.angle_range).sum(-1))>=2) & \ |
|
((((refer_frame_dis[..., :3] > 0).sum(-1))>=1) | (((refer_frame_dis[..., 3:] > 0).sum(-1))>=1)) |
|
).sum(0) |
|
valid_index[:100] = 0 |
|
|
|
if self.add_frame_timestep_embedder and self.causal_frame and (actions_pool[:frame_idx,24]==1).sum() > 0: |
|
valid_index[frame_idx:] = 0 |
|
|
|
mask = valid_index >= 1 |
|
mask[0] = False |
|
candidate_indices = np.argwhere(mask) |
|
|
|
mask2 = valid_index >= 0 |
|
mask2[0] = False |
|
|
|
count = min(self.condition_similar_length, candidate_indices.shape[0]) |
|
selected_indices = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:,0] |
|
|
|
if count < self.condition_similar_length: |
|
candidate_indices2 = np.argwhere(mask2) |
|
selected_indices2 = candidate_indices2[np.random.choice(candidate_indices2.shape[0], self.condition_similar_length-count, replace=True)][:,0] |
|
selected_indices = np.concatenate([selected_indices, selected_indices2]) |
|
|
|
if self.sample_more_event: |
|
if random.uniform(0, 1) > 0.3: |
|
valid_idx = torch.nonzero(actions_pool[:frame_idx,24]==1)[:,0] |
|
if len(valid_idx) > self.condition_similar_length //2: |
|
valid_idx = valid_idx[-self.condition_similar_length //2:] |
|
|
|
if len(valid_idx) > 0: |
|
selected_indices[-len(valid_idx):] = valid_idx + 4 |
|
|
|
else: |
|
selected_indices = np.array(list(range(self.condition_similar_length))) * 0 + random.randint(0, frame_idx) |
|
|
|
video_pool = [] |
|
for si in selected_indices: |
|
video_pool.append(video_raw.get_clip(start_sec=si/fps, end_sec=(si+1)/fps)["video"][:,0].permute(1,2,0)) |
|
|
|
video_pool = np.stack(video_pool) |
|
video = np.concatenate([video, video_pool]) |
|
actions = np.concatenate([actions, actions_pool[selected_indices]]) |
|
poses = np.concatenate([poses, poses_pool[selected_indices]]) |
|
|
|
timestep = np.concatenate([np.array(list(range(frame_idx, frame_idx + self.n_frames))), selected_indices]) |
|
|
|
else: |
|
timestep = np.array(list(range(self.n_frames))) |
|
|
|
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() |
|
|
|
if self.split == "validation" and not self.customized_validation: |
|
num_frame = actions.shape[0] |
|
|
|
actions[:] = 0 |
|
actions[:,16] = 1 |
|
poses[:] = 0 |
|
for ff in range(1, num_frame): |
|
poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15 |
|
|
|
if self.within_context: |
|
actions[:] = 0 |
|
actions[:self.n_frames//2+1,16] = 1 |
|
actions[self.n_frames//2+1:,16] = -1 |
|
poses[:] = 0 |
|
for ff in range(1, num_frame): |
|
poses[ff,4] = poses[ff-1,4] + actions[ff,16] * -15 |
|
|
|
return ( |
|
video[:: self.frame_skip], |
|
actions[:: self.frame_skip], |
|
poses[:: self.frame_skip], |
|
timestep |
|
) |