Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README.md +74 -3
- config.json +26 -0
- model.safetensors +3 -0
- soccer_qa_inference.py +304 -0
- special_tokens_map.json +26 -0
- src/datasets/data_manager.py +90 -0
- src/datasets/imagenet1k.py +152 -0
- src/datasets/utils/dataloader.py +234 -0
- src/datasets/utils/utils.py +21 -0
- src/datasets/utils/video/__pycache__/functional.cpython-312.pyc +0 -0
- src/datasets/utils/video/__pycache__/randaugment.cpython-312.pyc +0 -0
- src/datasets/utils/video/__pycache__/transforms.cpython-312.pyc +0 -0
- src/datasets/utils/video/__pycache__/volume_transforms.cpython-312.pyc +0 -0
- src/datasets/utils/video/functional.py +110 -0
- src/datasets/utils/video/randaugment.py +536 -0
- src/datasets/utils/video/randerase.py +170 -0
- src/datasets/utils/video/transforms.py +1161 -0
- src/datasets/utils/video/transforms_builder.py +165 -0
- src/datasets/utils/video/volume_transforms.py +159 -0
- src/datasets/utils/weighted_sampler.py +336 -0
- src/datasets/utils/worker_init_fn.py +76 -0
- src/datasets/video_dataset.py +373 -0
- src/hub/__init__.py +0 -0
- src/hub/backbones.py +177 -0
- src/masks/__pycache__/utils.cpython-312.pyc +0 -0
- src/masks/default.py +18 -0
- src/masks/multiseq_multiblock3d.py +239 -0
- src/masks/utils.py +21 -0
- src/models/__pycache__/attentive_pooler.cpython-312.pyc +0 -0
- src/models/__pycache__/vision_transformer.cpython-312.pyc +0 -0
- src/models/ac_predictor.py +200 -0
- src/models/attentive_pooler.py +137 -0
- src/models/predictor.py +253 -0
- src/models/utils/__pycache__/modules.cpython-312.pyc +0 -0
- src/models/utils/__pycache__/patch_embed.cpython-312.pyc +0 -0
- src/models/utils/__pycache__/pos_embs.cpython-312.pyc +0 -0
- src/models/utils/modules.py +610 -0
- src/models/utils/patch_embed.py +52 -0
- src/models/utils/pos_embs.py +93 -0
- src/models/vision_transformer.py +487 -0
- src/utils/__pycache__/tensors.cpython-312.pyc +0 -0
- src/utils/checkpoint_loader.py +37 -0
- src/utils/distributed.py +101 -0
- src/utils/logging.py +108 -0
- src/utils/monitoring.py +171 -0
- src/utils/schedulers.py +93 -0
- src/utils/tensors.py +53 -0
- src/utils/wrappers.py +43 -0
- tokenizer.json +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,74 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
license: cc-by-nc-4.0
|
4 |
+
tags:
|
5 |
+
- soccer
|
6 |
+
- video-qa
|
7 |
+
- question-answering
|
8 |
+
- vision-language
|
9 |
+
- multimodal
|
10 |
+
- sports-analysis
|
11 |
+
library_name: transformers
|
12 |
+
pipeline_tag: video-text-to-text
|
13 |
+
---
|
14 |
+
|
15 |
+
# Soccer QA 4B - Soccer Video Question Answering Model
|
16 |
+
|
17 |
+
**⚠️ RESEARCH USE ONLY - NON-COMMERCIAL LICENSE**
|
18 |
+
|
19 |
+
Soccer QA 4B is a unified video question-answering model specifically designed for soccer video understanding.
|
20 |
+
|
21 |
+
## Model Description
|
22 |
+
|
23 |
+
This model can answer questions about soccer videos by analyzing visual content and generating natural language responses.
|
24 |
+
|
25 |
+
**Example:**
|
26 |
+
- **Input**: Video + "What unfolded during the game in the video?"
|
27 |
+
- **Output**: "During the game, there was a foul committed by a player from the yellow-jerseyed team, leading to a yellow card being issued..."
|
28 |
+
|
29 |
+
## Architecture
|
30 |
+
- **Vision Encoder**: DWT-VJEPA2-based video encoder (vit_giant, 1408 dim)
|
31 |
+
- **Text Model**: LLaMA 3.2-3B with LoRA fine-tuning
|
32 |
+
- **Vision-Text Bridge**: Learned projection layer (1408 → 2048 → 3072)
|
33 |
+
- **Specialization**: Fine-tuned on soccer video QA data
|
34 |
+
|
35 |
+
## Usage (Helper functions are in repo)
|
36 |
+
|
37 |
+
```python
|
38 |
+
from soccer_qa_inference import SoccerQA
|
39 |
+
|
40 |
+
model = SoccerQA("/path/to/model")
|
41 |
+
answer = model.ask("video.mp4", "Was this a Foul?", max_tokens=45)
|
42 |
+
print(answer)
|
43 |
+
```
|
44 |
+
|
45 |
+
## Model Details
|
46 |
+
- **Parameters**: ~4B total
|
47 |
+
- **Input**: Video files (16 frames, 256x256) + text questions
|
48 |
+
- **Output**: Natural language answers
|
49 |
+
- **Domain**: Soccer/football video analysis
|
50 |
+
- **Context**: Handles complex game situations, player actions, fouls, etc.
|
51 |
+
|
52 |
+
## Training Data
|
53 |
+
- Soccer video clips with question-answer pairs
|
54 |
+
- Covers various game situations: fouls, shots, saves, player actions
|
55 |
+
- Annotated with detailed descriptions of game events
|
56 |
+
|
57 |
+
## Limitations
|
58 |
+
- Research use only, no commercial applications
|
59 |
+
- Optimized specifically for soccer content
|
60 |
+
- May not generalize well to other sports or video domains
|
61 |
+
- Requires high-quality video input for best results
|
62 |
+
|
63 |
+
## License
|
64 |
+
CC-BY-NC-4.0 - Research use only, no commercial applications permitted.
|
65 |
+
|
66 |
+
## Citation
|
67 |
+
```bibtex
|
68 |
+
@misc{soccer-qa-4b-2025,
|
69 |
+
title={Soccer QA 4B: Video Question Answering for Soccer Analysis},
|
70 |
+
author={Varun Kodathala, Sports Vision},
|
71 |
+
year={2025},
|
72 |
+
note={Research model for soccer video understanding}
|
73 |
+
}
|
74 |
+
```
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "soccer_qa_4b",
|
3 |
+
"architectures": [
|
4 |
+
"SoccerQA4BModel"
|
5 |
+
],
|
6 |
+
"vision_dim": 1408,
|
7 |
+
"projection_dim": 2048,
|
8 |
+
"text_dim": 3072,
|
9 |
+
"img_size": 256,
|
10 |
+
"num_frames": 16,
|
11 |
+
"max_length": 256,
|
12 |
+
"temperature": 0.7,
|
13 |
+
"imagenet_mean": [
|
14 |
+
0.485,
|
15 |
+
0.456,
|
16 |
+
0.406
|
17 |
+
],
|
18 |
+
"imagenet_std": [
|
19 |
+
0.229,
|
20 |
+
0.224,
|
21 |
+
0.225
|
22 |
+
],
|
23 |
+
"hidden_size": 3072,
|
24 |
+
"vocab_size": 128257,
|
25 |
+
"model_description": "Soccer video question answering model"
|
26 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f268918412af5ab623937ca776bf5c91eb26f04c3ff7e4cc257598aeda61b7cc
|
3 |
+
size 18512562808
|
soccer_qa_inference.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Soccer QA Inference - Single Class, Clean API
|
4 |
+
|
5 |
+
Usage in Colab:
|
6 |
+
from soccer_qa_inference import SoccerQA
|
7 |
+
model = SoccerQA("soccer-qa-3b-unified")
|
8 |
+
answer = model.ask("video.mp4", "What happened?", max_tokens=128)
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import numpy as np
|
16 |
+
from safetensors.torch import load_file
|
17 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
18 |
+
from decord import VideoReader
|
19 |
+
|
20 |
+
# Import your existing modules
|
21 |
+
import src.datasets.utils.video.transforms as video_transforms
|
22 |
+
import src.datasets.utils.video.volume_transforms as volume_transforms
|
23 |
+
from src.models.vision_transformer import vit_giant_rope
|
24 |
+
|
25 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
26 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
27 |
+
|
28 |
+
def get_video(fname, num_frames=16):
|
29 |
+
"""Load video and sample frames uniformly"""
|
30 |
+
vr = VideoReader(fname)
|
31 |
+
frame_idx = np.linspace(0, len(vr) - 1, num=num_frames).astype(np.int64)
|
32 |
+
video = vr.get_batch(frame_idx).asnumpy()
|
33 |
+
return video
|
34 |
+
|
35 |
+
def build_video_transform(img_size):
|
36 |
+
"""Build video preprocessing transform"""
|
37 |
+
short_side_size = int(256.0 / 224 * img_size)
|
38 |
+
eval_transform = video_transforms.Compose([
|
39 |
+
video_transforms.Resize(short_side_size, interpolation="bilinear"),
|
40 |
+
video_transforms.CenterCrop(size=(img_size, img_size)),
|
41 |
+
volume_transforms.ClipToTensor(),
|
42 |
+
video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
43 |
+
])
|
44 |
+
return eval_transform
|
45 |
+
|
46 |
+
class SoccerQA:
|
47 |
+
"""Single class for Soccer QA inference - Clean Colab API"""
|
48 |
+
|
49 |
+
def __init__(self, model_dir="/home/varunkodathala/jepa_llm/soccer_pretrain/soccer-qa-3b-unified"):
|
50 |
+
"""Initialize Soccer QA model
|
51 |
+
|
52 |
+
Args:
|
53 |
+
model_dir: Path to merged model directory
|
54 |
+
"""
|
55 |
+
self.model_dir = model_dir
|
56 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
57 |
+
|
58 |
+
print(f"🚀 Loading Soccer QA from {model_dir}...")
|
59 |
+
|
60 |
+
# Load config and tokenizer
|
61 |
+
self._load_config()
|
62 |
+
self._load_tokenizer()
|
63 |
+
|
64 |
+
# Build models
|
65 |
+
self._build_vision_model()
|
66 |
+
self._build_text_model()
|
67 |
+
self._build_projection()
|
68 |
+
|
69 |
+
# Load all weights
|
70 |
+
self._load_weights()
|
71 |
+
|
72 |
+
# Build video transforms
|
73 |
+
self.video_transform = build_video_transform(self.img_size)
|
74 |
+
|
75 |
+
print("✅ Soccer QA ready!")
|
76 |
+
|
77 |
+
def _load_config(self):
|
78 |
+
"""Load model configuration"""
|
79 |
+
config_path = os.path.join(self.model_dir, "config.json")
|
80 |
+
with open(config_path, 'r') as f:
|
81 |
+
self.config = json.load(f)
|
82 |
+
|
83 |
+
self.vision_dim = self.config["vision_dim"] # 1408
|
84 |
+
self.projection_dim = self.config["projection_dim"] # 2048
|
85 |
+
self.text_dim = self.config["text_dim"] # 3072
|
86 |
+
self.img_size = self.config["img_size"] # 256
|
87 |
+
self.num_frames = self.config["num_frames"] # 16
|
88 |
+
|
89 |
+
def _load_tokenizer(self):
|
90 |
+
"""Load tokenizer with <video> token"""
|
91 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
92 |
+
if self.tokenizer.pad_token is None:
|
93 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
94 |
+
|
95 |
+
def _build_vision_model(self):
|
96 |
+
"""Build vision transformer using your src modules"""
|
97 |
+
self.vision_model = vit_giant_rope(
|
98 |
+
img_size=(self.img_size, self.img_size),
|
99 |
+
num_frames=self.num_frames
|
100 |
+
)
|
101 |
+
self.vision_model.to(self.device).eval()
|
102 |
+
|
103 |
+
# Freeze vision model
|
104 |
+
for param in self.vision_model.parameters():
|
105 |
+
param.requires_grad = False
|
106 |
+
|
107 |
+
def _build_text_model(self):
|
108 |
+
"""Build text model - we'll load merged weights later"""
|
109 |
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
110 |
+
"meta-llama/Llama-3.2-3B",
|
111 |
+
torch_dtype=torch.float32,
|
112 |
+
device_map=self.device,
|
113 |
+
trust_remote_code=True
|
114 |
+
)
|
115 |
+
|
116 |
+
# Resize for <video> token to match saved model
|
117 |
+
self.text_model.resize_token_embeddings(len(self.tokenizer))
|
118 |
+
self.text_model.eval()
|
119 |
+
|
120 |
+
def _build_projection(self):
|
121 |
+
"""Build vision projection layer"""
|
122 |
+
self.vision_projection = nn.Sequential(
|
123 |
+
nn.Linear(self.vision_dim, self.projection_dim), # 1408 -> 2048
|
124 |
+
nn.ReLU(),
|
125 |
+
nn.Dropout(0.1),
|
126 |
+
nn.Linear(self.projection_dim, self.text_dim), # 2048 -> 3072
|
127 |
+
nn.LayerNorm(self.text_dim)
|
128 |
+
).to(self.device)
|
129 |
+
|
130 |
+
def _load_weights(self):
|
131 |
+
"""Load all weights from safetensors - optimized approach"""
|
132 |
+
model_path = os.path.join(self.model_dir, "model.safetensors")
|
133 |
+
print(f"Loading weights from: {model_path}")
|
134 |
+
state_dict = load_file(model_path, device=str(self.device))
|
135 |
+
|
136 |
+
# Load vision encoder weights
|
137 |
+
vision_state = {}
|
138 |
+
for key, value in state_dict.items():
|
139 |
+
if key.startswith("vision_encoder."):
|
140 |
+
new_key = key.replace("vision_encoder.", "")
|
141 |
+
vision_state[new_key] = value
|
142 |
+
|
143 |
+
msg = self.vision_model.load_state_dict(vision_state, strict=False)
|
144 |
+
print(f"Vision model loaded: {msg}")
|
145 |
+
|
146 |
+
# Load projection weights
|
147 |
+
projection_state = {}
|
148 |
+
for key, value in state_dict.items():
|
149 |
+
if key.startswith("vision_projection."):
|
150 |
+
new_key = key.replace("vision_projection.", "")
|
151 |
+
projection_state[new_key] = value
|
152 |
+
|
153 |
+
self.vision_projection.load_state_dict(projection_state)
|
154 |
+
print("Projection layer loaded")
|
155 |
+
|
156 |
+
# Load text model weights directly from merged state_dict
|
157 |
+
text_state = {}
|
158 |
+
for key, value in state_dict.items():
|
159 |
+
if key.startswith("text_model."):
|
160 |
+
new_key = key.replace("text_model.", "")
|
161 |
+
text_state[new_key] = value
|
162 |
+
|
163 |
+
# Apply merged weights directly to text model
|
164 |
+
missing_keys, unexpected_keys = self.text_model.load_state_dict(text_state, strict=False)
|
165 |
+
if missing_keys:
|
166 |
+
print(f"Missing keys in text model: {len(missing_keys)} (this is normal)")
|
167 |
+
if unexpected_keys:
|
168 |
+
print(f"Unexpected keys in text model: {len(unexpected_keys)}")
|
169 |
+
print("✅ Text model loaded with merged weights")
|
170 |
+
|
171 |
+
# Clear state_dict from memory
|
172 |
+
del state_dict
|
173 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
174 |
+
|
175 |
+
def _get_video_embeddings(self, video_path):
|
176 |
+
"""Extract video embeddings from video file"""
|
177 |
+
with torch.inference_mode():
|
178 |
+
# Load video
|
179 |
+
video = get_video(video_path, self.num_frames)
|
180 |
+
video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
|
181 |
+
|
182 |
+
# Preprocess
|
183 |
+
x = self.video_transform(video).to(self.device).unsqueeze(0) # [1, 16, 3, 256, 256]
|
184 |
+
|
185 |
+
# Extract features
|
186 |
+
features = self.vision_model(x) # [1, 2048, 1408]
|
187 |
+
|
188 |
+
# Handle reshaping
|
189 |
+
squeezed = features.squeeze(0) # [2048, 1408]
|
190 |
+
if squeezed.shape[0] % 2048 == 0:
|
191 |
+
num_clips = squeezed.shape[0] // 2048
|
192 |
+
reshaped = squeezed.view(num_clips, 2048, 1408)
|
193 |
+
else:
|
194 |
+
reshaped = squeezed.unsqueeze(0) # [1, 2048, 1408]
|
195 |
+
|
196 |
+
return reshaped
|
197 |
+
|
198 |
+
def _project_vision_features(self, vision_features):
|
199 |
+
"""Project vision features to text embedding space"""
|
200 |
+
# vision_features: [num_clips, 2048, 1408]
|
201 |
+
num_clips, patches_per_clip, feature_dim = vision_features.shape
|
202 |
+
|
203 |
+
# Flatten: [num_clips * 2048, 1408]
|
204 |
+
flattened = vision_features.view(-1, feature_dim)
|
205 |
+
|
206 |
+
# Project: [num_clips * 2048, 3072]
|
207 |
+
projected = self.vision_projection(flattened)
|
208 |
+
|
209 |
+
# Return flattened for sequence: [total_patches, 3072]
|
210 |
+
return projected
|
211 |
+
|
212 |
+
def ask(self, video_path, question, max_tokens=128, temperature=0.7, top_p=0.9,
|
213 |
+
repetition_penalty=1.2, no_repeat_ngram_size=3):
|
214 |
+
"""Ask a question about a video
|
215 |
+
|
216 |
+
Args:
|
217 |
+
video_path: Path to video file
|
218 |
+
question: Question about the video
|
219 |
+
max_tokens: Maximum tokens to generate
|
220 |
+
temperature: Sampling temperature
|
221 |
+
top_p: Nucleus sampling parameter
|
222 |
+
repetition_penalty: Penalty for repetition
|
223 |
+
no_repeat_ngram_size: N-gram size for repetition blocking
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Generated answer as string
|
227 |
+
"""
|
228 |
+
with torch.no_grad():
|
229 |
+
# Get video embeddings
|
230 |
+
video_features = self._get_video_embeddings(video_path) # [num_clips, 2048, 1408]
|
231 |
+
vision_embeds = self._project_vision_features(video_features) # [total_patches, 3072]
|
232 |
+
vision_embeds = vision_embeds.unsqueeze(0) # [1, total_patches, 3072]
|
233 |
+
|
234 |
+
# Process question (remove <video> token if present)
|
235 |
+
question_clean = question.replace("<video>", "").strip()
|
236 |
+
|
237 |
+
# Tokenize question
|
238 |
+
question_tokens = self.tokenizer(
|
239 |
+
question_clean,
|
240 |
+
return_tensors="pt",
|
241 |
+
add_special_tokens=True
|
242 |
+
).to(self.device)
|
243 |
+
|
244 |
+
# Get text embeddings
|
245 |
+
text_embeds = self.text_model.get_input_embeddings()(question_tokens.input_ids)
|
246 |
+
|
247 |
+
# Combine vision and text embeddings
|
248 |
+
combined_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
|
249 |
+
|
250 |
+
# Create attention mask
|
251 |
+
vision_attention = torch.ones(
|
252 |
+
1, vision_embeds.shape[1],
|
253 |
+
dtype=question_tokens.attention_mask.dtype,
|
254 |
+
device=self.device
|
255 |
+
)
|
256 |
+
combined_attention_mask = torch.cat([vision_attention, question_tokens.attention_mask], dim=1)
|
257 |
+
|
258 |
+
# Generate response
|
259 |
+
generated_ids = self.text_model.generate(
|
260 |
+
inputs_embeds=combined_embeds,
|
261 |
+
attention_mask=combined_attention_mask,
|
262 |
+
max_new_tokens=max_tokens,
|
263 |
+
temperature=temperature,
|
264 |
+
top_p=top_p,
|
265 |
+
do_sample=True,
|
266 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
267 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
268 |
+
repetition_penalty=repetition_penalty,
|
269 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
270 |
+
use_cache=True,
|
271 |
+
return_dict_in_generate=False
|
272 |
+
)
|
273 |
+
|
274 |
+
# Handle different return formats from generate()
|
275 |
+
if generated_ids.shape[1] > combined_embeds.shape[1]:
|
276 |
+
# Full sequence returned - slice from combined length
|
277 |
+
new_tokens = generated_ids[:, combined_embeds.shape[1]:]
|
278 |
+
else:
|
279 |
+
# Only new tokens returned - use all
|
280 |
+
new_tokens = generated_ids
|
281 |
+
|
282 |
+
generated_text = self.tokenizer.batch_decode(
|
283 |
+
new_tokens,
|
284 |
+
skip_special_tokens=True
|
285 |
+
)[0]
|
286 |
+
|
287 |
+
return generated_text.strip()
|
288 |
+
|
289 |
+
def batch_ask(self, video_path, questions, **kwargs):
|
290 |
+
"""Ask multiple questions about the same video
|
291 |
+
|
292 |
+
Args:
|
293 |
+
video_path: Path to video file
|
294 |
+
questions: List of questions
|
295 |
+
**kwargs: Generation parameters
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
List of {"question": str, "answer": str} dicts
|
299 |
+
"""
|
300 |
+
results = []
|
301 |
+
for question in questions:
|
302 |
+
answer = self.ask(video_path, question, **kwargs)
|
303 |
+
results.append({"question": question, "answer": answer})
|
304 |
+
return results
|
special_tokens_map.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<video>"
|
4 |
+
],
|
5 |
+
"bos_token": {
|
6 |
+
"content": "<|begin_of_text|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"eos_token": {
|
13 |
+
"content": "<|end_of_text|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false
|
18 |
+
},
|
19 |
+
"pad_token": {
|
20 |
+
"content": "<|end_of_text|>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false
|
25 |
+
}
|
26 |
+
}
|
src/datasets/data_manager.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from logging import getLogger
|
7 |
+
|
8 |
+
_GLOBAL_SEED = 0
|
9 |
+
logger = getLogger()
|
10 |
+
|
11 |
+
|
12 |
+
def init_data(
|
13 |
+
batch_size,
|
14 |
+
transform=None,
|
15 |
+
shared_transform=None,
|
16 |
+
data="ImageNet",
|
17 |
+
collator=None,
|
18 |
+
pin_mem=True,
|
19 |
+
num_workers=8,
|
20 |
+
world_size=1,
|
21 |
+
rank=0,
|
22 |
+
root_path=None,
|
23 |
+
image_folder=None,
|
24 |
+
training=True,
|
25 |
+
drop_last=True,
|
26 |
+
subset_file=None,
|
27 |
+
clip_len=None,
|
28 |
+
dataset_fpcs=None,
|
29 |
+
frame_sample_rate=None,
|
30 |
+
duration=None,
|
31 |
+
fps=None,
|
32 |
+
num_clips=1,
|
33 |
+
random_clip_sampling=True,
|
34 |
+
allow_clip_overlap=False,
|
35 |
+
filter_short_videos=False,
|
36 |
+
filter_long_videos=int(1e9),
|
37 |
+
datasets_weights=None,
|
38 |
+
persistent_workers=False,
|
39 |
+
deterministic=True,
|
40 |
+
log_dir=None,
|
41 |
+
):
|
42 |
+
if data.lower() == "imagenet":
|
43 |
+
from src.datasets.imagenet1k import make_imagenet1k
|
44 |
+
|
45 |
+
dataset, data_loader, dist_sampler = make_imagenet1k(
|
46 |
+
transform=transform,
|
47 |
+
batch_size=batch_size,
|
48 |
+
collator=collator,
|
49 |
+
pin_mem=pin_mem,
|
50 |
+
training=training,
|
51 |
+
num_workers=num_workers,
|
52 |
+
world_size=world_size,
|
53 |
+
rank=rank,
|
54 |
+
root_path=root_path,
|
55 |
+
image_folder=image_folder,
|
56 |
+
persistent_workers=persistent_workers,
|
57 |
+
drop_last=drop_last,
|
58 |
+
subset_file=subset_file,
|
59 |
+
)
|
60 |
+
|
61 |
+
elif data.lower() == "videodataset":
|
62 |
+
from src.datasets.video_dataset import make_videodataset
|
63 |
+
|
64 |
+
dataset, data_loader, dist_sampler = make_videodataset(
|
65 |
+
data_paths=root_path,
|
66 |
+
batch_size=batch_size,
|
67 |
+
frames_per_clip=clip_len,
|
68 |
+
dataset_fpcs=dataset_fpcs,
|
69 |
+
frame_step=frame_sample_rate,
|
70 |
+
duration=duration,
|
71 |
+
fps=fps,
|
72 |
+
num_clips=num_clips,
|
73 |
+
random_clip_sampling=random_clip_sampling,
|
74 |
+
allow_clip_overlap=allow_clip_overlap,
|
75 |
+
filter_short_videos=filter_short_videos,
|
76 |
+
filter_long_videos=filter_long_videos,
|
77 |
+
shared_transform=shared_transform,
|
78 |
+
transform=transform,
|
79 |
+
datasets_weights=datasets_weights,
|
80 |
+
collator=collator,
|
81 |
+
num_workers=num_workers,
|
82 |
+
pin_mem=pin_mem,
|
83 |
+
persistent_workers=persistent_workers,
|
84 |
+
world_size=world_size,
|
85 |
+
rank=rank,
|
86 |
+
deterministic=deterministic,
|
87 |
+
log_dir=log_dir,
|
88 |
+
)
|
89 |
+
|
90 |
+
return (data_loader, dist_sampler)
|
src/datasets/imagenet1k.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import subprocess
|
8 |
+
import time
|
9 |
+
from logging import getLogger
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torchvision
|
14 |
+
|
15 |
+
_GLOBAL_SEED = 0
|
16 |
+
logger = getLogger()
|
17 |
+
|
18 |
+
|
19 |
+
class ImageNet(torchvision.datasets.ImageFolder):
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
root,
|
24 |
+
image_folder="imagenet_full_size/061417/",
|
25 |
+
tar_file="imagenet_full_size-061417.tar.gz",
|
26 |
+
transform=None,
|
27 |
+
train=True,
|
28 |
+
job_id=None,
|
29 |
+
local_rank=None,
|
30 |
+
index_targets=False,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
ImageNet
|
34 |
+
|
35 |
+
Dataset wrapper
|
36 |
+
|
37 |
+
:param root: root network directory for ImageNet data
|
38 |
+
:param image_folder: path to images inside root network directory
|
39 |
+
:param tar_file: zipped image_folder inside root network directory
|
40 |
+
:param train: whether to load train data (or validation)
|
41 |
+
:param job_id: scheduler job-id used to create dir on local machine
|
42 |
+
:param index_targets: whether to index the id of each labeled image
|
43 |
+
"""
|
44 |
+
|
45 |
+
suffix = "train/" if train else "val/"
|
46 |
+
data_path = os.path.join(root, image_folder, suffix)
|
47 |
+
logger.info(f"data-path {data_path}")
|
48 |
+
|
49 |
+
super(ImageNet, self).__init__(root=data_path, transform=transform)
|
50 |
+
logger.info("Initialized ImageNet")
|
51 |
+
|
52 |
+
if index_targets:
|
53 |
+
self.targets = []
|
54 |
+
for sample in self.samples:
|
55 |
+
self.targets.append(sample[1])
|
56 |
+
self.targets = np.array(self.targets)
|
57 |
+
self.samples = np.array(self.samples)
|
58 |
+
|
59 |
+
mint = None
|
60 |
+
self.target_indices = []
|
61 |
+
for t in range(len(self.classes)):
|
62 |
+
indices = np.squeeze(np.argwhere(self.targets == t)).tolist()
|
63 |
+
self.target_indices.append(indices)
|
64 |
+
mint = len(indices) if mint is None else min(mint, len(indices))
|
65 |
+
logger.debug(f"num-labeled target {t} {len(indices)}")
|
66 |
+
logger.info(f"min. labeled indices {mint}")
|
67 |
+
|
68 |
+
|
69 |
+
class ImageNetSubset(object):
|
70 |
+
|
71 |
+
def __init__(self, dataset, subset_file):
|
72 |
+
"""
|
73 |
+
ImageNetSubset
|
74 |
+
|
75 |
+
:param dataset: ImageNet dataset object
|
76 |
+
:param subset_file: '.txt' file containing IDs of IN1K images to keep
|
77 |
+
"""
|
78 |
+
self.dataset = dataset
|
79 |
+
self.subset_file = subset_file
|
80 |
+
self.filter_dataset_(subset_file)
|
81 |
+
|
82 |
+
def filter_dataset_(self, subset_file):
|
83 |
+
"""Filter self.dataset to a subset"""
|
84 |
+
root = self.dataset.root
|
85 |
+
class_to_idx = self.dataset.class_to_idx
|
86 |
+
# -- update samples to subset of IN1k targets/samples
|
87 |
+
new_samples = []
|
88 |
+
logger.info(f"Using {subset_file}")
|
89 |
+
with open(subset_file, "r") as rfile:
|
90 |
+
for line in rfile:
|
91 |
+
class_name = line.split("_")[0]
|
92 |
+
target = class_to_idx[class_name]
|
93 |
+
img = line.split("\n")[0]
|
94 |
+
new_samples.append((os.path.join(root, class_name, img), target))
|
95 |
+
self.samples = new_samples
|
96 |
+
|
97 |
+
@property
|
98 |
+
def classes(self):
|
99 |
+
return self.dataset.classes
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.samples)
|
103 |
+
|
104 |
+
def __getitem__(self, index):
|
105 |
+
path, target = self.samples[index]
|
106 |
+
img = self.dataset.loader(path)
|
107 |
+
if self.dataset.transform is not None:
|
108 |
+
img = self.dataset.transform(img)
|
109 |
+
if self.dataset.target_transform is not None:
|
110 |
+
target = self.dataset.target_transform(target)
|
111 |
+
return img, target
|
112 |
+
|
113 |
+
|
114 |
+
def make_imagenet1k(
|
115 |
+
transform,
|
116 |
+
batch_size,
|
117 |
+
collator=None,
|
118 |
+
pin_mem=True,
|
119 |
+
num_workers=8,
|
120 |
+
world_size=1,
|
121 |
+
rank=0,
|
122 |
+
root_path=None,
|
123 |
+
image_folder=None,
|
124 |
+
training=True,
|
125 |
+
drop_last=True,
|
126 |
+
persistent_workers=False,
|
127 |
+
subset_file=None,
|
128 |
+
):
|
129 |
+
dataset = ImageNet(
|
130 |
+
root=root_path,
|
131 |
+
image_folder=image_folder,
|
132 |
+
transform=transform,
|
133 |
+
train=training,
|
134 |
+
index_targets=False,
|
135 |
+
)
|
136 |
+
if subset_file is not None:
|
137 |
+
dataset = ImageNetSubset(dataset, subset_file)
|
138 |
+
logger.info("ImageNet dataset created")
|
139 |
+
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank)
|
140 |
+
data_loader = torch.utils.data.DataLoader(
|
141 |
+
dataset,
|
142 |
+
collate_fn=collator,
|
143 |
+
sampler=dist_sampler,
|
144 |
+
batch_size=batch_size,
|
145 |
+
drop_last=drop_last,
|
146 |
+
pin_memory=pin_mem,
|
147 |
+
num_workers=num_workers,
|
148 |
+
persistent_workers=persistent_workers,
|
149 |
+
)
|
150 |
+
logger.info("ImageNet unsupervised data loader created")
|
151 |
+
|
152 |
+
return dataset, data_loader, dist_sampler
|
src/datasets/utils/dataloader.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import bisect
|
7 |
+
import csv
|
8 |
+
import io
|
9 |
+
import time
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import _utils
|
14 |
+
from torch.utils.data.dataloader import ExceptionWrapper, _DatasetKind, _MultiProcessingDataLoaderIter
|
15 |
+
|
16 |
+
from src.utils.monitoring import ResourceMonitoringThread
|
17 |
+
|
18 |
+
|
19 |
+
class ConcatIndices:
|
20 |
+
"""Helper to map indices of concatenated/mixed datasets to the sample index for the corresponding dataset."""
|
21 |
+
|
22 |
+
cumulative_sizes: np.ndarray
|
23 |
+
|
24 |
+
def __init__(self, sizes):
|
25 |
+
self.cumulative_sizes = np.cumsum(sizes)
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return self.cumulative_sizes[-1]
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
# Returns a pair (dataset_idx, sample_idx)
|
32 |
+
if idx < 0 or idx >= len(self):
|
33 |
+
raise ValueError(f"index must be between 0 and the total size ({len(self)})")
|
34 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
35 |
+
if dataset_idx == 0:
|
36 |
+
return dataset_idx, idx
|
37 |
+
return dataset_idx, idx - self.cumulative_sizes[dataset_idx - 1]
|
38 |
+
|
39 |
+
|
40 |
+
class CSVLogger(object):
|
41 |
+
"""An append-to CSV abstraction. File I/O requires a flush."""
|
42 |
+
|
43 |
+
def __init__(self, fname, header):
|
44 |
+
"""Write header to internal buffers."""
|
45 |
+
self.fname = fname
|
46 |
+
self.buffer = io.StringIO()
|
47 |
+
self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC)
|
48 |
+
self.writer.writerow(header)
|
49 |
+
self.initialized = False
|
50 |
+
|
51 |
+
def writerow(self, row) -> None:
|
52 |
+
"""Write row to internal buffers."""
|
53 |
+
self.writer.writerow(row)
|
54 |
+
|
55 |
+
def flush(self) -> None:
|
56 |
+
"""Flush buffer to file."""
|
57 |
+
# Overwrite old file
|
58 |
+
mode = "a+" if self.initialized else "w"
|
59 |
+
|
60 |
+
with open(self.fname, mode, newline="") as f:
|
61 |
+
f.write(self.buffer.getvalue())
|
62 |
+
|
63 |
+
self.buffer = io.StringIO()
|
64 |
+
self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC)
|
65 |
+
self.initialized = True
|
66 |
+
|
67 |
+
|
68 |
+
class MonitoredDataset(torch.utils.data.Dataset):
|
69 |
+
"""Implement resource monitoring on a per-worker basis.
|
70 |
+
|
71 |
+
The sampling occurs every monitor_interval seconds and writes the log
|
72 |
+
every log_interval seconds to a file specified by log_filename, which
|
73 |
+
maps a worker id to a file using the '%w' placeholder.
|
74 |
+
|
75 |
+
Warning: Do not call this dataset before it is consumed in the DataLoader.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self, dataset: torch.utils.data.Dataset, log_filename: str, log_interval: float, monitor_interval: float
|
80 |
+
):
|
81 |
+
self.dataset = dataset
|
82 |
+
self.log_filename = str(log_filename)
|
83 |
+
self.log_interval = log_interval
|
84 |
+
self.monitor_interval = monitor_interval
|
85 |
+
self._csv_log = None
|
86 |
+
self._monitoring_thread = None
|
87 |
+
self._last_log_time = None
|
88 |
+
# Patch getitems dynamically
|
89 |
+
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
90 |
+
|
91 |
+
def __getitems__(self, index):
|
92 |
+
self.maybe_start_resource_monitoring()
|
93 |
+
return self.dataset.__getitems__(index)
|
94 |
+
|
95 |
+
self.__getitems__ = __getitems__
|
96 |
+
|
97 |
+
def __del__(self):
|
98 |
+
self.stop_resource_monitoring()
|
99 |
+
|
100 |
+
def __getitem__(self, index):
|
101 |
+
self.maybe_start_resource_monitoring()
|
102 |
+
return self.dataset.__getitem__(index)
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return len(self.dataset)
|
106 |
+
|
107 |
+
def _elapsed_log_time(self):
|
108 |
+
if self._last_log_time is None:
|
109 |
+
return float("inf")
|
110 |
+
else:
|
111 |
+
return time.perf_counter() - self._last_log_time
|
112 |
+
|
113 |
+
def _update_log_time(self):
|
114 |
+
self._last_log_time = time.perf_counter()
|
115 |
+
|
116 |
+
def maybe_start_resource_monitoring(self):
|
117 |
+
if self._monitoring_thread is None:
|
118 |
+
|
119 |
+
def callback_fn(resource_sample):
|
120 |
+
worker_info = torch.utils.data.get_worker_info()
|
121 |
+
worker_id = worker_info.id
|
122 |
+
|
123 |
+
if self._csv_log is None:
|
124 |
+
header = [f.name for f in resource_sample.fields()]
|
125 |
+
log_filename = self.log_filename.replace("%w", str(worker_id))
|
126 |
+
self._csv_log = CSVLogger(log_filename, header)
|
127 |
+
row_values = resource_sample.as_tuple()
|
128 |
+
self._csv_log.writerow(row_values)
|
129 |
+
|
130 |
+
if self._elapsed_log_time() > self.log_interval:
|
131 |
+
self._csv_log.flush()
|
132 |
+
self._update_log_time()
|
133 |
+
|
134 |
+
self._monitoring_thread = ResourceMonitoringThread(
|
135 |
+
None, self.monitor_interval, stats_callback_fn=callback_fn
|
136 |
+
)
|
137 |
+
self._monitoring_thread.start()
|
138 |
+
|
139 |
+
def stop_resource_monitoring(self):
|
140 |
+
if self._monitoring_thread:
|
141 |
+
self._monitoring_thread.stop()
|
142 |
+
|
143 |
+
|
144 |
+
class NondeterministicDataLoader(torch.utils.data.DataLoader):
|
145 |
+
"""Override torch dataloader to return out of order."""
|
146 |
+
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
"""Pass through constructor."""
|
149 |
+
super().__init__(*args, **kwargs)
|
150 |
+
|
151 |
+
def _get_iterator(self):
|
152 |
+
if self.num_workers:
|
153 |
+
self.check_worker_number_rationality()
|
154 |
+
return _SloppyMultiProcessingDataLoaderIter(self)
|
155 |
+
else:
|
156 |
+
return super()._get_iterator()
|
157 |
+
|
158 |
+
|
159 |
+
class _SloppyMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
|
160 |
+
|
161 |
+
def __init__(self, *args, **kwargs):
|
162 |
+
"""Pass through constructor."""
|
163 |
+
super().__init__(*args, **kwargs)
|
164 |
+
|
165 |
+
def _next_data(self):
|
166 |
+
"""Adds out of order returns."""
|
167 |
+
while True:
|
168 |
+
# If the worker responsible for `self._rcvd_idx` has already ended
|
169 |
+
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
170 |
+
# we try to advance `self._rcvd_idx` to find the next valid index.
|
171 |
+
#
|
172 |
+
# This part needs to run in the loop because both the `self._get_data()`
|
173 |
+
# call and `_IterableDatasetStopIteration` check below can mark
|
174 |
+
# extra worker(s) as dead.
|
175 |
+
while self._rcvd_idx < self._send_idx:
|
176 |
+
info = self._task_info[self._rcvd_idx]
|
177 |
+
if info is None:
|
178 |
+
# Found a reordered tombstone
|
179 |
+
del self._task_info[self._rcvd_idx]
|
180 |
+
self._rcvd_idx += 1
|
181 |
+
self._try_put_index()
|
182 |
+
else:
|
183 |
+
worker_id = info[0]
|
184 |
+
# has data or is still active
|
185 |
+
if len(info) == 2 or self._workers_status[worker_id]:
|
186 |
+
break
|
187 |
+
del self._task_info[self._rcvd_idx]
|
188 |
+
self._rcvd_idx += 1
|
189 |
+
else:
|
190 |
+
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
191 |
+
if not self._persistent_workers:
|
192 |
+
self._shutdown_workers()
|
193 |
+
raise StopIteration
|
194 |
+
|
195 |
+
# Now `self._rcvd_idx` is the batch index we want to fetch
|
196 |
+
|
197 |
+
# Check if the next sample has already been generated
|
198 |
+
if len(self._task_info[self._rcvd_idx]) == 2:
|
199 |
+
data = self._task_info.pop(self._rcvd_idx)[1]
|
200 |
+
return self._process_data(data)
|
201 |
+
|
202 |
+
assert not self._shutdown and self._tasks_outstanding > 0
|
203 |
+
idx, data = self._get_data()
|
204 |
+
self._tasks_outstanding -= 1
|
205 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
206 |
+
# Check for _IterableDatasetStopIteration
|
207 |
+
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
208 |
+
if self._persistent_workers:
|
209 |
+
self._workers_status[data.worker_id] = False
|
210 |
+
else:
|
211 |
+
self._mark_worker_as_unavailable(data.worker_id)
|
212 |
+
self._try_put_index()
|
213 |
+
continue
|
214 |
+
|
215 |
+
if idx != self._rcvd_idx:
|
216 |
+
# Tombstone to recieve later
|
217 |
+
self._task_info[idx] = None
|
218 |
+
if isinstance(data, ExceptionWrapper):
|
219 |
+
data.reraise()
|
220 |
+
return data
|
221 |
+
else:
|
222 |
+
del self._task_info[idx]
|
223 |
+
return self._process_data(data)
|
224 |
+
|
225 |
+
|
226 |
+
def get_worker_info():
|
227 |
+
worker_info = torch.utils.data.get_worker_info()
|
228 |
+
if worker_info is None:
|
229 |
+
num_workers = 1
|
230 |
+
worker_id = 0
|
231 |
+
else:
|
232 |
+
num_workers = worker_info.num_workers
|
233 |
+
worker_id = worker_info.id
|
234 |
+
return num_workers, worker_id
|
src/datasets/utils/utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from src.utils.cluster import dataset_paths
|
7 |
+
from src.utils.logging import get_logger
|
8 |
+
|
9 |
+
logger = get_logger("Datasets utils")
|
10 |
+
|
11 |
+
|
12 |
+
def get_dataset_paths(datasets: list[str]):
|
13 |
+
paths = []
|
14 |
+
for d in datasets:
|
15 |
+
try:
|
16 |
+
path = dataset_paths().get(d)
|
17 |
+
except Exception:
|
18 |
+
raise Exception(f"Unknown dataset: {d}")
|
19 |
+
paths.append(path)
|
20 |
+
logger.info(f"Datapaths {paths}")
|
21 |
+
return paths
|
src/datasets/utils/video/__pycache__/functional.cpython-312.pyc
ADDED
Binary file (5.52 kB). View file
|
|
src/datasets/utils/video/__pycache__/randaugment.cpython-312.pyc
ADDED
Binary file (18.1 kB). View file
|
|
src/datasets/utils/video/__pycache__/transforms.cpython-312.pyc
ADDED
Binary file (53.7 kB). View file
|
|
src/datasets/utils/video/__pycache__/volume_transforms.cpython-312.pyc
ADDED
Binary file (6.57 kB). View file
|
|
src/datasets/utils/video/functional.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numbers
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import PIL
|
11 |
+
import torch
|
12 |
+
from torchvision.transforms import functional as tvf
|
13 |
+
|
14 |
+
|
15 |
+
def _is_tensor_clip(clip):
|
16 |
+
return torch.is_tensor(clip) and clip.ndimension() == 4
|
17 |
+
|
18 |
+
|
19 |
+
def crop_clip(clip, min_h, min_w, h, w):
|
20 |
+
if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
|
21 |
+
if clip[0].shape[-1] == 3:
|
22 |
+
cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip]
|
23 |
+
else:
|
24 |
+
assert clip[0].shape[0] == 3
|
25 |
+
cropped = [img[:, min_h : min_h + h, min_w : min_w + w] for img in clip]
|
26 |
+
|
27 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
28 |
+
cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip]
|
29 |
+
|
30 |
+
else:
|
31 |
+
raise TypeError(
|
32 |
+
"Expected numpy.ndarray or PIL.Image or torch.Tensor):" + "but got list of {0}".format(type(clip[0]))
|
33 |
+
)
|
34 |
+
return cropped
|
35 |
+
|
36 |
+
|
37 |
+
def resize_clip(clip, size, interpolation="bilinear"):
|
38 |
+
if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
|
39 |
+
if isinstance(size, numbers.Number):
|
40 |
+
if clip[0].shape[-1] == 3:
|
41 |
+
im_h, im_w, im_c = clip[0].shape
|
42 |
+
else:
|
43 |
+
assert clip[0].shape[0] == 3
|
44 |
+
im_c, im_h, im_w = clip[0].shape
|
45 |
+
# Min spatial dim already matches minimal size
|
46 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
|
47 |
+
return clip
|
48 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
49 |
+
size = (new_w, new_h)
|
50 |
+
else:
|
51 |
+
size = size[0], size[1]
|
52 |
+
|
53 |
+
if isinstance(clip[0], np.ndarray):
|
54 |
+
if interpolation == "bilinear":
|
55 |
+
np_inter = cv2.INTER_LINEAR
|
56 |
+
else:
|
57 |
+
np_inter = cv2.INTER_NEAREST
|
58 |
+
scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip]
|
59 |
+
else: # isinstance(clip[0], torch.Tensor)
|
60 |
+
if interpolation == "bilinear":
|
61 |
+
np_inter = tvf.InterpolationMode.BILINEAR
|
62 |
+
else:
|
63 |
+
np_inter = tvf.InterpolationMode.NEAREST
|
64 |
+
size = (size[1], size[0]) # torchvision transformers expect the size in (h, w) order.
|
65 |
+
scaled = [tvf.resize(img, size, interpolation=np_inter) for img in clip]
|
66 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
67 |
+
if isinstance(size, numbers.Number):
|
68 |
+
im_w, im_h = clip[0].size
|
69 |
+
# Min spatial dim already matches minimal size
|
70 |
+
if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
|
71 |
+
return clip
|
72 |
+
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
73 |
+
size = (new_w, new_h)
|
74 |
+
else:
|
75 |
+
size = size[1], size[0]
|
76 |
+
if interpolation == "bilinear":
|
77 |
+
pil_inter = PIL.Image.BILINEAR
|
78 |
+
else:
|
79 |
+
pil_inter = PIL.Image.NEAREST
|
80 |
+
scaled = [img.resize(size, pil_inter) for img in clip]
|
81 |
+
else:
|
82 |
+
raise TypeError(
|
83 |
+
"Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0]))
|
84 |
+
)
|
85 |
+
return scaled
|
86 |
+
|
87 |
+
|
88 |
+
def get_resize_sizes(im_h, im_w, size):
|
89 |
+
if im_w < im_h:
|
90 |
+
ow = size
|
91 |
+
oh = int(size * im_h / im_w)
|
92 |
+
else:
|
93 |
+
oh = size
|
94 |
+
ow = int(size * im_w / im_h)
|
95 |
+
return oh, ow
|
96 |
+
|
97 |
+
|
98 |
+
def normalize(clip, mean, std, inplace=False):
|
99 |
+
if not _is_tensor_clip(clip):
|
100 |
+
raise TypeError("tensor is not a torch clip.")
|
101 |
+
|
102 |
+
if not inplace:
|
103 |
+
clip = clip.clone()
|
104 |
+
|
105 |
+
dtype = clip.dtype
|
106 |
+
mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
|
107 |
+
std = torch.as_tensor(std, dtype=dtype, device=clip.device)
|
108 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
109 |
+
|
110 |
+
return clip
|
src/datasets/utils/video/randaugment.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Copyright 2020 Ross Wightman
|
4 |
+
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
# This implementation is based on
|
18 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
|
19 |
+
# published under an Apache License 2.0.
|
20 |
+
|
21 |
+
# COMMENT FROM ORIGINAL:
|
22 |
+
# AutoAugment, RandAugment, and AugMix for PyTorch
|
23 |
+
# This code implements the searched ImageNet policies with various tweaks and
|
24 |
+
# improvements and does not include any of the search code. AA and RA
|
25 |
+
# Implementation adapted from:
|
26 |
+
# https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
27 |
+
# AugMix adapted from:
|
28 |
+
# https://github.com/google-research/augmix
|
29 |
+
# Papers:
|
30 |
+
# AutoAugment: Learning Augmentation Policies from Data
|
31 |
+
# https://arxiv.org/abs/1805.09501
|
32 |
+
# Learning Data Augmentation Strategies for Object Detection
|
33 |
+
# https://arxiv.org/abs/1906.11172
|
34 |
+
# RandAugment: Practical automated data augmentation...
|
35 |
+
# https://arxiv.org/abs/1909.13719
|
36 |
+
# AugMix: A Simple Data Processing Method to Improve Robustness and
|
37 |
+
# Uncertainty https://arxiv.org/abs/1912.02781
|
38 |
+
|
39 |
+
import math
|
40 |
+
import random
|
41 |
+
import re
|
42 |
+
|
43 |
+
import numpy as np
|
44 |
+
import PIL
|
45 |
+
from PIL import Image, ImageEnhance, ImageOps
|
46 |
+
|
47 |
+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
|
48 |
+
|
49 |
+
_FILL = (128, 128, 128)
|
50 |
+
|
51 |
+
# This signifies the max integer that the controller RNN could predict for the
|
52 |
+
# augmentation scheme.
|
53 |
+
_MAX_LEVEL = 10.0
|
54 |
+
|
55 |
+
_HPARAMS_DEFAULT = {
|
56 |
+
"translate_const": 250,
|
57 |
+
"img_mean": _FILL,
|
58 |
+
}
|
59 |
+
|
60 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
61 |
+
|
62 |
+
|
63 |
+
def _interpolation(kwargs):
|
64 |
+
interpolation = kwargs.pop("resample", Image.BILINEAR)
|
65 |
+
if isinstance(interpolation, (list, tuple)):
|
66 |
+
return random.choice(interpolation)
|
67 |
+
else:
|
68 |
+
return interpolation
|
69 |
+
|
70 |
+
|
71 |
+
def _check_args_tf(kwargs):
|
72 |
+
if "fillcolor" in kwargs and _PIL_VER < (5, 0):
|
73 |
+
kwargs.pop("fillcolor")
|
74 |
+
kwargs["resample"] = _interpolation(kwargs)
|
75 |
+
|
76 |
+
|
77 |
+
def shear_x(img, factor, **kwargs):
|
78 |
+
_check_args_tf(kwargs)
|
79 |
+
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
80 |
+
|
81 |
+
|
82 |
+
def shear_y(img, factor, **kwargs):
|
83 |
+
_check_args_tf(kwargs)
|
84 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
85 |
+
|
86 |
+
|
87 |
+
def translate_x_rel(img, pct, **kwargs):
|
88 |
+
pixels = pct * img.size[0]
|
89 |
+
_check_args_tf(kwargs)
|
90 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
91 |
+
|
92 |
+
|
93 |
+
def translate_y_rel(img, pct, **kwargs):
|
94 |
+
pixels = pct * img.size[1]
|
95 |
+
_check_args_tf(kwargs)
|
96 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
97 |
+
|
98 |
+
|
99 |
+
def translate_x_abs(img, pixels, **kwargs):
|
100 |
+
_check_args_tf(kwargs)
|
101 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
102 |
+
|
103 |
+
|
104 |
+
def translate_y_abs(img, pixels, **kwargs):
|
105 |
+
_check_args_tf(kwargs)
|
106 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
107 |
+
|
108 |
+
|
109 |
+
def rotate(img, degrees, **kwargs):
|
110 |
+
_check_args_tf(kwargs)
|
111 |
+
if _PIL_VER >= (5, 2):
|
112 |
+
return img.rotate(degrees, **kwargs)
|
113 |
+
elif _PIL_VER >= (5, 0):
|
114 |
+
w, h = img.size
|
115 |
+
post_trans = (0, 0)
|
116 |
+
rotn_center = (w / 2.0, h / 2.0)
|
117 |
+
angle = -math.radians(degrees)
|
118 |
+
matrix = [
|
119 |
+
round(math.cos(angle), 15),
|
120 |
+
round(math.sin(angle), 15),
|
121 |
+
0.0,
|
122 |
+
round(-math.sin(angle), 15),
|
123 |
+
round(math.cos(angle), 15),
|
124 |
+
0.0,
|
125 |
+
]
|
126 |
+
|
127 |
+
def transform(x, y, matrix):
|
128 |
+
(a, b, c, d, e, f) = matrix
|
129 |
+
return a * x + b * y + c, d * x + e * y + f
|
130 |
+
|
131 |
+
matrix[2], matrix[5] = transform(
|
132 |
+
-rotn_center[0] - post_trans[0],
|
133 |
+
-rotn_center[1] - post_trans[1],
|
134 |
+
matrix,
|
135 |
+
)
|
136 |
+
matrix[2] += rotn_center[0]
|
137 |
+
matrix[5] += rotn_center[1]
|
138 |
+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
139 |
+
else:
|
140 |
+
return img.rotate(degrees, resample=kwargs["resample"])
|
141 |
+
|
142 |
+
|
143 |
+
def auto_contrast(img, **__):
|
144 |
+
return ImageOps.autocontrast(img)
|
145 |
+
|
146 |
+
|
147 |
+
def invert(img, **__):
|
148 |
+
return ImageOps.invert(img)
|
149 |
+
|
150 |
+
|
151 |
+
def equalize(img, **__):
|
152 |
+
return ImageOps.equalize(img)
|
153 |
+
|
154 |
+
|
155 |
+
def solarize(img, thresh, **__):
|
156 |
+
return ImageOps.solarize(img, thresh)
|
157 |
+
|
158 |
+
|
159 |
+
def solarize_add(img, add, thresh=128, **__):
|
160 |
+
lut = []
|
161 |
+
for i in range(256):
|
162 |
+
if i < thresh:
|
163 |
+
lut.append(min(255, i + add))
|
164 |
+
else:
|
165 |
+
lut.append(i)
|
166 |
+
if img.mode in ("L", "RGB"):
|
167 |
+
if img.mode == "RGB" and len(lut) == 256:
|
168 |
+
lut = lut + lut + lut
|
169 |
+
return img.point(lut)
|
170 |
+
else:
|
171 |
+
return img
|
172 |
+
|
173 |
+
|
174 |
+
def posterize(img, bits_to_keep, **__):
|
175 |
+
if bits_to_keep >= 8:
|
176 |
+
return img
|
177 |
+
return ImageOps.posterize(img, bits_to_keep)
|
178 |
+
|
179 |
+
|
180 |
+
def contrast(img, factor, **__):
|
181 |
+
return ImageEnhance.Contrast(img).enhance(factor)
|
182 |
+
|
183 |
+
|
184 |
+
def color(img, factor, **__):
|
185 |
+
return ImageEnhance.Color(img).enhance(factor)
|
186 |
+
|
187 |
+
|
188 |
+
def brightness(img, factor, **__):
|
189 |
+
return ImageEnhance.Brightness(img).enhance(factor)
|
190 |
+
|
191 |
+
|
192 |
+
def sharpness(img, factor, **__):
|
193 |
+
return ImageEnhance.Sharpness(img).enhance(factor)
|
194 |
+
|
195 |
+
|
196 |
+
def _randomly_negate(v):
|
197 |
+
"""With 50% prob, negate the value"""
|
198 |
+
return -v if random.random() > 0.5 else v
|
199 |
+
|
200 |
+
|
201 |
+
def _rotate_level_to_arg(level, _hparams):
|
202 |
+
# range [-30, 30]
|
203 |
+
level = (level / _MAX_LEVEL) * 30.0
|
204 |
+
level = _randomly_negate(level)
|
205 |
+
return (level,)
|
206 |
+
|
207 |
+
|
208 |
+
def _enhance_level_to_arg(level, _hparams):
|
209 |
+
# range [0.1, 1.9]
|
210 |
+
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
211 |
+
|
212 |
+
|
213 |
+
def _enhance_increasing_level_to_arg(level, _hparams):
|
214 |
+
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
215 |
+
# range [0.1, 1.9]
|
216 |
+
level = (level / _MAX_LEVEL) * 0.9
|
217 |
+
level = 1.0 + _randomly_negate(level)
|
218 |
+
return (level,)
|
219 |
+
|
220 |
+
|
221 |
+
def _shear_level_to_arg(level, _hparams):
|
222 |
+
# range [-0.3, 0.3]
|
223 |
+
level = (level / _MAX_LEVEL) * 0.3
|
224 |
+
level = _randomly_negate(level)
|
225 |
+
return (level,)
|
226 |
+
|
227 |
+
|
228 |
+
def _translate_abs_level_to_arg(level, hparams):
|
229 |
+
translate_const = hparams["translate_const"]
|
230 |
+
level = (level / _MAX_LEVEL) * float(translate_const)
|
231 |
+
level = _randomly_negate(level)
|
232 |
+
return (level,)
|
233 |
+
|
234 |
+
|
235 |
+
def _translate_rel_level_to_arg(level, hparams):
|
236 |
+
# default range [-0.45, 0.45]
|
237 |
+
translate_pct = hparams.get("translate_pct", 0.45)
|
238 |
+
level = (level / _MAX_LEVEL) * translate_pct
|
239 |
+
level = _randomly_negate(level)
|
240 |
+
return (level,)
|
241 |
+
|
242 |
+
|
243 |
+
def _posterize_level_to_arg(level, _hparams):
|
244 |
+
# As per Tensorflow TPU EfficientNet impl
|
245 |
+
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
246 |
+
# intensity/severity of augmentation decreases with level
|
247 |
+
return (int((level / _MAX_LEVEL) * 4),)
|
248 |
+
|
249 |
+
|
250 |
+
def _posterize_increasing_level_to_arg(level, hparams):
|
251 |
+
# As per Tensorflow models research and UDA impl
|
252 |
+
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
253 |
+
# intensity/severity of augmentation increases with level
|
254 |
+
return (4 - _posterize_level_to_arg(level, hparams)[0],)
|
255 |
+
|
256 |
+
|
257 |
+
def _posterize_original_level_to_arg(level, _hparams):
|
258 |
+
# As per original AutoAugment paper description
|
259 |
+
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
260 |
+
# intensity/severity of augmentation decreases with level
|
261 |
+
return (int((level / _MAX_LEVEL) * 4) + 4,)
|
262 |
+
|
263 |
+
|
264 |
+
def _solarize_level_to_arg(level, _hparams):
|
265 |
+
# range [0, 256]
|
266 |
+
# intensity/severity of augmentation decreases with level
|
267 |
+
return (int((level / _MAX_LEVEL) * 256),)
|
268 |
+
|
269 |
+
|
270 |
+
def _solarize_increasing_level_to_arg(level, _hparams):
|
271 |
+
# range [0, 256]
|
272 |
+
# intensity/severity of augmentation increases with level
|
273 |
+
return (256 - _solarize_level_to_arg(level, _hparams)[0],)
|
274 |
+
|
275 |
+
|
276 |
+
def _solarize_add_level_to_arg(level, _hparams):
|
277 |
+
# range [0, 110]
|
278 |
+
return (int((level / _MAX_LEVEL) * 110),)
|
279 |
+
|
280 |
+
|
281 |
+
LEVEL_TO_ARG = {
|
282 |
+
"AutoContrast": None,
|
283 |
+
"Equalize": None,
|
284 |
+
"Invert": None,
|
285 |
+
"Rotate": _rotate_level_to_arg,
|
286 |
+
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
287 |
+
"Posterize": _posterize_level_to_arg,
|
288 |
+
"PosterizeIncreasing": _posterize_increasing_level_to_arg,
|
289 |
+
"PosterizeOriginal": _posterize_original_level_to_arg,
|
290 |
+
"Solarize": _solarize_level_to_arg,
|
291 |
+
"SolarizeIncreasing": _solarize_increasing_level_to_arg,
|
292 |
+
"SolarizeAdd": _solarize_add_level_to_arg,
|
293 |
+
"Color": _enhance_level_to_arg,
|
294 |
+
"ColorIncreasing": _enhance_increasing_level_to_arg,
|
295 |
+
"Contrast": _enhance_level_to_arg,
|
296 |
+
"ContrastIncreasing": _enhance_increasing_level_to_arg,
|
297 |
+
"Brightness": _enhance_level_to_arg,
|
298 |
+
"BrightnessIncreasing": _enhance_increasing_level_to_arg,
|
299 |
+
"Sharpness": _enhance_level_to_arg,
|
300 |
+
"SharpnessIncreasing": _enhance_increasing_level_to_arg,
|
301 |
+
"ShearX": _shear_level_to_arg,
|
302 |
+
"ShearY": _shear_level_to_arg,
|
303 |
+
"TranslateX": _translate_abs_level_to_arg,
|
304 |
+
"TranslateY": _translate_abs_level_to_arg,
|
305 |
+
"TranslateXRel": _translate_rel_level_to_arg,
|
306 |
+
"TranslateYRel": _translate_rel_level_to_arg,
|
307 |
+
}
|
308 |
+
|
309 |
+
|
310 |
+
NAME_TO_OP = {
|
311 |
+
"AutoContrast": auto_contrast,
|
312 |
+
"Equalize": equalize,
|
313 |
+
"Invert": invert,
|
314 |
+
"Rotate": rotate,
|
315 |
+
"Posterize": posterize,
|
316 |
+
"PosterizeIncreasing": posterize,
|
317 |
+
"PosterizeOriginal": posterize,
|
318 |
+
"Solarize": solarize,
|
319 |
+
"SolarizeIncreasing": solarize,
|
320 |
+
"SolarizeAdd": solarize_add,
|
321 |
+
"Color": color,
|
322 |
+
"ColorIncreasing": color,
|
323 |
+
"Contrast": contrast,
|
324 |
+
"ContrastIncreasing": contrast,
|
325 |
+
"Brightness": brightness,
|
326 |
+
"BrightnessIncreasing": brightness,
|
327 |
+
"Sharpness": sharpness,
|
328 |
+
"SharpnessIncreasing": sharpness,
|
329 |
+
"ShearX": shear_x,
|
330 |
+
"ShearY": shear_y,
|
331 |
+
"TranslateX": translate_x_abs,
|
332 |
+
"TranslateY": translate_y_abs,
|
333 |
+
"TranslateXRel": translate_x_rel,
|
334 |
+
"TranslateYRel": translate_y_rel,
|
335 |
+
}
|
336 |
+
|
337 |
+
|
338 |
+
class AugmentOp:
|
339 |
+
"""
|
340 |
+
Apply for video.
|
341 |
+
"""
|
342 |
+
|
343 |
+
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
344 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
345 |
+
self.aug_fn = NAME_TO_OP[name]
|
346 |
+
self.level_fn = LEVEL_TO_ARG[name]
|
347 |
+
self.prob = prob
|
348 |
+
self.magnitude = magnitude
|
349 |
+
self.hparams = hparams.copy()
|
350 |
+
self.kwargs = {
|
351 |
+
"fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL,
|
352 |
+
"resample": hparams["interpolation"] if "interpolation" in hparams else _RANDOM_INTERPOLATION,
|
353 |
+
}
|
354 |
+
|
355 |
+
# If magnitude_std is > 0, we introduce some randomness
|
356 |
+
# in the usually fixed policy and sample magnitude from a normal distribution
|
357 |
+
# with mean `magnitude` and std-dev of `magnitude_std`.
|
358 |
+
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
359 |
+
self.magnitude_std = self.hparams.get("magnitude_std", 0)
|
360 |
+
|
361 |
+
def __call__(self, img_list):
|
362 |
+
if self.prob < 1.0 and random.random() > self.prob:
|
363 |
+
return img_list
|
364 |
+
magnitude = self.magnitude
|
365 |
+
if self.magnitude_std and self.magnitude_std > 0:
|
366 |
+
magnitude = random.gauss(magnitude, self.magnitude_std)
|
367 |
+
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
368 |
+
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else ()
|
369 |
+
|
370 |
+
if isinstance(img_list, list):
|
371 |
+
return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list]
|
372 |
+
else:
|
373 |
+
return self.aug_fn(img_list, *level_args, **self.kwargs)
|
374 |
+
|
375 |
+
|
376 |
+
_RAND_TRANSFORMS = [
|
377 |
+
"AutoContrast",
|
378 |
+
"Equalize",
|
379 |
+
"Invert",
|
380 |
+
"Rotate",
|
381 |
+
"Posterize",
|
382 |
+
"Solarize",
|
383 |
+
"SolarizeAdd",
|
384 |
+
"Color",
|
385 |
+
"Contrast",
|
386 |
+
"Brightness",
|
387 |
+
"Sharpness",
|
388 |
+
"ShearX",
|
389 |
+
"ShearY",
|
390 |
+
"TranslateXRel",
|
391 |
+
"TranslateYRel",
|
392 |
+
]
|
393 |
+
|
394 |
+
|
395 |
+
_RAND_INCREASING_TRANSFORMS = [
|
396 |
+
"AutoContrast",
|
397 |
+
"Equalize",
|
398 |
+
"Invert",
|
399 |
+
"Rotate",
|
400 |
+
"PosterizeIncreasing",
|
401 |
+
"SolarizeIncreasing",
|
402 |
+
"SolarizeAdd",
|
403 |
+
"ColorIncreasing",
|
404 |
+
"ContrastIncreasing",
|
405 |
+
"BrightnessIncreasing",
|
406 |
+
"SharpnessIncreasing",
|
407 |
+
"ShearX",
|
408 |
+
"ShearY",
|
409 |
+
"TranslateXRel",
|
410 |
+
"TranslateYRel",
|
411 |
+
]
|
412 |
+
|
413 |
+
|
414 |
+
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
415 |
+
# They may not result in increased performance, but could likely be tuned to so.
|
416 |
+
_RAND_CHOICE_WEIGHTS_0 = {
|
417 |
+
"Rotate": 0.3,
|
418 |
+
"ShearX": 0.2,
|
419 |
+
"ShearY": 0.2,
|
420 |
+
"TranslateXRel": 0.1,
|
421 |
+
"TranslateYRel": 0.1,
|
422 |
+
"Color": 0.025,
|
423 |
+
"Sharpness": 0.025,
|
424 |
+
"AutoContrast": 0.025,
|
425 |
+
"Solarize": 0.005,
|
426 |
+
"SolarizeAdd": 0.005,
|
427 |
+
"Contrast": 0.005,
|
428 |
+
"Brightness": 0.005,
|
429 |
+
"Equalize": 0.005,
|
430 |
+
"Posterize": 0,
|
431 |
+
"Invert": 0,
|
432 |
+
}
|
433 |
+
|
434 |
+
_RAND_CHOICE_WEIGHTS_1 = {
|
435 |
+
"Rotate": 0.0,
|
436 |
+
"ShearX": 0.0,
|
437 |
+
"ShearY": 0.0,
|
438 |
+
"TranslateXRel": 0.0,
|
439 |
+
"TranslateYRel": 0.0,
|
440 |
+
"Color": 0.25,
|
441 |
+
"Sharpness": 0.25,
|
442 |
+
"AutoContrast": 0.25,
|
443 |
+
"Solarize": 0.05,
|
444 |
+
"SolarizeAdd": 0.05,
|
445 |
+
"Contrast": 0.05,
|
446 |
+
"Brightness": 0.05,
|
447 |
+
"Equalize": 0.05,
|
448 |
+
"Posterize": 0,
|
449 |
+
"Invert": 0,
|
450 |
+
}
|
451 |
+
|
452 |
+
|
453 |
+
def _select_rand_weights(weight_idx=0, transforms=None):
|
454 |
+
transforms = transforms or _RAND_TRANSFORMS
|
455 |
+
assert weight_idx == 0 or weight_idx == 1 # only two sets of weights currently
|
456 |
+
if weight_idx == 0:
|
457 |
+
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
458 |
+
elif weight_idx == 1:
|
459 |
+
rand_weights = _RAND_CHOICE_WEIGHTS_1
|
460 |
+
probs = [rand_weights[k] for k in transforms]
|
461 |
+
probs /= np.sum(probs)
|
462 |
+
return probs
|
463 |
+
|
464 |
+
|
465 |
+
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
466 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
467 |
+
transforms = transforms or _RAND_TRANSFORMS
|
468 |
+
return [AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
469 |
+
|
470 |
+
|
471 |
+
class RandAugment:
|
472 |
+
def __init__(self, ops, num_layers=2, choice_weights=None):
|
473 |
+
self.ops = ops
|
474 |
+
self.num_layers = num_layers
|
475 |
+
self.choice_weights = choice_weights
|
476 |
+
|
477 |
+
def __call__(self, img):
|
478 |
+
# no replacement when using weighted choice
|
479 |
+
ops = np.random.choice(
|
480 |
+
self.ops,
|
481 |
+
self.num_layers,
|
482 |
+
replace=self.choice_weights is None,
|
483 |
+
p=self.choice_weights,
|
484 |
+
)
|
485 |
+
for op in ops:
|
486 |
+
img = op(img)
|
487 |
+
return img
|
488 |
+
|
489 |
+
|
490 |
+
def rand_augment_transform(config_str, hparams):
|
491 |
+
"""
|
492 |
+
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
493 |
+
|
494 |
+
Create a RandAugment transform
|
495 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
496 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
497 |
+
sections, not order sepecific determine
|
498 |
+
'm' - integer magnitude of rand augment
|
499 |
+
'n' - integer num layers (number of transform ops selected per image)
|
500 |
+
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
501 |
+
'mstd' - float std deviation of magnitude noise applied
|
502 |
+
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
503 |
+
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
504 |
+
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
505 |
+
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
506 |
+
:return: A PyTorch compatible Transform
|
507 |
+
"""
|
508 |
+
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
509 |
+
num_layers = 2 # default to 2 ops per image
|
510 |
+
weight_idx = None # default to no probability weights for op choice
|
511 |
+
transforms = _RAND_TRANSFORMS
|
512 |
+
config = config_str.split("-")
|
513 |
+
assert config[0] == "rand"
|
514 |
+
config = config[1:]
|
515 |
+
for c in config:
|
516 |
+
cs = re.split(r"(\d.*)", c)
|
517 |
+
if len(cs) < 2:
|
518 |
+
continue
|
519 |
+
key, val = cs[:2]
|
520 |
+
if key == "mstd":
|
521 |
+
# noise param injected via hparams for now
|
522 |
+
hparams.setdefault("magnitude_std", float(val))
|
523 |
+
elif key == "inc":
|
524 |
+
if bool(val):
|
525 |
+
transforms = _RAND_INCREASING_TRANSFORMS
|
526 |
+
elif key == "m":
|
527 |
+
magnitude = int(val)
|
528 |
+
elif key == "n":
|
529 |
+
num_layers = int(val)
|
530 |
+
elif key == "w":
|
531 |
+
weight_idx = int(val)
|
532 |
+
else:
|
533 |
+
assert NotImplementedError
|
534 |
+
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
535 |
+
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
536 |
+
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
src/datasets/utils/video/randerase.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Copyright 2020 Ross Wightman
|
4 |
+
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
# This implementation is based on
|
18 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
|
19 |
+
# published under an Apache License 2.0.
|
20 |
+
|
21 |
+
|
22 |
+
import math
|
23 |
+
import random
|
24 |
+
|
25 |
+
import torch
|
26 |
+
|
27 |
+
|
28 |
+
def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"):
|
29 |
+
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
|
30 |
+
# paths, flip the order so normal is run on CPU if this becomes a problem
|
31 |
+
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
|
32 |
+
if per_pixel:
|
33 |
+
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
|
34 |
+
elif rand_color:
|
35 |
+
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
|
36 |
+
else:
|
37 |
+
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
|
38 |
+
|
39 |
+
|
40 |
+
class RandomErasing:
|
41 |
+
"""Randomly selects a rectangle region in an image and erases its pixels.
|
42 |
+
'Random Erasing Data Augmentation' by Zhong et al.
|
43 |
+
See https://arxiv.org/pdf/1708.04896.pdf
|
44 |
+
This variant of RandomErasing is intended to be applied to either a batch
|
45 |
+
or single image tensor after it has been normalized by dataset mean and std.
|
46 |
+
Args:
|
47 |
+
probability: Probability that the Random Erasing operation will be performed.
|
48 |
+
min_area: Minimum percentage of erased area wrt input image area.
|
49 |
+
max_area: Maximum percentage of erased area wrt input image area.
|
50 |
+
min_aspect: Minimum aspect ratio of erased area.
|
51 |
+
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
|
52 |
+
'const' - erase block is constant color of 0 for all channels
|
53 |
+
'rand' - erase block is same per-channel random (normal) color
|
54 |
+
'pixel' - erase block is per-pixel random (normal) color
|
55 |
+
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
|
56 |
+
per-image count is randomly chosen between 1 and this value.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
probability=0.5,
|
62 |
+
min_area=0.02,
|
63 |
+
max_area=1 / 3,
|
64 |
+
min_aspect=0.3,
|
65 |
+
max_aspect=None,
|
66 |
+
mode="const",
|
67 |
+
min_count=1,
|
68 |
+
max_count=None,
|
69 |
+
num_splits=0,
|
70 |
+
device="cuda",
|
71 |
+
cube=True,
|
72 |
+
):
|
73 |
+
self.probability = probability
|
74 |
+
self.min_area = min_area
|
75 |
+
self.max_area = max_area
|
76 |
+
max_aspect = max_aspect or 1 / min_aspect
|
77 |
+
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
78 |
+
self.min_count = min_count
|
79 |
+
self.max_count = max_count or min_count
|
80 |
+
self.num_splits = num_splits
|
81 |
+
mode = mode.lower()
|
82 |
+
self.rand_color = False
|
83 |
+
self.per_pixel = False
|
84 |
+
self.cube = cube
|
85 |
+
if mode == "rand":
|
86 |
+
self.rand_color = True # per block random normal
|
87 |
+
elif mode == "pixel":
|
88 |
+
self.per_pixel = True # per pixel random normal
|
89 |
+
else:
|
90 |
+
assert not mode or mode == "const"
|
91 |
+
self.device = device
|
92 |
+
|
93 |
+
def _erase(self, img, chan, img_h, img_w, dtype):
|
94 |
+
if random.random() > self.probability:
|
95 |
+
return
|
96 |
+
area = img_h * img_w
|
97 |
+
count = self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count)
|
98 |
+
for _ in range(count):
|
99 |
+
for _ in range(10):
|
100 |
+
target_area = random.uniform(self.min_area, self.max_area) * area / count
|
101 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
102 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
103 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
104 |
+
if w < img_w and h < img_h:
|
105 |
+
top = random.randint(0, img_h - h)
|
106 |
+
left = random.randint(0, img_w - w)
|
107 |
+
img[:, top : top + h, left : left + w] = _get_pixels(
|
108 |
+
self.per_pixel,
|
109 |
+
self.rand_color,
|
110 |
+
(chan, h, w),
|
111 |
+
dtype=dtype,
|
112 |
+
device=self.device,
|
113 |
+
)
|
114 |
+
break
|
115 |
+
|
116 |
+
def _erase_cube(
|
117 |
+
self,
|
118 |
+
img,
|
119 |
+
batch_start,
|
120 |
+
batch_size,
|
121 |
+
chan,
|
122 |
+
img_h,
|
123 |
+
img_w,
|
124 |
+
dtype,
|
125 |
+
):
|
126 |
+
if random.random() > self.probability:
|
127 |
+
return
|
128 |
+
area = img_h * img_w
|
129 |
+
count = self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count)
|
130 |
+
for _ in range(count):
|
131 |
+
for _ in range(100):
|
132 |
+
target_area = random.uniform(self.min_area, self.max_area) * area / count
|
133 |
+
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
134 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
135 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
136 |
+
if w < img_w and h < img_h:
|
137 |
+
top = random.randint(0, img_h - h)
|
138 |
+
left = random.randint(0, img_w - w)
|
139 |
+
for i in range(batch_start, batch_size):
|
140 |
+
img_instance = img[i]
|
141 |
+
img_instance[:, top : top + h, left : left + w] = _get_pixels(
|
142 |
+
self.per_pixel,
|
143 |
+
self.rand_color,
|
144 |
+
(chan, h, w),
|
145 |
+
dtype=dtype,
|
146 |
+
device=self.device,
|
147 |
+
)
|
148 |
+
break
|
149 |
+
|
150 |
+
def __call__(self, input):
|
151 |
+
if len(input.size()) == 3:
|
152 |
+
self._erase(input, *input.size(), input.dtype)
|
153 |
+
else:
|
154 |
+
batch_size, chan, img_h, img_w = input.size()
|
155 |
+
# skip first slice of batch if num_splits is set (for clean portion of samples)
|
156 |
+
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
|
157 |
+
if self.cube:
|
158 |
+
self._erase_cube(
|
159 |
+
input,
|
160 |
+
batch_start,
|
161 |
+
batch_size,
|
162 |
+
chan,
|
163 |
+
img_h,
|
164 |
+
img_w,
|
165 |
+
input.dtype,
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
for i in range(batch_start, batch_size):
|
169 |
+
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
170 |
+
return input
|
src/datasets/utils/video/transforms.py
ADDED
@@ -0,0 +1,1161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import numbers
|
8 |
+
import random
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import PIL
|
12 |
+
import torch
|
13 |
+
import torchvision
|
14 |
+
import torchvision.transforms.functional as F
|
15 |
+
from PIL import Image
|
16 |
+
from torch import Tensor
|
17 |
+
from torchvision import transforms
|
18 |
+
|
19 |
+
import src.datasets.utils.video.functional as FF
|
20 |
+
from src.datasets.utils.video.randaugment import rand_augment_transform
|
21 |
+
|
22 |
+
_pil_interpolation_to_str = {
|
23 |
+
Image.NEAREST: "PIL.Image.NEAREST",
|
24 |
+
Image.BILINEAR: "PIL.Image.BILINEAR",
|
25 |
+
Image.BICUBIC: "PIL.Image.BICUBIC",
|
26 |
+
Image.LANCZOS: "PIL.Image.LANCZOS",
|
27 |
+
Image.HAMMING: "PIL.Image.HAMMING",
|
28 |
+
Image.BOX: "PIL.Image.BOX",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
33 |
+
PAD_FRAME_METHODS = ["circulant"]
|
34 |
+
|
35 |
+
|
36 |
+
def _pil_interp(method):
|
37 |
+
if method == "bicubic":
|
38 |
+
return Image.BICUBIC
|
39 |
+
elif method == "lanczos":
|
40 |
+
return Image.LANCZOS
|
41 |
+
elif method == "hamming":
|
42 |
+
return Image.HAMMING
|
43 |
+
else:
|
44 |
+
return Image.BILINEAR
|
45 |
+
|
46 |
+
|
47 |
+
def random_short_side_scale_jitter(images, min_size, max_size, boxes=None, inverse_uniform_sampling=False):
|
48 |
+
"""
|
49 |
+
Perform a spatial short scale jittering on the given images and
|
50 |
+
corresponding boxes.
|
51 |
+
Args:
|
52 |
+
images (tensor): images to perform scale jitter. Dimension is
|
53 |
+
`num frames` x `channel` x `height` x `width`.
|
54 |
+
min_size (int): the minimal size to scale the frames.
|
55 |
+
max_size (int): the maximal size to scale the frames.
|
56 |
+
boxes (ndarray): optional. Corresponding boxes to images.
|
57 |
+
Dimension is `num boxes` x 4.
|
58 |
+
inverse_uniform_sampling (bool): if True, sample uniformly in
|
59 |
+
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the
|
60 |
+
scale. If False, take a uniform sample from [min_scale, max_scale].
|
61 |
+
Returns:
|
62 |
+
(tensor): the scaled images with dimension of
|
63 |
+
`num frames` x `channel` x `new height` x `new width`.
|
64 |
+
(ndarray or None): the scaled boxes with dimension of
|
65 |
+
`num boxes` x 4.
|
66 |
+
"""
|
67 |
+
if inverse_uniform_sampling:
|
68 |
+
size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)))
|
69 |
+
else:
|
70 |
+
size = int(round(np.random.uniform(min_size, max_size)))
|
71 |
+
|
72 |
+
height = images.shape[2]
|
73 |
+
width = images.shape[3]
|
74 |
+
if (width <= height and width == size) or (height <= width and height == size):
|
75 |
+
return images, boxes
|
76 |
+
new_width = size
|
77 |
+
new_height = size
|
78 |
+
if width < height:
|
79 |
+
new_height = int(math.floor((float(height) / width) * size))
|
80 |
+
if boxes is not None:
|
81 |
+
boxes = boxes * float(new_height) / height
|
82 |
+
else:
|
83 |
+
new_width = int(math.floor((float(width) / height) * size))
|
84 |
+
if boxes is not None:
|
85 |
+
boxes = boxes * float(new_width) / width
|
86 |
+
|
87 |
+
return (
|
88 |
+
torch.nn.functional.interpolate(
|
89 |
+
images,
|
90 |
+
size=(new_height, new_width),
|
91 |
+
mode="bilinear",
|
92 |
+
align_corners=False,
|
93 |
+
),
|
94 |
+
boxes,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
99 |
+
"""
|
100 |
+
Peform crop on the bounding boxes given the offsets.
|
101 |
+
Args:
|
102 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
103 |
+
is `num boxes` x 4.
|
104 |
+
x_offset (int): cropping offset in the x axis.
|
105 |
+
y_offset (int): cropping offset in the y axis.
|
106 |
+
Returns:
|
107 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
108 |
+
`num boxes` x 4.
|
109 |
+
"""
|
110 |
+
cropped_boxes = boxes.copy()
|
111 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
112 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
113 |
+
|
114 |
+
return cropped_boxes
|
115 |
+
|
116 |
+
|
117 |
+
def random_crop(images, size, boxes=None):
|
118 |
+
"""
|
119 |
+
Perform random spatial crop on the given images and corresponding boxes.
|
120 |
+
Args:
|
121 |
+
images (tensor): images to perform random crop. The dimension is
|
122 |
+
`num frames` x `channel` x `height` x `width`.
|
123 |
+
size (int): the size of height and width to crop on the image.
|
124 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
125 |
+
Dimension is `num boxes` x 4.
|
126 |
+
Returns:
|
127 |
+
cropped (tensor): cropped images with dimension of
|
128 |
+
`num frames` x `channel` x `size` x `size`.
|
129 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
130 |
+
`num boxes` x 4.
|
131 |
+
"""
|
132 |
+
if images.shape[2] == size and images.shape[3] == size:
|
133 |
+
return images
|
134 |
+
height = images.shape[2]
|
135 |
+
width = images.shape[3]
|
136 |
+
y_offset = 0
|
137 |
+
if height > size:
|
138 |
+
y_offset = int(np.random.randint(0, height - size))
|
139 |
+
x_offset = 0
|
140 |
+
if width > size:
|
141 |
+
x_offset = int(np.random.randint(0, width - size))
|
142 |
+
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
143 |
+
|
144 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
145 |
+
|
146 |
+
return cropped, cropped_boxes
|
147 |
+
|
148 |
+
|
149 |
+
def horizontal_flip(prob, images, boxes=None):
|
150 |
+
"""
|
151 |
+
Perform horizontal flip on the given images and corresponding boxes.
|
152 |
+
Args:
|
153 |
+
prob (float): probility to flip the images.
|
154 |
+
images (tensor): images to perform horizontal flip, the dimension is
|
155 |
+
`num frames` x `channel` x `height` x `width`.
|
156 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
157 |
+
Dimension is `num boxes` x 4.
|
158 |
+
Returns:
|
159 |
+
images (tensor): images with dimension of
|
160 |
+
`num frames` x `channel` x `height` x `width`.
|
161 |
+
flipped_boxes (ndarray or None): the flipped boxes with dimension of
|
162 |
+
`num boxes` x 4.
|
163 |
+
"""
|
164 |
+
if boxes is None:
|
165 |
+
flipped_boxes = None
|
166 |
+
else:
|
167 |
+
flipped_boxes = boxes.copy()
|
168 |
+
|
169 |
+
if np.random.uniform() < prob:
|
170 |
+
images = images.flip((-1))
|
171 |
+
|
172 |
+
if len(images.shape) == 3:
|
173 |
+
width = images.shape[2]
|
174 |
+
elif len(images.shape) == 4:
|
175 |
+
width = images.shape[3]
|
176 |
+
else:
|
177 |
+
raise NotImplementedError("Dimension does not supported")
|
178 |
+
if boxes is not None:
|
179 |
+
flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
|
180 |
+
|
181 |
+
return images, flipped_boxes
|
182 |
+
|
183 |
+
|
184 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
185 |
+
"""
|
186 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
187 |
+
Args:
|
188 |
+
images (tensor): images to perform uniform crop. The dimension is
|
189 |
+
`num frames` x `channel` x `height` x `width`.
|
190 |
+
size (int): size of height and weight to crop the images.
|
191 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
192 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
193 |
+
crop if height is larger than width.
|
194 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
195 |
+
Dimension is `num boxes` x 4.
|
196 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
197 |
+
performing any crop.
|
198 |
+
Returns:
|
199 |
+
cropped (tensor): images with dimension of
|
200 |
+
`num frames` x `channel` x `size` x `size`.
|
201 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
202 |
+
`num boxes` x 4.
|
203 |
+
"""
|
204 |
+
assert spatial_idx in [0, 1, 2]
|
205 |
+
ndim = len(images.shape)
|
206 |
+
if ndim == 3:
|
207 |
+
images = images.unsqueeze(0)
|
208 |
+
height = images.shape[2]
|
209 |
+
width = images.shape[3]
|
210 |
+
|
211 |
+
if scale_size is not None:
|
212 |
+
if width <= height:
|
213 |
+
width, height = scale_size, int(height / width * scale_size)
|
214 |
+
else:
|
215 |
+
width, height = int(width / height * scale_size), scale_size
|
216 |
+
images = torch.nn.functional.interpolate(
|
217 |
+
images,
|
218 |
+
size=(height, width),
|
219 |
+
mode="bilinear",
|
220 |
+
align_corners=False,
|
221 |
+
)
|
222 |
+
|
223 |
+
y_offset = int(math.ceil((height - size) / 2))
|
224 |
+
x_offset = int(math.ceil((width - size) / 2))
|
225 |
+
|
226 |
+
if height > width:
|
227 |
+
if spatial_idx == 0:
|
228 |
+
y_offset = 0
|
229 |
+
elif spatial_idx == 2:
|
230 |
+
y_offset = height - size
|
231 |
+
else:
|
232 |
+
if spatial_idx == 0:
|
233 |
+
x_offset = 0
|
234 |
+
elif spatial_idx == 2:
|
235 |
+
x_offset = width - size
|
236 |
+
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
237 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
238 |
+
if ndim == 3:
|
239 |
+
cropped = cropped.squeeze(0)
|
240 |
+
return cropped, cropped_boxes
|
241 |
+
|
242 |
+
|
243 |
+
def clip_boxes_to_image(boxes, height, width):
|
244 |
+
"""
|
245 |
+
Clip an array of boxes to an image with the given height and width.
|
246 |
+
Args:
|
247 |
+
boxes (ndarray): bounding boxes to perform clipping.
|
248 |
+
Dimension is `num boxes` x 4.
|
249 |
+
height (int): given image height.
|
250 |
+
width (int): given image width.
|
251 |
+
Returns:
|
252 |
+
clipped_boxes (ndarray): the clipped boxes with dimension of
|
253 |
+
`num boxes` x 4.
|
254 |
+
"""
|
255 |
+
clipped_boxes = boxes.copy()
|
256 |
+
clipped_boxes[:, [0, 2]] = np.minimum(width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]))
|
257 |
+
clipped_boxes[:, [1, 3]] = np.minimum(height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]))
|
258 |
+
return clipped_boxes
|
259 |
+
|
260 |
+
|
261 |
+
def blend(images1, images2, alpha):
|
262 |
+
"""
|
263 |
+
Blend two images with a given weight alpha.
|
264 |
+
Args:
|
265 |
+
images1 (tensor): the first images to be blended, the dimension is
|
266 |
+
`num frames` x `channel` x `height` x `width`.
|
267 |
+
images2 (tensor): the second images to be blended, the dimension is
|
268 |
+
`num frames` x `channel` x `height` x `width`.
|
269 |
+
alpha (float): the blending weight.
|
270 |
+
Returns:
|
271 |
+
(tensor): blended images, the dimension is
|
272 |
+
`num frames` x `channel` x `height` x `width`.
|
273 |
+
"""
|
274 |
+
return images1 * alpha + images2 * (1 - alpha)
|
275 |
+
|
276 |
+
|
277 |
+
def grayscale(images):
|
278 |
+
"""
|
279 |
+
Get the grayscale for the input images. The channels of images should be
|
280 |
+
in order BGR.
|
281 |
+
Args:
|
282 |
+
images (tensor): the input images for getting grayscale. Dimension is
|
283 |
+
`num frames` x `channel` x `height` x `width`.
|
284 |
+
Returns:
|
285 |
+
img_gray (tensor): blended images, the dimension is
|
286 |
+
`num frames` x `channel` x `height` x `width`.
|
287 |
+
"""
|
288 |
+
# R -> 0.299, G -> 0.587, B -> 0.114.
|
289 |
+
img_gray = torch.tensor(images)
|
290 |
+
gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
|
291 |
+
img_gray[:, 0] = gray_channel
|
292 |
+
img_gray[:, 1] = gray_channel
|
293 |
+
img_gray[:, 2] = gray_channel
|
294 |
+
return img_gray
|
295 |
+
|
296 |
+
|
297 |
+
def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
|
298 |
+
"""
|
299 |
+
Perfrom a color jittering on the input images. The channels of images
|
300 |
+
should be in order BGR.
|
301 |
+
Args:
|
302 |
+
images (tensor): images to perform color jitter. Dimension is
|
303 |
+
`num frames` x `channel` x `height` x `width`.
|
304 |
+
img_brightness (float): jitter ratio for brightness.
|
305 |
+
img_contrast (float): jitter ratio for contrast.
|
306 |
+
img_saturation (float): jitter ratio for saturation.
|
307 |
+
Returns:
|
308 |
+
images (tensor): the jittered images, the dimension is
|
309 |
+
`num frames` x `channel` x `height` x `width`.
|
310 |
+
"""
|
311 |
+
|
312 |
+
jitter = []
|
313 |
+
if img_brightness != 0:
|
314 |
+
jitter.append("brightness")
|
315 |
+
if img_contrast != 0:
|
316 |
+
jitter.append("contrast")
|
317 |
+
if img_saturation != 0:
|
318 |
+
jitter.append("saturation")
|
319 |
+
|
320 |
+
if len(jitter) > 0:
|
321 |
+
order = np.random.permutation(np.arange(len(jitter)))
|
322 |
+
for idx in range(0, len(jitter)):
|
323 |
+
if jitter[order[idx]] == "brightness":
|
324 |
+
images = brightness_jitter(img_brightness, images)
|
325 |
+
elif jitter[order[idx]] == "contrast":
|
326 |
+
images = contrast_jitter(img_contrast, images)
|
327 |
+
elif jitter[order[idx]] == "saturation":
|
328 |
+
images = saturation_jitter(img_saturation, images)
|
329 |
+
return images
|
330 |
+
|
331 |
+
|
332 |
+
def brightness_jitter(var, images):
|
333 |
+
"""
|
334 |
+
Perfrom brightness jittering on the input images. The channels of images
|
335 |
+
should be in order BGR.
|
336 |
+
Args:
|
337 |
+
var (float): jitter ratio for brightness.
|
338 |
+
images (tensor): images to perform color jitter. Dimension is
|
339 |
+
`num frames` x `channel` x `height` x `width`.
|
340 |
+
Returns:
|
341 |
+
images (tensor): the jittered images, the dimension is
|
342 |
+
`num frames` x `channel` x `height` x `width`.
|
343 |
+
"""
|
344 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
345 |
+
|
346 |
+
img_bright = torch.zeros(images.shape)
|
347 |
+
images = blend(images, img_bright, alpha)
|
348 |
+
return images
|
349 |
+
|
350 |
+
|
351 |
+
def contrast_jitter(var, images):
|
352 |
+
"""
|
353 |
+
Perfrom contrast jittering on the input images. The channels of images
|
354 |
+
should be in order BGR.
|
355 |
+
Args:
|
356 |
+
var (float): jitter ratio for contrast.
|
357 |
+
images (tensor): images to perform color jitter. Dimension is
|
358 |
+
`num frames` x `channel` x `height` x `width`.
|
359 |
+
Returns:
|
360 |
+
images (tensor): the jittered images, the dimension is
|
361 |
+
`num frames` x `channel` x `height` x `width`.
|
362 |
+
"""
|
363 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
364 |
+
|
365 |
+
img_gray = grayscale(images)
|
366 |
+
img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
|
367 |
+
images = blend(images, img_gray, alpha)
|
368 |
+
return images
|
369 |
+
|
370 |
+
|
371 |
+
def saturation_jitter(var, images):
|
372 |
+
"""
|
373 |
+
Perfrom saturation jittering on the input images. The channels of images
|
374 |
+
should be in order BGR.
|
375 |
+
Args:
|
376 |
+
var (float): jitter ratio for saturation.
|
377 |
+
images (tensor): images to perform color jitter. Dimension is
|
378 |
+
`num frames` x `channel` x `height` x `width`.
|
379 |
+
Returns:
|
380 |
+
images (tensor): the jittered images, the dimension is
|
381 |
+
`num frames` x `channel` x `height` x `width`.
|
382 |
+
"""
|
383 |
+
alpha = 1.0 + np.random.uniform(-var, var)
|
384 |
+
img_gray = grayscale(images)
|
385 |
+
images = blend(images, img_gray, alpha)
|
386 |
+
|
387 |
+
return images
|
388 |
+
|
389 |
+
|
390 |
+
def lighting_jitter(images, alphastd, eigval, eigvec):
|
391 |
+
"""
|
392 |
+
Perform AlexNet-style PCA jitter on the given images.
|
393 |
+
Args:
|
394 |
+
images (tensor): images to perform lighting jitter. Dimension is
|
395 |
+
`num frames` x `channel` x `height` x `width`.
|
396 |
+
alphastd (float): jitter ratio for PCA jitter.
|
397 |
+
eigval (list): eigenvalues for PCA jitter.
|
398 |
+
eigvec (list[list]): eigenvectors for PCA jitter.
|
399 |
+
Returns:
|
400 |
+
out_images (tensor): the jittered images, the dimension is
|
401 |
+
`num frames` x `channel` x `height` x `width`.
|
402 |
+
"""
|
403 |
+
if alphastd == 0:
|
404 |
+
return images
|
405 |
+
# generate alpha1, alpha2, alpha3.
|
406 |
+
alpha = np.random.normal(0, alphastd, size=(1, 3))
|
407 |
+
eig_vec = np.array(eigvec)
|
408 |
+
eig_val = np.reshape(eigval, (1, 3))
|
409 |
+
rgb = np.sum(
|
410 |
+
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
|
411 |
+
axis=1,
|
412 |
+
)
|
413 |
+
out_images = torch.zeros_like(images)
|
414 |
+
if len(images.shape) == 3:
|
415 |
+
# C H W
|
416 |
+
channel_dim = 0
|
417 |
+
elif len(images.shape) == 4:
|
418 |
+
# T C H W
|
419 |
+
channel_dim = 1
|
420 |
+
else:
|
421 |
+
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
|
422 |
+
|
423 |
+
for idx in range(images.shape[channel_dim]):
|
424 |
+
# C H W
|
425 |
+
if len(images.shape) == 3:
|
426 |
+
out_images[idx] = images[idx] + rgb[2 - idx]
|
427 |
+
# T C H W
|
428 |
+
elif len(images.shape) == 4:
|
429 |
+
out_images[:, idx] = images[:, idx] + rgb[2 - idx]
|
430 |
+
else:
|
431 |
+
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
|
432 |
+
|
433 |
+
return out_images
|
434 |
+
|
435 |
+
|
436 |
+
def color_normalization(images, mean, stddev):
|
437 |
+
"""
|
438 |
+
Perform color nomration on the given images.
|
439 |
+
Args:
|
440 |
+
images (tensor): images to perform color normalization. Dimension is
|
441 |
+
`num frames` x `channel` x `height` x `width`.
|
442 |
+
mean (list): mean values for normalization.
|
443 |
+
stddev (list): standard deviations for normalization.
|
444 |
+
|
445 |
+
Returns:
|
446 |
+
out_images (tensor): the noramlized images, the dimension is
|
447 |
+
`num frames` x `channel` x `height` x `width`.
|
448 |
+
"""
|
449 |
+
if len(images.shape) == 3:
|
450 |
+
assert len(mean) == images.shape[0], "channel mean not computed properly"
|
451 |
+
assert len(stddev) == images.shape[0], "channel stddev not computed properly"
|
452 |
+
elif len(images.shape) == 4:
|
453 |
+
assert len(mean) == images.shape[1], "channel mean not computed properly"
|
454 |
+
assert len(stddev) == images.shape[1], "channel stddev not computed properly"
|
455 |
+
else:
|
456 |
+
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
|
457 |
+
|
458 |
+
out_images = torch.zeros_like(images)
|
459 |
+
for idx in range(len(mean)):
|
460 |
+
# C H W
|
461 |
+
if len(images.shape) == 3:
|
462 |
+
out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
|
463 |
+
elif len(images.shape) == 4:
|
464 |
+
out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
|
465 |
+
else:
|
466 |
+
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
|
467 |
+
return out_images
|
468 |
+
|
469 |
+
|
470 |
+
def _get_param_spatial_crop(scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False):
|
471 |
+
"""
|
472 |
+
Given scale, ratio, height and width, return sampled coordinates of the videos.
|
473 |
+
"""
|
474 |
+
for _ in range(num_repeat):
|
475 |
+
area = height * width
|
476 |
+
target_area = random.uniform(*scale) * area
|
477 |
+
if log_scale:
|
478 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
479 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
480 |
+
else:
|
481 |
+
aspect_ratio = random.uniform(*ratio)
|
482 |
+
|
483 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
484 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
485 |
+
|
486 |
+
if np.random.uniform() < 0.5 and switch_hw:
|
487 |
+
w, h = h, w
|
488 |
+
|
489 |
+
if 0 < w <= width and 0 < h <= height:
|
490 |
+
i = random.randint(0, height - h)
|
491 |
+
j = random.randint(0, width - w)
|
492 |
+
return i, j, h, w
|
493 |
+
|
494 |
+
# Fallback to central crop
|
495 |
+
in_ratio = float(width) / float(height)
|
496 |
+
if in_ratio < min(ratio):
|
497 |
+
w = width
|
498 |
+
h = int(round(w / min(ratio)))
|
499 |
+
elif in_ratio > max(ratio):
|
500 |
+
h = height
|
501 |
+
w = int(round(h * max(ratio)))
|
502 |
+
else: # whole image
|
503 |
+
w = width
|
504 |
+
h = height
|
505 |
+
i = (height - h) // 2
|
506 |
+
j = (width - w) // 2
|
507 |
+
return i, j, h, w
|
508 |
+
|
509 |
+
|
510 |
+
def random_resized_crop(
|
511 |
+
images,
|
512 |
+
target_height,
|
513 |
+
target_width,
|
514 |
+
scale=(0.8, 1.0),
|
515 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
516 |
+
):
|
517 |
+
"""
|
518 |
+
Crop the given images to random size and aspect ratio. A crop of random
|
519 |
+
size (default: of 0.08 to 1.0) of the original size and a random aspect
|
520 |
+
ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
|
521 |
+
crop is finally resized to given size. This is popularly used to train the
|
522 |
+
Inception networks.
|
523 |
+
|
524 |
+
Args:
|
525 |
+
images: Images to perform resizing and cropping.
|
526 |
+
target_height: Desired height after cropping.
|
527 |
+
target_width: Desired width after cropping.
|
528 |
+
scale: Scale range of Inception-style area based random resizing.
|
529 |
+
ratio: Aspect ratio range of Inception-style area based random resizing.
|
530 |
+
"""
|
531 |
+
|
532 |
+
height = images.shape[2]
|
533 |
+
width = images.shape[3]
|
534 |
+
|
535 |
+
i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
|
536 |
+
cropped = images[:, :, i : i + h, j : j + w]
|
537 |
+
return torch.nn.functional.interpolate(
|
538 |
+
cropped,
|
539 |
+
size=(target_height, target_width),
|
540 |
+
mode="bilinear",
|
541 |
+
align_corners=False,
|
542 |
+
)
|
543 |
+
|
544 |
+
|
545 |
+
def random_resized_crop_with_shift(
|
546 |
+
images,
|
547 |
+
target_height,
|
548 |
+
target_width,
|
549 |
+
scale=(0.8, 1.0),
|
550 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
551 |
+
):
|
552 |
+
"""
|
553 |
+
This is similar to random_resized_crop. However, it samples two different
|
554 |
+
boxes (for cropping) for the first and last frame. It then linearly
|
555 |
+
interpolates the two boxes for other frames.
|
556 |
+
|
557 |
+
Args:
|
558 |
+
images: Images to perform resizing and cropping.
|
559 |
+
target_height: Desired height after cropping.
|
560 |
+
target_width: Desired width after cropping.
|
561 |
+
scale: Scale range of Inception-style area based random resizing.
|
562 |
+
ratio: Aspect ratio range of Inception-style area based random resizing.
|
563 |
+
"""
|
564 |
+
t = images.shape[1]
|
565 |
+
height = images.shape[2]
|
566 |
+
width = images.shape[3]
|
567 |
+
|
568 |
+
i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
|
569 |
+
i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
|
570 |
+
i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
|
571 |
+
j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
|
572 |
+
h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
|
573 |
+
w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
|
574 |
+
out = torch.zeros((3, t, target_height, target_width))
|
575 |
+
for ind in range(t):
|
576 |
+
out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
|
577 |
+
images[
|
578 |
+
:,
|
579 |
+
ind : ind + 1,
|
580 |
+
i_s[ind] : i_s[ind] + h_s[ind],
|
581 |
+
j_s[ind] : j_s[ind] + w_s[ind],
|
582 |
+
],
|
583 |
+
size=(target_height, target_width),
|
584 |
+
mode="bilinear",
|
585 |
+
align_corners=False,
|
586 |
+
)
|
587 |
+
return out
|
588 |
+
|
589 |
+
|
590 |
+
def create_random_augment(
|
591 |
+
input_size,
|
592 |
+
auto_augment=None,
|
593 |
+
interpolation="bilinear",
|
594 |
+
):
|
595 |
+
"""
|
596 |
+
Get video randaug transform.
|
597 |
+
|
598 |
+
Args:
|
599 |
+
input_size: The size of the input video in tuple.
|
600 |
+
auto_augment: Parameters for randaug. An example:
|
601 |
+
"rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
|
602 |
+
of operations to apply).
|
603 |
+
interpolation: Interpolation method.
|
604 |
+
"""
|
605 |
+
if isinstance(input_size, tuple):
|
606 |
+
img_size = input_size[-2:]
|
607 |
+
else:
|
608 |
+
img_size = input_size
|
609 |
+
|
610 |
+
if auto_augment:
|
611 |
+
assert isinstance(auto_augment, str)
|
612 |
+
if isinstance(img_size, tuple):
|
613 |
+
img_size_min = min(img_size)
|
614 |
+
else:
|
615 |
+
img_size_min = img_size
|
616 |
+
aa_params = {"translate_const": int(img_size_min * 0.45)}
|
617 |
+
if interpolation and interpolation != "random":
|
618 |
+
aa_params["interpolation"] = _pil_interp(interpolation)
|
619 |
+
if auto_augment.startswith("rand"):
|
620 |
+
return transforms.Compose([rand_augment_transform(auto_augment, aa_params)])
|
621 |
+
raise NotImplementedError
|
622 |
+
|
623 |
+
|
624 |
+
def random_sized_crop_img(
|
625 |
+
im,
|
626 |
+
size,
|
627 |
+
jitter_scale=(0.08, 1.0),
|
628 |
+
jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
|
629 |
+
max_iter=10,
|
630 |
+
):
|
631 |
+
"""
|
632 |
+
Performs Inception-style cropping (used for training).
|
633 |
+
"""
|
634 |
+
assert len(im.shape) == 3, "Currently only support image for random_sized_crop"
|
635 |
+
h, w = im.shape[1:3]
|
636 |
+
i, j, h, w = _get_param_spatial_crop(
|
637 |
+
scale=jitter_scale,
|
638 |
+
ratio=jitter_aspect,
|
639 |
+
height=h,
|
640 |
+
width=w,
|
641 |
+
num_repeat=max_iter,
|
642 |
+
log_scale=False,
|
643 |
+
switch_hw=True,
|
644 |
+
)
|
645 |
+
cropped = im[:, i : i + h, j : j + w]
|
646 |
+
return torch.nn.functional.interpolate(
|
647 |
+
cropped.unsqueeze(0),
|
648 |
+
size=(size, size),
|
649 |
+
mode="bilinear",
|
650 |
+
align_corners=False,
|
651 |
+
).squeeze(0)
|
652 |
+
|
653 |
+
|
654 |
+
def circulant_frame_padding(video: Tensor, total_frames: int) -> Tensor:
|
655 |
+
"""
|
656 |
+
Applies circulant frame padding (repeating the video) to a specified size.
|
657 |
+
|
658 |
+
Args:
|
659 |
+
video: The input video to be padded. Expected (C, T, H, W)
|
660 |
+
total_frames: The number of frames after padding.
|
661 |
+
|
662 |
+
Returns
|
663 |
+
The video padded to total_frames.
|
664 |
+
"""
|
665 |
+
start_frames = video.shape[1]
|
666 |
+
if start_frames == total_frames:
|
667 |
+
return video
|
668 |
+
|
669 |
+
num_repeats = total_frames // start_frames + (total_frames % start_frames > 0)
|
670 |
+
|
671 |
+
return video.repeat((1, num_repeats) + (1,) * (video.ndim - 2))[:, :total_frames]
|
672 |
+
|
673 |
+
|
674 |
+
def frame_pad(video: Tensor, total_frames: int, pad_frame_method: str) -> Tensor:
|
675 |
+
if pad_frame_method not in PAD_FRAME_METHODS:
|
676 |
+
raise ValueError(f"Unrecognized pad_frame_method {pad_frame_method}")
|
677 |
+
|
678 |
+
if pad_frame_method == "circulant":
|
679 |
+
return circulant_frame_padding(video, total_frames)
|
680 |
+
|
681 |
+
return None
|
682 |
+
|
683 |
+
|
684 |
+
# The following code are modified based on timm lib, we will replace the following
|
685 |
+
# contents with dependency from PyTorchVideo.
|
686 |
+
# https://github.com/facebookresearch/pytorchvideo
|
687 |
+
class RandomResizedCropAndInterpolation:
|
688 |
+
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
689 |
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
690 |
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
691 |
+
is finally resized to given size.
|
692 |
+
This is popularly used to train the Inception networks.
|
693 |
+
Args:
|
694 |
+
size: expected output size of each edge
|
695 |
+
scale: range of size of the origin size cropped
|
696 |
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
697 |
+
interpolation: Default: PIL.Image.BILINEAR
|
698 |
+
"""
|
699 |
+
|
700 |
+
def __init__(
|
701 |
+
self,
|
702 |
+
size,
|
703 |
+
scale=(0.08, 1.0),
|
704 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
705 |
+
interpolation="bilinear",
|
706 |
+
):
|
707 |
+
if isinstance(size, tuple):
|
708 |
+
self.size = size
|
709 |
+
else:
|
710 |
+
self.size = (size, size)
|
711 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
712 |
+
print("range should be of kind (min, max)")
|
713 |
+
|
714 |
+
if interpolation == "random":
|
715 |
+
self.interpolation = _RANDOM_INTERPOLATION
|
716 |
+
else:
|
717 |
+
self.interpolation = _pil_interp(interpolation)
|
718 |
+
self.scale = scale
|
719 |
+
self.ratio = ratio
|
720 |
+
|
721 |
+
@staticmethod
|
722 |
+
def get_params(img, scale, ratio):
|
723 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
724 |
+
Args:
|
725 |
+
img (PIL Image): Image to be cropped.
|
726 |
+
scale (tuple): range of size of the origin size cropped
|
727 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
728 |
+
Returns:
|
729 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
730 |
+
sized crop.
|
731 |
+
"""
|
732 |
+
area = img.size[0] * img.size[1]
|
733 |
+
|
734 |
+
for _ in range(10):
|
735 |
+
target_area = random.uniform(*scale) * area
|
736 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
737 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
738 |
+
|
739 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
740 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
741 |
+
|
742 |
+
if w <= img.size[0] and h <= img.size[1]:
|
743 |
+
i = random.randint(0, img.size[1] - h)
|
744 |
+
j = random.randint(0, img.size[0] - w)
|
745 |
+
return i, j, h, w
|
746 |
+
|
747 |
+
# Fallback to central crop
|
748 |
+
in_ratio = img.size[0] / img.size[1]
|
749 |
+
if in_ratio < min(ratio):
|
750 |
+
w = img.size[0]
|
751 |
+
h = int(round(w / min(ratio)))
|
752 |
+
elif in_ratio > max(ratio):
|
753 |
+
h = img.size[1]
|
754 |
+
w = int(round(h * max(ratio)))
|
755 |
+
else: # whole image
|
756 |
+
w = img.size[0]
|
757 |
+
h = img.size[1]
|
758 |
+
i = (img.size[1] - h) // 2
|
759 |
+
j = (img.size[0] - w) // 2
|
760 |
+
return i, j, h, w
|
761 |
+
|
762 |
+
def __call__(self, img):
|
763 |
+
"""
|
764 |
+
Args:
|
765 |
+
img (PIL Image): Image to be cropped and resized.
|
766 |
+
Returns:
|
767 |
+
PIL Image: Randomly cropped and resized image.
|
768 |
+
"""
|
769 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
770 |
+
if isinstance(self.interpolation, (tuple, list)):
|
771 |
+
interpolation = random.choice(self.interpolation)
|
772 |
+
else:
|
773 |
+
interpolation = self.interpolation
|
774 |
+
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
775 |
+
|
776 |
+
def __repr__(self):
|
777 |
+
if isinstance(self.interpolation, (tuple, list)):
|
778 |
+
interpolate_str = " ".join([_pil_interpolation_to_str[x] for x in self.interpolation])
|
779 |
+
else:
|
780 |
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
781 |
+
format_string = self.__class__.__name__ + "(size={0}".format(self.size)
|
782 |
+
format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
|
783 |
+
format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
|
784 |
+
format_string += ", interpolation={0})".format(interpolate_str)
|
785 |
+
return format_string
|
786 |
+
|
787 |
+
|
788 |
+
class Compose(object):
|
789 |
+
"""Composes several transforms
|
790 |
+
Args:
|
791 |
+
transforms (list of ``Transform`` objects): list of transforms
|
792 |
+
to compose
|
793 |
+
"""
|
794 |
+
|
795 |
+
def __init__(self, transforms):
|
796 |
+
self.transforms = transforms
|
797 |
+
|
798 |
+
def __call__(self, clip):
|
799 |
+
for t in self.transforms:
|
800 |
+
clip = t(clip)
|
801 |
+
return clip
|
802 |
+
|
803 |
+
|
804 |
+
class RandomHorizontalFlip(object):
|
805 |
+
"""Horizontally flip the list of given images randomly
|
806 |
+
with a probability 0.5
|
807 |
+
"""
|
808 |
+
|
809 |
+
def __call__(self, clip):
|
810 |
+
"""
|
811 |
+
Args:
|
812 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
813 |
+
in format (h, w, c) in numpy.ndarray
|
814 |
+
Returns:
|
815 |
+
PIL.Image or numpy.ndarray: Randomly flipped clip
|
816 |
+
"""
|
817 |
+
if random.random() < 0.5:
|
818 |
+
if isinstance(clip[0], np.ndarray):
|
819 |
+
return [np.fliplr(img) for img in clip]
|
820 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
821 |
+
return [img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip]
|
822 |
+
else:
|
823 |
+
raise TypeError("Expected numpy.ndarray or PIL.Image" + " but got list of {0}".format(type(clip[0])))
|
824 |
+
return clip
|
825 |
+
|
826 |
+
|
827 |
+
class RandomResize(object):
|
828 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
829 |
+
The larger the original image is, the more times it takes to
|
830 |
+
interpolate
|
831 |
+
Args:
|
832 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
833 |
+
defaults to nearest
|
834 |
+
size (tuple): (widht, height)
|
835 |
+
"""
|
836 |
+
|
837 |
+
def __init__(self, ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="nearest"):
|
838 |
+
self.ratio = ratio
|
839 |
+
self.interpolation = interpolation
|
840 |
+
|
841 |
+
def __call__(self, clip):
|
842 |
+
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
843 |
+
|
844 |
+
if isinstance(clip[0], np.ndarray):
|
845 |
+
im_h, im_w, im_c = clip[0].shape
|
846 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
847 |
+
im_w, im_h = clip[0].size
|
848 |
+
|
849 |
+
new_w = int(im_w * scaling_factor)
|
850 |
+
new_h = int(im_h * scaling_factor)
|
851 |
+
new_size = (new_w, new_h)
|
852 |
+
resized = FF.resize_clip(clip, new_size, interpolation=self.interpolation)
|
853 |
+
return resized
|
854 |
+
|
855 |
+
|
856 |
+
class Resize(object):
|
857 |
+
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
858 |
+
The larger the original image is, the more times it takes to
|
859 |
+
interpolate
|
860 |
+
Args:
|
861 |
+
interpolation (str): Can be one of 'nearest', 'bilinear'
|
862 |
+
defaults to nearest
|
863 |
+
size (tuple): (widht, height)
|
864 |
+
"""
|
865 |
+
|
866 |
+
def __init__(self, size, interpolation="nearest"):
|
867 |
+
self.size = size
|
868 |
+
self.interpolation = interpolation
|
869 |
+
|
870 |
+
def __call__(self, clip):
|
871 |
+
resized = FF.resize_clip(clip, self.size, interpolation=self.interpolation)
|
872 |
+
return resized
|
873 |
+
|
874 |
+
|
875 |
+
class RandomCrop(object):
|
876 |
+
"""Extract random crop at the same location for a list of images
|
877 |
+
Args:
|
878 |
+
size (sequence or int): Desired output size for the
|
879 |
+
crop in format (h, w)
|
880 |
+
"""
|
881 |
+
|
882 |
+
def __init__(self, size):
|
883 |
+
if isinstance(size, numbers.Number):
|
884 |
+
size = (size, size)
|
885 |
+
|
886 |
+
self.size = size
|
887 |
+
|
888 |
+
def __call__(self, clip):
|
889 |
+
"""
|
890 |
+
Args:
|
891 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
892 |
+
in format (h, w, c) in numpy.ndarray
|
893 |
+
Returns:
|
894 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
895 |
+
"""
|
896 |
+
h, w = self.size
|
897 |
+
if isinstance(clip[0], np.ndarray):
|
898 |
+
im_h, im_w, im_c = clip[0].shape
|
899 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
900 |
+
im_w, im_h = clip[0].size
|
901 |
+
else:
|
902 |
+
raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
|
903 |
+
if w > im_w or h > im_h:
|
904 |
+
error_msg = (
|
905 |
+
"Initial image size should be larger then "
|
906 |
+
"cropped size but got cropped sizes : ({w}, {h}) while "
|
907 |
+
"initial image is ({im_w}, {im_h})".format(im_w=im_w, im_h=im_h, w=w, h=h)
|
908 |
+
)
|
909 |
+
raise ValueError(error_msg)
|
910 |
+
|
911 |
+
x1 = random.randint(0, im_w - w)
|
912 |
+
y1 = random.randint(0, im_h - h)
|
913 |
+
cropped = FF.crop_clip(clip, y1, x1, h, w)
|
914 |
+
|
915 |
+
return cropped
|
916 |
+
|
917 |
+
|
918 |
+
class ThreeCrop(object):
|
919 |
+
"""Extract random crop at the same location for a list of images
|
920 |
+
Args:
|
921 |
+
size (sequence or int): Desired output size for the
|
922 |
+
crop in format (h, w)
|
923 |
+
"""
|
924 |
+
|
925 |
+
def __init__(self, size):
|
926 |
+
if isinstance(size, numbers.Number):
|
927 |
+
size = (size, size)
|
928 |
+
|
929 |
+
self.size = size
|
930 |
+
|
931 |
+
def __call__(self, clip):
|
932 |
+
"""
|
933 |
+
Args:
|
934 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
935 |
+
in format (h, w, c) in numpy.ndarray
|
936 |
+
Returns:
|
937 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
938 |
+
"""
|
939 |
+
h, w = self.size
|
940 |
+
if isinstance(clip[0], np.ndarray):
|
941 |
+
im_h, im_w, im_c = clip[0].shape
|
942 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
943 |
+
im_w, im_h = clip[0].size
|
944 |
+
else:
|
945 |
+
raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
|
946 |
+
if w != im_w and h != im_h:
|
947 |
+
clip = FF.resize_clip(clip, self.size, interpolation="bilinear")
|
948 |
+
im_h, im_w, im_c = clip[0].shape
|
949 |
+
|
950 |
+
step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0)
|
951 |
+
cropped = []
|
952 |
+
for i in range(3):
|
953 |
+
if im_h > self.size[0]:
|
954 |
+
x1 = 0
|
955 |
+
y1 = i * step
|
956 |
+
cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
|
957 |
+
else:
|
958 |
+
x1 = i * step
|
959 |
+
y1 = 0
|
960 |
+
cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
|
961 |
+
return cropped
|
962 |
+
|
963 |
+
|
964 |
+
class RandomRotation(object):
|
965 |
+
"""Rotate entire clip randomly by a random angle within
|
966 |
+
given bounds
|
967 |
+
Args:
|
968 |
+
degrees (sequence or int): Range of degrees to select from
|
969 |
+
If degrees is a number instead of sequence like (min, max),
|
970 |
+
the range of degrees, will be (-degrees, +degrees).
|
971 |
+
"""
|
972 |
+
|
973 |
+
def __init__(self, degrees):
|
974 |
+
if isinstance(degrees, numbers.Number):
|
975 |
+
if degrees < 0:
|
976 |
+
raise ValueError("If degrees is a single number," "must be positive")
|
977 |
+
degrees = (-degrees, degrees)
|
978 |
+
else:
|
979 |
+
if len(degrees) != 2:
|
980 |
+
raise ValueError("If degrees is a sequence," "it must be of len 2.")
|
981 |
+
|
982 |
+
self.degrees = degrees
|
983 |
+
|
984 |
+
def __call__(self, clip):
|
985 |
+
"""
|
986 |
+
Args:
|
987 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
988 |
+
in format (h, w, c) in numpy.ndarray
|
989 |
+
Returns:
|
990 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
991 |
+
"""
|
992 |
+
import skimage
|
993 |
+
|
994 |
+
angle = random.uniform(self.degrees[0], self.degrees[1])
|
995 |
+
if isinstance(clip[0], np.ndarray):
|
996 |
+
rotated = [skimage.transform.rotate(img, angle) for img in clip]
|
997 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
998 |
+
rotated = [img.rotate(angle) for img in clip]
|
999 |
+
else:
|
1000 |
+
raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
|
1001 |
+
|
1002 |
+
return rotated
|
1003 |
+
|
1004 |
+
|
1005 |
+
class CenterCrop(object):
|
1006 |
+
"""Extract center crop at the same location for a list of images
|
1007 |
+
Args:
|
1008 |
+
size (sequence or int): Desired output size for the
|
1009 |
+
crop in format (h, w)
|
1010 |
+
"""
|
1011 |
+
|
1012 |
+
def __init__(self, size):
|
1013 |
+
if isinstance(size, numbers.Number):
|
1014 |
+
size = (size, size)
|
1015 |
+
|
1016 |
+
self.size = size
|
1017 |
+
|
1018 |
+
def __call__(self, clip):
|
1019 |
+
"""
|
1020 |
+
Args:
|
1021 |
+
img (PIL.Image or numpy.ndarray): List of images to be cropped
|
1022 |
+
in format (h, w, c) in numpy.ndarray
|
1023 |
+
Returns:
|
1024 |
+
PIL.Image or numpy.ndarray: Cropped list of images
|
1025 |
+
"""
|
1026 |
+
h, w = self.size
|
1027 |
+
if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
|
1028 |
+
if clip[0].shape[-1] == 3:
|
1029 |
+
im_h, im_w, im_c = clip[0].shape
|
1030 |
+
else:
|
1031 |
+
assert clip[0].shape[0] == 3
|
1032 |
+
im_c, im_h, im_w = clip[0].shape
|
1033 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1034 |
+
im_w, im_h = clip[0].size
|
1035 |
+
else:
|
1036 |
+
raise TypeError(
|
1037 |
+
"Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0]))
|
1038 |
+
)
|
1039 |
+
if w > im_w or h > im_h:
|
1040 |
+
error_msg = (
|
1041 |
+
"Initial image size should be larger then "
|
1042 |
+
"cropped size but got cropped sizes : ({w}, {h}) while "
|
1043 |
+
"initial image is ({im_w}, {im_h})".format(im_w=im_w, im_h=im_h, w=w, h=h)
|
1044 |
+
)
|
1045 |
+
raise ValueError(error_msg)
|
1046 |
+
|
1047 |
+
x1 = int(round((im_w - w) / 2.0))
|
1048 |
+
y1 = int(round((im_h - h) / 2.0))
|
1049 |
+
cropped = FF.crop_clip(clip, y1, x1, h, w)
|
1050 |
+
|
1051 |
+
return cropped
|
1052 |
+
|
1053 |
+
|
1054 |
+
class ColorJitter(object):
|
1055 |
+
"""
|
1056 |
+
Randomly change the brightness, contrast and saturation and hue of the clip
|
1057 |
+
|
1058 |
+
Args:
|
1059 |
+
brightness (float): How much to jitter brightness. brightness_factor
|
1060 |
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
1061 |
+
contrast (float): How much to jitter contrast. contrast_factor
|
1062 |
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
1063 |
+
saturation (float): How much to jitter saturation. saturation_factor
|
1064 |
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
1065 |
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
1066 |
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
1067 |
+
"""
|
1068 |
+
|
1069 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
1070 |
+
self.brightness = brightness
|
1071 |
+
self.contrast = contrast
|
1072 |
+
self.saturation = saturation
|
1073 |
+
self.hue = hue
|
1074 |
+
|
1075 |
+
def get_params(self, brightness, contrast, saturation, hue):
|
1076 |
+
if brightness > 0:
|
1077 |
+
brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
|
1078 |
+
else:
|
1079 |
+
brightness_factor = None
|
1080 |
+
|
1081 |
+
if contrast > 0:
|
1082 |
+
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
|
1083 |
+
else:
|
1084 |
+
contrast_factor = None
|
1085 |
+
|
1086 |
+
if saturation > 0:
|
1087 |
+
saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
|
1088 |
+
else:
|
1089 |
+
saturation_factor = None
|
1090 |
+
|
1091 |
+
if hue > 0:
|
1092 |
+
hue_factor = random.uniform(-hue, hue)
|
1093 |
+
else:
|
1094 |
+
hue_factor = None
|
1095 |
+
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
1096 |
+
|
1097 |
+
def __call__(self, clip):
|
1098 |
+
"""
|
1099 |
+
Args:
|
1100 |
+
clip (list): list of PIL.Image
|
1101 |
+
Returns:
|
1102 |
+
list PIL.Image : list of transformed PIL.Image
|
1103 |
+
"""
|
1104 |
+
if isinstance(clip[0], np.ndarray):
|
1105 |
+
raise TypeError("Color jitter not yet implemented for numpy arrays")
|
1106 |
+
elif isinstance(clip[0], PIL.Image.Image):
|
1107 |
+
brightness, contrast, saturation, hue = self.get_params(
|
1108 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
# Create img transform function sequence
|
1112 |
+
img_transforms = []
|
1113 |
+
if brightness is not None:
|
1114 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
1115 |
+
if saturation is not None:
|
1116 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
1117 |
+
if hue is not None:
|
1118 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
1119 |
+
if contrast is not None:
|
1120 |
+
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
1121 |
+
random.shuffle(img_transforms)
|
1122 |
+
|
1123 |
+
# Apply to all images
|
1124 |
+
jittered_clip = []
|
1125 |
+
for img in clip:
|
1126 |
+
for func in img_transforms:
|
1127 |
+
jittered_img = func(img)
|
1128 |
+
jittered_clip.append(jittered_img)
|
1129 |
+
|
1130 |
+
else:
|
1131 |
+
raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
|
1132 |
+
return jittered_clip
|
1133 |
+
|
1134 |
+
|
1135 |
+
class Normalize(object):
|
1136 |
+
"""Normalize a clip with mean and standard deviation.
|
1137 |
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
1138 |
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
1139 |
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
1140 |
+
.. note::
|
1141 |
+
This transform acts out of place, i.e., it does not mutates the input tensor.
|
1142 |
+
Args:
|
1143 |
+
mean (sequence): Sequence of means for each channel.
|
1144 |
+
std (sequence): Sequence of standard deviations for each channel.
|
1145 |
+
"""
|
1146 |
+
|
1147 |
+
def __init__(self, mean, std):
|
1148 |
+
self.mean = mean
|
1149 |
+
self.std = std
|
1150 |
+
|
1151 |
+
def __call__(self, clip):
|
1152 |
+
"""
|
1153 |
+
Args:
|
1154 |
+
clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
|
1155 |
+
Returns:
|
1156 |
+
Tensor: Normalized Tensor clip.
|
1157 |
+
"""
|
1158 |
+
return FF.normalize(clip, self.mean, self.std)
|
1159 |
+
|
1160 |
+
def __repr__(self):
|
1161 |
+
return self.__class__.__name__ + "(mean={0}, std={1})".format(self.mean, self.std)
|
src/datasets/utils/video/transforms_builder.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
|
11 |
+
import src.datasets.utils.video.transforms as video_transforms
|
12 |
+
from src.datasets.utils.video.randerase import RandomErasing
|
13 |
+
|
14 |
+
|
15 |
+
def make_transforms(
|
16 |
+
random_horizontal_flip=True,
|
17 |
+
random_resize_aspect_ratio=(3 / 4, 4 / 3),
|
18 |
+
random_resize_scale=(0.3, 1.0),
|
19 |
+
reprob=0.0,
|
20 |
+
auto_augment=False,
|
21 |
+
motion_shift=False,
|
22 |
+
crop_size=224,
|
23 |
+
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
24 |
+
pad_frame_count: Optional[int] = None,
|
25 |
+
pad_frame_method: str = "circulant",
|
26 |
+
):
|
27 |
+
_frames_augmentation = VideoTransform(
|
28 |
+
random_horizontal_flip=random_horizontal_flip,
|
29 |
+
random_resize_aspect_ratio=random_resize_aspect_ratio,
|
30 |
+
random_resize_scale=random_resize_scale,
|
31 |
+
reprob=reprob,
|
32 |
+
auto_augment=auto_augment,
|
33 |
+
motion_shift=motion_shift,
|
34 |
+
crop_size=crop_size,
|
35 |
+
normalize=normalize,
|
36 |
+
pad_frame_count=pad_frame_count,
|
37 |
+
pad_frame_method=pad_frame_method,
|
38 |
+
)
|
39 |
+
return _frames_augmentation
|
40 |
+
|
41 |
+
|
42 |
+
class VideoTransform(object):
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
random_horizontal_flip=True,
|
47 |
+
random_resize_aspect_ratio=(3 / 4, 4 / 3),
|
48 |
+
random_resize_scale=(0.3, 1.0),
|
49 |
+
reprob=0.0,
|
50 |
+
auto_augment=False,
|
51 |
+
motion_shift=False,
|
52 |
+
crop_size=224,
|
53 |
+
normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
54 |
+
pad_frame_count: Optional[int] = None,
|
55 |
+
pad_frame_method: str = "circulant",
|
56 |
+
):
|
57 |
+
self.random_horizontal_flip = random_horizontal_flip
|
58 |
+
self.random_resize_aspect_ratio = random_resize_aspect_ratio
|
59 |
+
self.random_resize_scale = random_resize_scale
|
60 |
+
self.auto_augment = auto_augment
|
61 |
+
self.motion_shift = motion_shift
|
62 |
+
self.crop_size = crop_size
|
63 |
+
self.mean = torch.tensor(normalize[0], dtype=torch.float32)
|
64 |
+
self.std = torch.tensor(normalize[1], dtype=torch.float32)
|
65 |
+
self.pad_frame_count = pad_frame_count
|
66 |
+
self.pad_frame_method = pad_frame_method
|
67 |
+
|
68 |
+
if not self.auto_augment:
|
69 |
+
# Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
|
70 |
+
self.mean *= 255.0
|
71 |
+
self.std *= 255.0
|
72 |
+
|
73 |
+
self.autoaug_transform = video_transforms.create_random_augment(
|
74 |
+
input_size=(crop_size, crop_size),
|
75 |
+
auto_augment="rand-m7-n4-mstd0.5-inc1",
|
76 |
+
interpolation="bicubic",
|
77 |
+
)
|
78 |
+
|
79 |
+
self.spatial_transform = (
|
80 |
+
video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
|
81 |
+
)
|
82 |
+
|
83 |
+
self.reprob = reprob
|
84 |
+
self.erase_transform = RandomErasing(
|
85 |
+
reprob,
|
86 |
+
mode="pixel",
|
87 |
+
max_count=1,
|
88 |
+
num_splits=1,
|
89 |
+
device="cpu",
|
90 |
+
)
|
91 |
+
|
92 |
+
def __call__(self, buffer):
|
93 |
+
|
94 |
+
if self.auto_augment:
|
95 |
+
buffer = [transforms.ToPILImage()(frame) for frame in buffer]
|
96 |
+
buffer = self.autoaug_transform(buffer)
|
97 |
+
buffer = [transforms.ToTensor()(img) for img in buffer]
|
98 |
+
buffer = torch.stack(buffer) # T C H W
|
99 |
+
buffer = buffer.permute(0, 2, 3, 1) # T H W C
|
100 |
+
elif torch.is_tensor(buffer):
|
101 |
+
# TODO: ensure input is always a tensor?
|
102 |
+
buffer = buffer.to(torch.float32)
|
103 |
+
else:
|
104 |
+
buffer = torch.tensor(buffer, dtype=torch.float32)
|
105 |
+
|
106 |
+
buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
|
107 |
+
|
108 |
+
buffer = self.spatial_transform(
|
109 |
+
images=buffer,
|
110 |
+
target_height=self.crop_size,
|
111 |
+
target_width=self.crop_size,
|
112 |
+
scale=self.random_resize_scale,
|
113 |
+
ratio=self.random_resize_aspect_ratio,
|
114 |
+
)
|
115 |
+
if self.random_horizontal_flip:
|
116 |
+
buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
|
117 |
+
|
118 |
+
buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
|
119 |
+
if self.reprob > 0:
|
120 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
121 |
+
buffer = self.erase_transform(buffer)
|
122 |
+
buffer = buffer.permute(1, 0, 2, 3)
|
123 |
+
|
124 |
+
if self.pad_frame_count is not None:
|
125 |
+
buffer = video_transforms.frame_pad(buffer, self.pad_frame_count, self.pad_frame_method)
|
126 |
+
|
127 |
+
return buffer
|
128 |
+
|
129 |
+
|
130 |
+
def tensor_normalize(tensor, mean, std):
|
131 |
+
"""
|
132 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
133 |
+
Args:
|
134 |
+
tensor (tensor): tensor to normalize.
|
135 |
+
mean (tensor or list): mean value to subtract.
|
136 |
+
std (tensor or list): std to divide.
|
137 |
+
"""
|
138 |
+
if tensor.dtype == torch.uint8:
|
139 |
+
tensor = tensor.float()
|
140 |
+
tensor = tensor / 255.0
|
141 |
+
if isinstance(mean, list):
|
142 |
+
mean = torch.tensor(mean)
|
143 |
+
if isinstance(std, list):
|
144 |
+
std = torch.tensor(std)
|
145 |
+
tensor = tensor - mean
|
146 |
+
tensor = tensor / std
|
147 |
+
return tensor
|
148 |
+
|
149 |
+
|
150 |
+
def _tensor_normalize_inplace(tensor, mean, std):
|
151 |
+
"""
|
152 |
+
Normalize a given tensor by subtracting the mean and dividing the std.
|
153 |
+
Args:
|
154 |
+
tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
|
155 |
+
mean (tensor): mean value to subtract (in 0 to 255 floats).
|
156 |
+
std (tensor): std to divide (in 0 to 255 floats).
|
157 |
+
"""
|
158 |
+
if tensor.dtype == torch.uint8:
|
159 |
+
tensor = tensor.float()
|
160 |
+
|
161 |
+
C, T, H, W = tensor.shape
|
162 |
+
tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
|
163 |
+
tensor.sub_(mean).div_(std)
|
164 |
+
tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
|
165 |
+
return tensor
|
src/datasets/utils/video/volume_transforms.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def convert_img(img):
|
12 |
+
"""Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
|
13 |
+
if len(img.shape) == 3:
|
14 |
+
img = img.transpose(2, 0, 1)
|
15 |
+
if len(img.shape) == 2:
|
16 |
+
img = np.expand_dims(img, 0)
|
17 |
+
return img
|
18 |
+
|
19 |
+
|
20 |
+
class ClipToTensor(object):
|
21 |
+
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
|
22 |
+
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, channel_nb=3, div_255=True, numpy=False):
|
26 |
+
self.channel_nb = channel_nb
|
27 |
+
self.div_255 = div_255
|
28 |
+
self.numpy = numpy
|
29 |
+
|
30 |
+
def __call__(self, clip):
|
31 |
+
"""
|
32 |
+
Args: clip (list of numpy.ndarray): clip (list of images)
|
33 |
+
to be converted to tensor.
|
34 |
+
"""
|
35 |
+
# Retrieve shape
|
36 |
+
if isinstance(clip[0], np.ndarray):
|
37 |
+
h, w, ch = clip[0].shape
|
38 |
+
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
|
39 |
+
elif isinstance(clip[0], Image.Image):
|
40 |
+
w, h = clip[0].size
|
41 |
+
elif isinstance(clip[0], torch.Tensor):
|
42 |
+
tensor_clip = torch.stack(clip)
|
43 |
+
# Converting (T, C, H, W) -> (C, T, H, W) to match what `convert_img` followed by
|
44 |
+
# `np_clip[:, img_idx, :, :] = img` does for other data types.
|
45 |
+
tensor_clip = tensor_clip.permute(1, 0, 2, 3)
|
46 |
+
if not isinstance(tensor_clip, torch.FloatTensor):
|
47 |
+
tensor_clip = tensor_clip.float()
|
48 |
+
if self.div_255:
|
49 |
+
tensor_clip = torch.div(tensor_clip, 255)
|
50 |
+
return tensor_clip
|
51 |
+
else:
|
52 |
+
raise TypeError(
|
53 |
+
"Expected numpy.ndarray or PIL.Image or torch.Tensor\
|
54 |
+
but got list of {0}".format(
|
55 |
+
type(clip[0])
|
56 |
+
)
|
57 |
+
)
|
58 |
+
|
59 |
+
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
|
60 |
+
|
61 |
+
# Convert
|
62 |
+
for img_idx, img in enumerate(clip):
|
63 |
+
if isinstance(img, np.ndarray):
|
64 |
+
pass
|
65 |
+
elif isinstance(img, Image.Image):
|
66 |
+
img = np.array(img, copy=False)
|
67 |
+
else:
|
68 |
+
raise TypeError(
|
69 |
+
"Expected numpy.ndarray or PIL.Image\
|
70 |
+
but got list of {0}".format(
|
71 |
+
type(clip[0])
|
72 |
+
)
|
73 |
+
)
|
74 |
+
img = convert_img(img)
|
75 |
+
np_clip[:, img_idx, :, :] = img
|
76 |
+
|
77 |
+
if self.numpy:
|
78 |
+
if self.div_255:
|
79 |
+
np_clip = np_clip / 255.0
|
80 |
+
return np_clip
|
81 |
+
|
82 |
+
else:
|
83 |
+
tensor_clip = torch.from_numpy(np_clip)
|
84 |
+
|
85 |
+
if not isinstance(tensor_clip, torch.FloatTensor):
|
86 |
+
tensor_clip = tensor_clip.float()
|
87 |
+
if self.div_255:
|
88 |
+
tensor_clip = torch.div(tensor_clip, 255)
|
89 |
+
return tensor_clip
|
90 |
+
|
91 |
+
|
92 |
+
# Note this norms data to -1/1
|
93 |
+
class ClipToTensor_K(object):
|
94 |
+
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
|
95 |
+
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, channel_nb=3, div_255=True, numpy=False):
|
99 |
+
self.channel_nb = channel_nb
|
100 |
+
self.div_255 = div_255
|
101 |
+
self.numpy = numpy
|
102 |
+
|
103 |
+
def __call__(self, clip):
|
104 |
+
"""
|
105 |
+
Args: clip (list of numpy.ndarray): clip (list of images)
|
106 |
+
to be converted to tensor.
|
107 |
+
"""
|
108 |
+
# Retrieve shape
|
109 |
+
if isinstance(clip[0], np.ndarray):
|
110 |
+
h, w, ch = clip[0].shape
|
111 |
+
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
|
112 |
+
elif isinstance(clip[0], Image.Image):
|
113 |
+
w, h = clip[0].size
|
114 |
+
else:
|
115 |
+
raise TypeError(
|
116 |
+
"Expected numpy.ndarray or PIL.Image\
|
117 |
+
but got list of {0}".format(
|
118 |
+
type(clip[0])
|
119 |
+
)
|
120 |
+
)
|
121 |
+
|
122 |
+
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
|
123 |
+
|
124 |
+
# Convert
|
125 |
+
for img_idx, img in enumerate(clip):
|
126 |
+
if isinstance(img, np.ndarray):
|
127 |
+
pass
|
128 |
+
elif isinstance(img, Image.Image):
|
129 |
+
img = np.array(img, copy=False)
|
130 |
+
else:
|
131 |
+
raise TypeError(
|
132 |
+
"Expected numpy.ndarray or PIL.Image\
|
133 |
+
but got list of {0}".format(
|
134 |
+
type(clip[0])
|
135 |
+
)
|
136 |
+
)
|
137 |
+
img = convert_img(img)
|
138 |
+
np_clip[:, img_idx, :, :] = img
|
139 |
+
if self.numpy:
|
140 |
+
if self.div_255:
|
141 |
+
np_clip = (np_clip - 127.5) / 127.5
|
142 |
+
return np_clip
|
143 |
+
|
144 |
+
else:
|
145 |
+
tensor_clip = torch.from_numpy(np_clip)
|
146 |
+
|
147 |
+
if not isinstance(tensor_clip, torch.FloatTensor):
|
148 |
+
tensor_clip = tensor_clip.float()
|
149 |
+
if self.div_255:
|
150 |
+
tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
|
151 |
+
return tensor_clip
|
152 |
+
|
153 |
+
|
154 |
+
class ToTensor(object):
|
155 |
+
"""Converts numpy array to tensor"""
|
156 |
+
|
157 |
+
def __call__(self, array):
|
158 |
+
tensor = torch.from_numpy(array)
|
159 |
+
return tensor
|
src/datasets/utils/weighted_sampler.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import Iterator, Optional
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DistributedSampler, RandomSampler
|
12 |
+
|
13 |
+
from src.utils.logging import get_logger
|
14 |
+
|
15 |
+
logger = get_logger("WeightedSampler")
|
16 |
+
|
17 |
+
|
18 |
+
class DistributedWeightedSampler(DistributedSampler):
|
19 |
+
"""
|
20 |
+
This class implements a weighted sampler for distributed training.
|
21 |
+
See https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler for more details.
|
22 |
+
|
23 |
+
It shares the same interface as `torch.utils.data.DistributedSampler`.
|
24 |
+
The effective change is replacing `DistributedSampler`'s `torch.randperm` for generating the sequence
|
25 |
+
of indices with `numpy.random.Generator.choice`, with replacement. This allows weighted sampling and
|
26 |
+
avoiding issue with `torch.randperm` when the number of samples is larger than 2^24 samples.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dataset,
|
32 |
+
num_replicas: Optional[int] = None,
|
33 |
+
rank: Optional[int] = None,
|
34 |
+
shuffle: bool = True,
|
35 |
+
seed: int = 0,
|
36 |
+
drop_last: bool = False,
|
37 |
+
):
|
38 |
+
logger.info(f"Using DistributedWeightedSampler with rank {rank} / {num_replicas}")
|
39 |
+
assert hasattr(
|
40 |
+
dataset, "sample_weights"
|
41 |
+
), "Dataset must have sample_weights property for using DistributedWeightedSampler"
|
42 |
+
super().__init__(
|
43 |
+
dataset,
|
44 |
+
num_replicas=num_replicas,
|
45 |
+
rank=rank,
|
46 |
+
shuffle=shuffle,
|
47 |
+
seed=seed,
|
48 |
+
drop_last=drop_last,
|
49 |
+
)
|
50 |
+
|
51 |
+
@property
|
52 |
+
def sample_probabilities(self) -> np.ndarray:
|
53 |
+
sample_weights = self.dataset.sample_weights
|
54 |
+
if isinstance(sample_weights, torch.Tensor):
|
55 |
+
sample_weights = sample_weights.cpu().numpy()
|
56 |
+
elif isinstance(sample_weights, list):
|
57 |
+
sample_weights = np.array(sample_weights)
|
58 |
+
assert isinstance(
|
59 |
+
sample_weights, np.ndarray
|
60 |
+
), f"sample_weights must be a numpy array, torch.Tensor, or python list; got {type(sample_weights)}"
|
61 |
+
return sample_weights / np.sum(sample_weights)
|
62 |
+
|
63 |
+
def __iter__(self) -> Iterator:
|
64 |
+
n = len(self.dataset)
|
65 |
+
|
66 |
+
# deterministically shuffle based on epoch and seed
|
67 |
+
rng = np.random.default_rng(self.seed + self.epoch)
|
68 |
+
indices = rng.choice(
|
69 |
+
range(0, n),
|
70 |
+
size=self.total_size,
|
71 |
+
p=self.sample_probabilities,
|
72 |
+
replace=True,
|
73 |
+
).tolist()
|
74 |
+
|
75 |
+
if not self.drop_last:
|
76 |
+
# add extra samples to make it evenly divisible
|
77 |
+
padding_size = self.total_size - len(indices)
|
78 |
+
if padding_size <= len(indices):
|
79 |
+
indices += indices[:padding_size]
|
80 |
+
else:
|
81 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
82 |
+
else:
|
83 |
+
# remove tail of data to make it evenly divisible
|
84 |
+
indices = indices[: self.total_size]
|
85 |
+
assert len(indices) == self.total_size
|
86 |
+
|
87 |
+
# subsample
|
88 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
89 |
+
assert len(indices) == self.num_samples
|
90 |
+
|
91 |
+
return iter(indices)
|
92 |
+
|
93 |
+
|
94 |
+
class MemoryEfficientDistributedWeightedSampler(DistributedSampler):
|
95 |
+
"""
|
96 |
+
This class implements a memory efficient version of `DistributedWeightedSampler`.
|
97 |
+
It shares the same interface as `DistributedWeightedSampler`.
|
98 |
+
The effective change is just-in-time sampling of the indices, instead of pre-computing them.
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
dataset,
|
104 |
+
num_replicas: Optional[int] = None,
|
105 |
+
rank: Optional[int] = None,
|
106 |
+
shuffle: bool = True,
|
107 |
+
seed: int = 0,
|
108 |
+
):
|
109 |
+
logger.info(f"Using MemoryEfficientDistributedWeightedSampler with rank {rank} / {num_replicas}")
|
110 |
+
assert hasattr(
|
111 |
+
dataset, "dataset_weights"
|
112 |
+
), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSampler"
|
113 |
+
super().__init__(
|
114 |
+
dataset,
|
115 |
+
num_replicas=num_replicas,
|
116 |
+
rank=rank,
|
117 |
+
shuffle=shuffle,
|
118 |
+
seed=seed,
|
119 |
+
)
|
120 |
+
|
121 |
+
self.dataset_weights = dataset.dataset_weights
|
122 |
+
self.dataset_sizes = [len(d) for d in dataset.datasets]
|
123 |
+
if len(self.dataset_sizes) != len(self.dataset_weights):
|
124 |
+
raise ValueError(
|
125 |
+
f"Number of datasets ({len(self.dataset_sizes)}) "
|
126 |
+
f"does not match number of dataset weights ({len(self.dataset_weights)})"
|
127 |
+
)
|
128 |
+
|
129 |
+
if self.shuffle:
|
130 |
+
self.rng = np.random.default_rng(self.seed + self.rank + self.epoch)
|
131 |
+
total_weights = sum(self.dataset_weights)
|
132 |
+
self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights])
|
133 |
+
else:
|
134 |
+
if any([not isinstance(w, int) for w in self.dataset_weights]):
|
135 |
+
raise ValueError("Dataset weights must be integers when shuffle is False")
|
136 |
+
|
137 |
+
self.dataset_orders = []
|
138 |
+
for i, w in enumerate(self.dataset_weights):
|
139 |
+
self.dataset_orders.extend([i] * w)
|
140 |
+
|
141 |
+
self.drawn_samples = 0
|
142 |
+
|
143 |
+
def __iter__(self) -> Iterator:
|
144 |
+
return self
|
145 |
+
|
146 |
+
def __next__(self) -> int:
|
147 |
+
if self.shuffle:
|
148 |
+
selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities)
|
149 |
+
|
150 |
+
# In order to avoid sampling the same example multiple times between the ranks,
|
151 |
+
# we limit each rank to a subset of the total number of samples in the dataset.
|
152 |
+
# For example if our dataet is [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], and we have 2 ranks,
|
153 |
+
# then rank 0 will ONLY sample from [0, 2, 4, 6, 8], and rank 1 from [1, 3, 5, 7, 9].
|
154 |
+
# In each iteration we first produce `in_rank_sample` which is the sample index in the rank,
|
155 |
+
# based on the size of the subset which that rank can sample from.
|
156 |
+
# Then we computer `sample_idx_in_dataset` for the indx of the sample in the whole dataset.
|
157 |
+
# For the above example if we are sampling for rank 1, we have `self.rng.integers(5)`.
|
158 |
+
# Let's assume the result is 2, then `in_rank_sample` is 2 (number "5" in the subset),
|
159 |
+
# so the sample index in the whole dataset is
|
160 |
+
# `in_rank_sample * self.num_replicas + self.rank`: 2 * 2 + 1 = 5.
|
161 |
+
|
162 |
+
selected_dataset_size = self.dataset_sizes[selected_dataset_idx]
|
163 |
+
# 1) Getting sample index in the rank.
|
164 |
+
# NOTE: this may effectively drops the last batch,
|
165 |
+
# but given the sample sizes that we use this sampler with, it should not be an issue.
|
166 |
+
num_samples_in_rank = selected_dataset_size // self.num_replicas
|
167 |
+
in_rank_sample = self.rng.integers(num_samples_in_rank)
|
168 |
+
|
169 |
+
# 2) Getting sample index in the dataset.
|
170 |
+
sample_idx_in_dataset = in_rank_sample * self.num_replicas + self.rank
|
171 |
+
|
172 |
+
else:
|
173 |
+
# Iterate through the dataset orders in a round-robin fashion, offset by the rank
|
174 |
+
dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders)
|
175 |
+
selected_dataset_idx = self.dataset_orders[dataset_orders_idx]
|
176 |
+
# Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples
|
177 |
+
sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[
|
178 |
+
selected_dataset_idx
|
179 |
+
]
|
180 |
+
self.drawn_samples += 1
|
181 |
+
|
182 |
+
# Getting the index of the sample in the whole dataset
|
183 |
+
# For example if the total dataset has 4 datasets with sizes [10, 20, 30, 5].
|
184 |
+
# and our selected_dataset_idx=3 and sample_idx_in_dataset=5
|
185 |
+
# then the index of the sample in the whole dataset is
|
186 |
+
# 10 (for dataset 1) + 20 (for dataset 1) + 5 (for sample_idx_in_dataset) = 35
|
187 |
+
# This is because the first 10 samples are from the first dataset, the next 20 are from the second dataset,
|
188 |
+
# then we reach at the 3rd dataset which is the selected dataset, and the 5th sample in the 3rd dataset.
|
189 |
+
sample_idx = 0
|
190 |
+
for i, d in enumerate(self.dataset_sizes):
|
191 |
+
if selected_dataset_idx == i:
|
192 |
+
break
|
193 |
+
sample_idx += d
|
194 |
+
sample_idx += sample_idx_in_dataset
|
195 |
+
|
196 |
+
return sample_idx
|
197 |
+
|
198 |
+
|
199 |
+
def safe_next(iterator):
|
200 |
+
try:
|
201 |
+
return next(iterator)
|
202 |
+
except StopIteration:
|
203 |
+
return None
|
204 |
+
|
205 |
+
|
206 |
+
class MemoryEfficientDistributedWeightedSamplerLessRepeat(DistributedSampler):
|
207 |
+
"""
|
208 |
+
This class implements a memory efficient version of `DistributedWeightedSampler`.
|
209 |
+
It shares the same interface as `DistributedWeightedSampler`.
|
210 |
+
The effective change is pre-computing the permutations of indices over a subset of total indices.
|
211 |
+
This subset is the selected with picking the indices in a dataset with steps sizes equal to the world size.
|
212 |
+
For example, if world size is 12 and rank is 2, for a dataset of size N,
|
213 |
+
this sampler only permutes the indices in range(2, n, 12)
|
214 |
+
|
215 |
+
Compared with MemoryEfficientDistributedWeightedSampler, this will reduce the effective number of repeat.
|
216 |
+
See discussions here: https://github.com/fairinternal/jepa-internal/pull/254
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
dataset,
|
222 |
+
num_replicas: Optional[int] = None,
|
223 |
+
rank: Optional[int] = None,
|
224 |
+
shuffle: bool = True,
|
225 |
+
seed: int = 0,
|
226 |
+
):
|
227 |
+
logger.info(f"Using MemoryEfficientDistributedWeightedSamplerLessRepeat with rank {rank} / {num_replicas}")
|
228 |
+
assert hasattr(
|
229 |
+
dataset, "dataset_weights"
|
230 |
+
), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSamplerLessRepeat"
|
231 |
+
super().__init__(
|
232 |
+
dataset,
|
233 |
+
num_replicas=num_replicas,
|
234 |
+
rank=rank,
|
235 |
+
shuffle=shuffle,
|
236 |
+
seed=seed,
|
237 |
+
)
|
238 |
+
|
239 |
+
self._generator = torch.Generator()
|
240 |
+
self._generator.manual_seed(seed)
|
241 |
+
|
242 |
+
self.dataset_weights = dataset.dataset_weights
|
243 |
+
self.dataset_sizes = [len(d) for d in dataset.datasets]
|
244 |
+
if len(self.dataset_sizes) != len(self.dataset_weights):
|
245 |
+
raise ValueError(
|
246 |
+
f"Number of datasets ({len(self.dataset_sizes)}) "
|
247 |
+
f"does not match number of dataset weights ({len(self.dataset_weights)})"
|
248 |
+
)
|
249 |
+
|
250 |
+
if self.shuffle:
|
251 |
+
self.rng = np.random.default_rng(self.seed + self.rank + self.epoch)
|
252 |
+
total_weights = sum(self.dataset_weights)
|
253 |
+
self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights])
|
254 |
+
|
255 |
+
# For each dataset we generate a permutation of the indices that will be processed by that rank.
|
256 |
+
# This is going to be the subset of indices, selected by the steps sizes of the world size.
|
257 |
+
logger.info(f"Generating dataset indices for rank {self.rank} / {self.num_replicas}")
|
258 |
+
|
259 |
+
# Getting a RandomSampler for indices assigned to each dataset.
|
260 |
+
self.individual_dataset_sampler = []
|
261 |
+
for ids, ds in enumerate(self.dataset_sizes):
|
262 |
+
|
263 |
+
# NOTE: this may effectively drops the last batch,
|
264 |
+
# but given the sample sizes that we use this sampler with, it should not be an issue.
|
265 |
+
num_samples_in_rank = ds // self.num_replicas
|
266 |
+
self.individual_dataset_sampler.append(self._new_sampler(num_samples_in_rank))
|
267 |
+
|
268 |
+
else:
|
269 |
+
if any([not isinstance(w, int) for w in self.dataset_weights]):
|
270 |
+
raise ValueError("Dataset weights must be integers when shuffle is False")
|
271 |
+
|
272 |
+
self.dataset_orders = []
|
273 |
+
for i, w in enumerate(self.dataset_weights):
|
274 |
+
self.dataset_orders.extend([i] * w)
|
275 |
+
|
276 |
+
self.drawn_samples = 0
|
277 |
+
|
278 |
+
def __iter__(self) -> Iterator:
|
279 |
+
return self
|
280 |
+
|
281 |
+
def _new_sampler(self, sample_size: int) -> RandomSampler:
|
282 |
+
assert self.shuffle
|
283 |
+
|
284 |
+
return iter(
|
285 |
+
RandomSampler(
|
286 |
+
range(sample_size),
|
287 |
+
generator=self._generator,
|
288 |
+
)
|
289 |
+
)
|
290 |
+
|
291 |
+
def _in_rank_next_index_for_dataset(self, dataset_idx: int) -> int:
|
292 |
+
assert self.shuffle
|
293 |
+
|
294 |
+
next_sampler_idx = safe_next(self.individual_dataset_sampler[dataset_idx])
|
295 |
+
if next_sampler_idx is None:
|
296 |
+
# We have reached the end of the dataset, we need to reset the sampler.
|
297 |
+
num_samples_in_rank = self.dataset_sizes[dataset_idx] // self.num_replicas
|
298 |
+
self.individual_dataset_sampler[dataset_idx] = self._new_sampler(num_samples_in_rank)
|
299 |
+
next_sampler_idx = safe_next(self.individual_dataset_sampler[dataset_idx])
|
300 |
+
assert next_sampler_idx is not None
|
301 |
+
|
302 |
+
return next_sampler_idx
|
303 |
+
|
304 |
+
def __next__(self) -> int:
|
305 |
+
if self.shuffle:
|
306 |
+
selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities)
|
307 |
+
in_rank_sample = self._in_rank_next_index_for_dataset(selected_dataset_idx)
|
308 |
+
|
309 |
+
# 2) Getting sample index in the dataset.
|
310 |
+
sample_idx_in_dataset = in_rank_sample * self.num_replicas + self.rank
|
311 |
+
|
312 |
+
else:
|
313 |
+
# Iterate through the dataset orders in a round-robin fashion, offset by the rank
|
314 |
+
dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders)
|
315 |
+
selected_dataset_idx = self.dataset_orders[dataset_orders_idx]
|
316 |
+
# Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples
|
317 |
+
sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[
|
318 |
+
selected_dataset_idx
|
319 |
+
]
|
320 |
+
self.drawn_samples += 1
|
321 |
+
|
322 |
+
# Getting the index of the sample in the whole dataset
|
323 |
+
# For example if the total dataset has 4 datasets with sizes [10, 20, 30, 5].
|
324 |
+
# and our selected_dataset_idx=3 and sample_idx_in_dataset=5
|
325 |
+
# then the index of the sample in the whole dataset is
|
326 |
+
# 10 (for dataset 1) + 20 (for dataset 1) + 5 (for sample_idx_in_dataset) = 35
|
327 |
+
# This is because the first 10 samples are from the first dataset, the next 20 are from the second dataset,
|
328 |
+
# then we reach at the 3rd dataset which is the selected dataset, and the 5th sample in the 3rd dataset.
|
329 |
+
sample_idx = 0
|
330 |
+
for i, d in enumerate(self.dataset_sizes):
|
331 |
+
if selected_dataset_idx == i:
|
332 |
+
break
|
333 |
+
sample_idx += d
|
334 |
+
sample_idx += sample_idx_in_dataset
|
335 |
+
|
336 |
+
return sample_idx
|
src/datasets/utils/worker_init_fn.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Copyright The Lightning AI team.
|
4 |
+
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
# This code originally comes from PyTorch Lighting with some light modificaitons:
|
18 |
+
# https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/utilities/seed.py
|
19 |
+
|
20 |
+
|
21 |
+
import os
|
22 |
+
import random
|
23 |
+
from typing import Optional
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
|
28 |
+
from src.utils.logging import get_logger
|
29 |
+
|
30 |
+
logger = get_logger("worker_init_fn")
|
31 |
+
|
32 |
+
|
33 |
+
def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]:
|
34 |
+
"""Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
|
35 |
+
algorithm."""
|
36 |
+
# Combine base seed, worker id and rank into a unique 64-bit number
|
37 |
+
combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
|
38 |
+
seeds = []
|
39 |
+
for _ in range(count):
|
40 |
+
# x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
|
41 |
+
combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
|
42 |
+
seeds.append(combined_seed)
|
43 |
+
return seeds
|
44 |
+
|
45 |
+
|
46 |
+
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
|
47 |
+
r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
|
48 |
+
``seed_everything(seed, workers=True)``.
|
49 |
+
|
50 |
+
See also the PyTorch documentation on
|
51 |
+
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
|
52 |
+
|
53 |
+
"""
|
54 |
+
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
|
55 |
+
if rank is None:
|
56 |
+
procid = os.environ.get("SLURM_PROCID")
|
57 |
+
if procid is None:
|
58 |
+
logger.warning("SLURM_PROCID is not set, setting rank to 0")
|
59 |
+
rank = 0
|
60 |
+
else:
|
61 |
+
rank = int(procid)
|
62 |
+
|
63 |
+
process_seed = torch.initial_seed()
|
64 |
+
# back out the base seed so we can use all the bits
|
65 |
+
base_seed = process_seed - worker_id
|
66 |
+
logger.debug(
|
67 |
+
f"Initializing random number generators of process {rank} worker {worker_id} with base seed {base_seed}"
|
68 |
+
)
|
69 |
+
seed_sequence = _generate_seed_sequence(base_seed, worker_id, rank, count=4)
|
70 |
+
torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed
|
71 |
+
random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds
|
72 |
+
|
73 |
+
ss = np.random.SeedSequence([base_seed, worker_id, rank])
|
74 |
+
np_rng_seed = ss.generate_state(4)
|
75 |
+
|
76 |
+
np.random.seed(np_rng_seed)
|
src/datasets/video_dataset.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import pathlib
|
9 |
+
import warnings
|
10 |
+
from logging import getLogger
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
import torch
|
15 |
+
import torchvision
|
16 |
+
from decord import VideoReader, cpu
|
17 |
+
|
18 |
+
from src.datasets.utils.dataloader import ConcatIndices, MonitoredDataset, NondeterministicDataLoader
|
19 |
+
from src.datasets.utils.weighted_sampler import DistributedWeightedSampler
|
20 |
+
|
21 |
+
_GLOBAL_SEED = 0
|
22 |
+
logger = getLogger()
|
23 |
+
|
24 |
+
|
25 |
+
def make_videodataset(
|
26 |
+
data_paths,
|
27 |
+
batch_size,
|
28 |
+
frames_per_clip=8,
|
29 |
+
dataset_fpcs=None,
|
30 |
+
frame_step=4,
|
31 |
+
duration=None,
|
32 |
+
fps=None,
|
33 |
+
num_clips=1,
|
34 |
+
random_clip_sampling=True,
|
35 |
+
allow_clip_overlap=False,
|
36 |
+
filter_short_videos=False,
|
37 |
+
filter_long_videos=int(10**9),
|
38 |
+
transform=None,
|
39 |
+
shared_transform=None,
|
40 |
+
rank=0,
|
41 |
+
world_size=1,
|
42 |
+
datasets_weights=None,
|
43 |
+
collator=None,
|
44 |
+
drop_last=True,
|
45 |
+
num_workers=10,
|
46 |
+
pin_mem=True,
|
47 |
+
persistent_workers=True,
|
48 |
+
deterministic=True,
|
49 |
+
log_dir=None,
|
50 |
+
):
|
51 |
+
dataset = VideoDataset(
|
52 |
+
data_paths=data_paths,
|
53 |
+
datasets_weights=datasets_weights,
|
54 |
+
frames_per_clip=frames_per_clip,
|
55 |
+
dataset_fpcs=dataset_fpcs,
|
56 |
+
duration=duration,
|
57 |
+
fps=fps,
|
58 |
+
frame_step=frame_step,
|
59 |
+
num_clips=num_clips,
|
60 |
+
random_clip_sampling=random_clip_sampling,
|
61 |
+
allow_clip_overlap=allow_clip_overlap,
|
62 |
+
filter_short_videos=filter_short_videos,
|
63 |
+
filter_long_videos=filter_long_videos,
|
64 |
+
shared_transform=shared_transform,
|
65 |
+
transform=transform,
|
66 |
+
)
|
67 |
+
|
68 |
+
log_dir = pathlib.Path(log_dir) if log_dir else None
|
69 |
+
if log_dir:
|
70 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
71 |
+
# Worker ID will replace '%w'
|
72 |
+
resource_log_filename = log_dir / f"resource_file_{rank}_%w.csv"
|
73 |
+
dataset = MonitoredDataset(
|
74 |
+
dataset=dataset,
|
75 |
+
log_filename=str(resource_log_filename),
|
76 |
+
log_interval=10.0,
|
77 |
+
monitor_interval=5.0,
|
78 |
+
)
|
79 |
+
|
80 |
+
logger.info("VideoDataset dataset created")
|
81 |
+
if datasets_weights is not None:
|
82 |
+
dist_sampler = DistributedWeightedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
83 |
+
else:
|
84 |
+
dist_sampler = torch.utils.data.distributed.DistributedSampler(
|
85 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=True
|
86 |
+
)
|
87 |
+
|
88 |
+
if deterministic:
|
89 |
+
data_loader = torch.utils.data.DataLoader(
|
90 |
+
dataset,
|
91 |
+
collate_fn=collator,
|
92 |
+
sampler=dist_sampler,
|
93 |
+
batch_size=batch_size,
|
94 |
+
drop_last=drop_last,
|
95 |
+
pin_memory=pin_mem,
|
96 |
+
num_workers=num_workers,
|
97 |
+
persistent_workers=(num_workers > 0) and persistent_workers,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
data_loader = NondeterministicDataLoader(
|
101 |
+
dataset,
|
102 |
+
collate_fn=collator,
|
103 |
+
sampler=dist_sampler,
|
104 |
+
batch_size=batch_size,
|
105 |
+
drop_last=drop_last,
|
106 |
+
pin_memory=pin_mem,
|
107 |
+
num_workers=num_workers,
|
108 |
+
persistent_workers=(num_workers > 0) and persistent_workers,
|
109 |
+
)
|
110 |
+
logger.info("VideoDataset unsupervised data loader created")
|
111 |
+
|
112 |
+
return dataset, data_loader, dist_sampler
|
113 |
+
|
114 |
+
|
115 |
+
class VideoDataset(torch.utils.data.Dataset):
|
116 |
+
"""Video classification dataset."""
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
data_paths,
|
121 |
+
datasets_weights=None,
|
122 |
+
frames_per_clip=16,
|
123 |
+
fps=None,
|
124 |
+
dataset_fpcs=None,
|
125 |
+
frame_step=4,
|
126 |
+
num_clips=1,
|
127 |
+
transform=None,
|
128 |
+
shared_transform=None,
|
129 |
+
random_clip_sampling=True,
|
130 |
+
allow_clip_overlap=False,
|
131 |
+
filter_short_videos=False,
|
132 |
+
filter_long_videos=int(10**9),
|
133 |
+
duration=None, # duration in seconds
|
134 |
+
):
|
135 |
+
self.data_paths = data_paths
|
136 |
+
self.datasets_weights = datasets_weights
|
137 |
+
self.frame_step = frame_step
|
138 |
+
self.num_clips = num_clips
|
139 |
+
self.transform = transform
|
140 |
+
self.shared_transform = shared_transform
|
141 |
+
self.random_clip_sampling = random_clip_sampling
|
142 |
+
self.allow_clip_overlap = allow_clip_overlap
|
143 |
+
self.filter_short_videos = filter_short_videos
|
144 |
+
self.filter_long_videos = filter_long_videos
|
145 |
+
self.duration = duration
|
146 |
+
self.fps = fps
|
147 |
+
|
148 |
+
if sum([v is not None for v in (fps, duration, frame_step)]) != 1:
|
149 |
+
raise ValueError(f"Must specify exactly one of either {fps=}, {duration=}, or {frame_step=}.")
|
150 |
+
|
151 |
+
if isinstance(data_paths, str):
|
152 |
+
data_paths = [data_paths]
|
153 |
+
|
154 |
+
if dataset_fpcs is None:
|
155 |
+
self.dataset_fpcs = [frames_per_clip for _ in data_paths]
|
156 |
+
else:
|
157 |
+
if len(dataset_fpcs) != len(data_paths):
|
158 |
+
raise ValueError("Frames per clip not properly specified for NFS data paths")
|
159 |
+
self.dataset_fpcs = dataset_fpcs
|
160 |
+
|
161 |
+
if VideoReader is None:
|
162 |
+
raise ImportError('Unable to import "decord" which is required to read videos.')
|
163 |
+
|
164 |
+
# Load video paths and labels
|
165 |
+
samples, labels = [], []
|
166 |
+
self.num_samples_per_dataset = []
|
167 |
+
for data_path in self.data_paths:
|
168 |
+
|
169 |
+
if data_path[-4:] == ".csv":
|
170 |
+
try:
|
171 |
+
data = pd.read_csv(data_path, header=None, delimiter=" ")
|
172 |
+
except pd.errors.ParserError:
|
173 |
+
# In image captioning datasets where we have space, we use :: as delimiter.
|
174 |
+
data = pd.read_csv(data_path, header=None, delimiter="::")
|
175 |
+
samples += list(data.values[:, 0])
|
176 |
+
labels += list(data.values[:, 1])
|
177 |
+
num_samples = len(data)
|
178 |
+
self.num_samples_per_dataset.append(num_samples)
|
179 |
+
|
180 |
+
elif data_path[-4:] == ".npy":
|
181 |
+
data = np.load(data_path, allow_pickle=True)
|
182 |
+
data = list(map(lambda x: repr(x)[1:-1], data))
|
183 |
+
samples += data
|
184 |
+
labels += [0] * len(data)
|
185 |
+
num_samples = len(data)
|
186 |
+
self.num_samples_per_dataset.append(len(data))
|
187 |
+
|
188 |
+
self.per_dataset_indices = ConcatIndices(self.num_samples_per_dataset)
|
189 |
+
|
190 |
+
# [Optional] Weights for each sample to be used by downstream
|
191 |
+
# weighted video sampler
|
192 |
+
self.sample_weights = None
|
193 |
+
if self.datasets_weights is not None:
|
194 |
+
self.sample_weights = []
|
195 |
+
for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset):
|
196 |
+
self.sample_weights += [dw / ns] * ns
|
197 |
+
|
198 |
+
self.samples = samples
|
199 |
+
self.labels = labels
|
200 |
+
|
201 |
+
def __getitem__(self, index):
|
202 |
+
sample = self.samples[index]
|
203 |
+
loaded_sample = False
|
204 |
+
# Keep trying to load videos until you find a valid sample
|
205 |
+
while not loaded_sample:
|
206 |
+
if not isinstance(sample, str):
|
207 |
+
logger.warning("Invalid sample.")
|
208 |
+
else:
|
209 |
+
if sample.split(".")[-1].lower() in ("jpg", "png", "jpeg"):
|
210 |
+
loaded_sample = self.get_item_image(index)
|
211 |
+
else:
|
212 |
+
loaded_sample = self.get_item_video(index)
|
213 |
+
|
214 |
+
if not loaded_sample:
|
215 |
+
index = np.random.randint(self.__len__())
|
216 |
+
sample = self.samples[index]
|
217 |
+
|
218 |
+
return loaded_sample
|
219 |
+
|
220 |
+
def get_item_video(self, index):
|
221 |
+
sample = self.samples[index]
|
222 |
+
dataset_idx, _ = self.per_dataset_indices[index]
|
223 |
+
frames_per_clip = self.dataset_fpcs[dataset_idx]
|
224 |
+
|
225 |
+
buffer, clip_indices = self.loadvideo_decord(sample, frames_per_clip) # [T H W 3]
|
226 |
+
loaded_video = len(buffer) > 0
|
227 |
+
if not loaded_video:
|
228 |
+
return
|
229 |
+
|
230 |
+
# Label/annotations for video
|
231 |
+
label = self.labels[index]
|
232 |
+
|
233 |
+
def split_into_clips(video):
|
234 |
+
"""Split video into a list of clips"""
|
235 |
+
fpc = frames_per_clip
|
236 |
+
nc = self.num_clips
|
237 |
+
return [video[i * fpc : (i + 1) * fpc] for i in range(nc)]
|
238 |
+
|
239 |
+
# Parse video into frames & apply data augmentations
|
240 |
+
if self.shared_transform is not None:
|
241 |
+
buffer = self.shared_transform(buffer)
|
242 |
+
buffer = split_into_clips(buffer)
|
243 |
+
if self.transform is not None:
|
244 |
+
buffer = [self.transform(clip) for clip in buffer]
|
245 |
+
|
246 |
+
return buffer, label, clip_indices
|
247 |
+
|
248 |
+
def get_item_image(self, index):
|
249 |
+
sample = self.samples[index]
|
250 |
+
dataset_idx, _ = self.per_dataset_indices[index]
|
251 |
+
fpc = self.dataset_fpcs[dataset_idx]
|
252 |
+
|
253 |
+
try:
|
254 |
+
image_tensor = torchvision.io.read_image(path=sample, mode=torchvision.io.ImageReadMode.RGB)
|
255 |
+
except Exception:
|
256 |
+
return
|
257 |
+
label = self.labels[index]
|
258 |
+
clip_indices = [np.arange(start=0, stop=fpc, dtype=np.int32)]
|
259 |
+
|
260 |
+
# Expanding the input image [3, H, W] ==> [T, 3, H, W]
|
261 |
+
buffer = image_tensor.unsqueeze(dim=0).repeat((fpc, 1, 1, 1))
|
262 |
+
buffer = buffer.permute((0, 2, 3, 1)) # [T, 3, H, W] ==> [T H W 3]
|
263 |
+
|
264 |
+
if self.shared_transform is not None:
|
265 |
+
# Technically we can have only transform, doing this just for the sake of consistency with videos.
|
266 |
+
buffer = self.shared_transform(buffer)
|
267 |
+
|
268 |
+
if self.transform is not None:
|
269 |
+
buffer = [self.transform(buffer)]
|
270 |
+
|
271 |
+
return buffer, label, clip_indices
|
272 |
+
|
273 |
+
def loadvideo_decord(self, sample, fpc):
|
274 |
+
"""Load video content using Decord"""
|
275 |
+
|
276 |
+
fname = sample
|
277 |
+
if not os.path.exists(fname):
|
278 |
+
warnings.warn(f"video path not found {fname=}")
|
279 |
+
return [], None
|
280 |
+
|
281 |
+
_fsize = os.path.getsize(fname)
|
282 |
+
if _fsize > self.filter_long_videos:
|
283 |
+
warnings.warn(f"skipping long video of size {_fsize=} (bytes)")
|
284 |
+
return [], None
|
285 |
+
|
286 |
+
try:
|
287 |
+
vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
|
288 |
+
except Exception:
|
289 |
+
return [], None
|
290 |
+
|
291 |
+
fstp = self.frame_step
|
292 |
+
if self.duration is not None or self.fps is not None:
|
293 |
+
try:
|
294 |
+
video_fps = math.ceil(vr.get_avg_fps())
|
295 |
+
except Exception as e:
|
296 |
+
logger.warning(e)
|
297 |
+
|
298 |
+
if self.duration is not None:
|
299 |
+
assert self.fps is None
|
300 |
+
fstp = int(self.duration * video_fps / fpc)
|
301 |
+
else:
|
302 |
+
assert self.duration is None
|
303 |
+
fstp = video_fps // self.fps
|
304 |
+
|
305 |
+
assert fstp is not None and fstp > 0
|
306 |
+
clip_len = int(fpc * fstp)
|
307 |
+
|
308 |
+
if self.filter_short_videos and len(vr) < clip_len:
|
309 |
+
warnings.warn(f"skipping video of length {len(vr)}")
|
310 |
+
return [], None
|
311 |
+
|
312 |
+
vr.seek(0) # Go to start of video before sampling frames
|
313 |
+
|
314 |
+
# Partition video into equal sized segments and sample each clip
|
315 |
+
# from a different segment
|
316 |
+
partition_len = len(vr) // self.num_clips
|
317 |
+
|
318 |
+
all_indices, clip_indices = [], []
|
319 |
+
for i in range(self.num_clips):
|
320 |
+
|
321 |
+
if partition_len > clip_len:
|
322 |
+
# If partition_len > clip len, then sample a random window of
|
323 |
+
# clip_len frames within the segment
|
324 |
+
end_indx = clip_len
|
325 |
+
if self.random_clip_sampling:
|
326 |
+
end_indx = np.random.randint(clip_len, partition_len)
|
327 |
+
start_indx = end_indx - clip_len
|
328 |
+
indices = np.linspace(start_indx, end_indx, num=fpc)
|
329 |
+
indices = np.clip(indices, start_indx, end_indx - 1).astype(np.int64)
|
330 |
+
# --
|
331 |
+
indices = indices + i * partition_len
|
332 |
+
else:
|
333 |
+
# If partition overlap not allowed and partition_len < clip_len
|
334 |
+
# then repeatedly append the last frame in the segment until
|
335 |
+
# we reach the desired clip length
|
336 |
+
if not self.allow_clip_overlap:
|
337 |
+
indices = np.linspace(0, partition_len, num=partition_len // fstp)
|
338 |
+
indices = np.concatenate(
|
339 |
+
(
|
340 |
+
indices,
|
341 |
+
np.ones(fpc - partition_len // fstp) * partition_len,
|
342 |
+
)
|
343 |
+
)
|
344 |
+
indices = np.clip(indices, 0, partition_len - 1).astype(np.int64)
|
345 |
+
# --
|
346 |
+
indices = indices + i * partition_len
|
347 |
+
|
348 |
+
# If partition overlap is allowed and partition_len < clip_len
|
349 |
+
# then start_indx of segment i+1 will lie within segment i
|
350 |
+
else:
|
351 |
+
sample_len = min(clip_len, len(vr)) - 1
|
352 |
+
indices = np.linspace(0, sample_len, num=sample_len // fstp)
|
353 |
+
indices = np.concatenate(
|
354 |
+
(
|
355 |
+
indices,
|
356 |
+
np.ones(fpc - sample_len // fstp) * sample_len,
|
357 |
+
)
|
358 |
+
)
|
359 |
+
indices = np.clip(indices, 0, sample_len - 1).astype(np.int64)
|
360 |
+
# --
|
361 |
+
clip_step = 0
|
362 |
+
if len(vr) > clip_len:
|
363 |
+
clip_step = (len(vr) - clip_len) // (self.num_clips - 1)
|
364 |
+
indices = indices + i * clip_step
|
365 |
+
|
366 |
+
clip_indices.append(indices)
|
367 |
+
all_indices.extend(list(indices))
|
368 |
+
|
369 |
+
buffer = vr.get_batch(all_indices).asnumpy()
|
370 |
+
return buffer, clip_indices
|
371 |
+
|
372 |
+
def __len__(self):
|
373 |
+
return len(self.samples)
|
src/hub/__init__.py
ADDED
File without changes
|
src/hub/backbones.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2"
|
9 |
+
|
10 |
+
# for testing
|
11 |
+
# VJEPA_BASE_URL = "http://localhost:8300"
|
12 |
+
|
13 |
+
ARCH_NAME_MAP = {
|
14 |
+
"vit_large": ("vit_large", "vitl"),
|
15 |
+
"vit_huge": ("vit_huge", "vith"),
|
16 |
+
"vit_giant": ("vit_giant_xformers", "vitg"),
|
17 |
+
"vit_ac_giant": ("vit_giant_xformers", "vjepa2-ac-vitg"),
|
18 |
+
"vit_giant_384": ("vit_giant_xformers", "vitg-384"),
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def _clean_backbone_key(state_dict):
|
23 |
+
for key, val in state_dict.copy().items():
|
24 |
+
_ = state_dict.pop(key)
|
25 |
+
key = key.replace("module.", "")
|
26 |
+
key = key.replace("backbone.", "")
|
27 |
+
state_dict[key] = val
|
28 |
+
return state_dict
|
29 |
+
|
30 |
+
|
31 |
+
def _make_vjepa2_ac_model(
|
32 |
+
*,
|
33 |
+
model_name: str = "vit_ac_giant",
|
34 |
+
img_size=256,
|
35 |
+
patch_size=16,
|
36 |
+
tubelet_size=2,
|
37 |
+
num_frames=64,
|
38 |
+
pretrained: bool = True,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
from ..models import ac_predictor as vit_ac_predictor
|
42 |
+
from ..models import vision_transformer as vit_encoder
|
43 |
+
|
44 |
+
vit_encoder_kwargs = dict(
|
45 |
+
patch_size=patch_size,
|
46 |
+
img_size=(img_size, img_size),
|
47 |
+
num_frames=num_frames,
|
48 |
+
tubelet_size=tubelet_size,
|
49 |
+
use_sdpa=True,
|
50 |
+
use_SiLU=False,
|
51 |
+
wide_SiLU=True,
|
52 |
+
uniform_power=False,
|
53 |
+
use_rope=True,
|
54 |
+
)
|
55 |
+
vit_encoder_kwargs.update(**kwargs)
|
56 |
+
|
57 |
+
arch_name = ARCH_NAME_MAP[model_name][0]
|
58 |
+
encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
|
59 |
+
|
60 |
+
vit_predictor_kwargs = dict(
|
61 |
+
img_size=(img_size, img_size),
|
62 |
+
patch_size=patch_size,
|
63 |
+
num_frames=num_frames,
|
64 |
+
tubelet_size=tubelet_size,
|
65 |
+
embed_dim=encoder.embed_dim,
|
66 |
+
)
|
67 |
+
vit_predictor_kwargs.update(**kwargs)
|
68 |
+
|
69 |
+
predictor = vit_ac_predictor.__dict__["vit_ac_predictor"](**vit_predictor_kwargs)
|
70 |
+
|
71 |
+
if pretrained:
|
72 |
+
model_file = ARCH_NAME_MAP[model_name][-1]
|
73 |
+
url = VJEPA_BASE_URL + f"/{model_file}.pt"
|
74 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
75 |
+
encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
|
76 |
+
encoder.load_state_dict(encoder_state_dict, strict=False)
|
77 |
+
predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
|
78 |
+
predictor.load_state_dict(predictor_state_dict, strict=True)
|
79 |
+
|
80 |
+
return encoder, predictor
|
81 |
+
|
82 |
+
|
83 |
+
def _make_vjepa2_model(
|
84 |
+
*,
|
85 |
+
model_name: str = "vit_large",
|
86 |
+
img_size=256,
|
87 |
+
patch_size=16,
|
88 |
+
tubelet_size=2,
|
89 |
+
num_frames=64,
|
90 |
+
pretrained: bool = True,
|
91 |
+
**kwargs,
|
92 |
+
):
|
93 |
+
from ..models import predictor as vit_predictor
|
94 |
+
from ..models import vision_transformer as vit_encoder
|
95 |
+
|
96 |
+
vit_encoder_kwargs = dict(
|
97 |
+
patch_size=patch_size,
|
98 |
+
img_size=(img_size, img_size),
|
99 |
+
num_frames=num_frames,
|
100 |
+
tubelet_size=tubelet_size,
|
101 |
+
use_sdpa=True,
|
102 |
+
use_SiLU=False,
|
103 |
+
wide_SiLU=True,
|
104 |
+
uniform_power=False,
|
105 |
+
use_rope=True,
|
106 |
+
)
|
107 |
+
vit_encoder_kwargs.update(**kwargs)
|
108 |
+
|
109 |
+
arch_name = ARCH_NAME_MAP[model_name][0]
|
110 |
+
encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
|
111 |
+
|
112 |
+
vit_predictor_kwargs = dict(
|
113 |
+
img_size=(img_size, img_size),
|
114 |
+
patch_size=patch_size,
|
115 |
+
use_mask_tokens=True,
|
116 |
+
embed_dim=encoder.embed_dim,
|
117 |
+
predictor_embed_dim=384,
|
118 |
+
num_frames=num_frames,
|
119 |
+
tubelet_size=tubelet_size,
|
120 |
+
depth=12,
|
121 |
+
num_heads=12,
|
122 |
+
num_mask_tokens=10,
|
123 |
+
use_rope=True,
|
124 |
+
uniform_power=False,
|
125 |
+
use_sdpa=True,
|
126 |
+
use_silu=False,
|
127 |
+
wide_silu=True,
|
128 |
+
)
|
129 |
+
vit_predictor_kwargs.update(**kwargs)
|
130 |
+
|
131 |
+
predictor = vit_predictor.__dict__["vit_predictor"](**vit_predictor_kwargs)
|
132 |
+
|
133 |
+
if pretrained:
|
134 |
+
model_file = ARCH_NAME_MAP[model_name][-1]
|
135 |
+
url = VJEPA_BASE_URL + f"/{model_file}.pt"
|
136 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
137 |
+
encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
|
138 |
+
encoder.load_state_dict(encoder_state_dict, strict=False) # state_dict has pos_embed but we use RoPE
|
139 |
+
predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
|
140 |
+
predictor.load_state_dict(predictor_state_dict, strict=False) # state_dict has pos_embed but we use RoPE
|
141 |
+
|
142 |
+
return encoder, predictor
|
143 |
+
|
144 |
+
|
145 |
+
def vjepa2_vit_large(*, pretrained: bool = True, **kwargs):
|
146 |
+
"""
|
147 |
+
VJEPA 2 ViT-Large model
|
148 |
+
"""
|
149 |
+
return _make_vjepa2_model(model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs)
|
150 |
+
|
151 |
+
|
152 |
+
def vjepa2_vit_huge(*, pretrained: bool = True, **kwargs):
|
153 |
+
"""
|
154 |
+
VJEPA 2 ViT-Huge model
|
155 |
+
"""
|
156 |
+
return _make_vjepa2_model(model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs)
|
157 |
+
|
158 |
+
|
159 |
+
def vjepa2_vit_giant(*, pretrained: bool = True, **kwargs):
|
160 |
+
"""
|
161 |
+
VJEPA 2 ViT-giant model
|
162 |
+
"""
|
163 |
+
return _make_vjepa2_model(model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs)
|
164 |
+
|
165 |
+
|
166 |
+
def vjepa2_vit_giant_384(*, pretrained: bool = True, **kwargs):
|
167 |
+
"""
|
168 |
+
VJEPA 2 ViT-giant-384 model
|
169 |
+
"""
|
170 |
+
return _make_vjepa2_model(model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs)
|
171 |
+
|
172 |
+
|
173 |
+
def vjepa2_ac_vit_giant(*, pretrained: bool = True, **kwargs):
|
174 |
+
"""
|
175 |
+
VJEPA 2-AC ViT-giant model
|
176 |
+
"""
|
177 |
+
return _make_vjepa2_ac_model(model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs)
|
src/masks/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (923 Bytes). View file
|
|
src/masks/default.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from logging import getLogger
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
_GLOBAL_SEED = 0
|
11 |
+
logger = getLogger()
|
12 |
+
|
13 |
+
|
14 |
+
class DefaultCollator(object):
|
15 |
+
|
16 |
+
def __call__(self, batch):
|
17 |
+
collated_batch = torch.utils.data.default_collate(batch)
|
18 |
+
return collated_batch, None, None
|
src/masks/multiseq_multiblock3d.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from logging import getLogger
|
8 |
+
from multiprocessing import Value
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
_GLOBAL_SEED = 0
|
13 |
+
logger = getLogger()
|
14 |
+
|
15 |
+
|
16 |
+
class MaskCollator(object):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
cfgs_mask,
|
21 |
+
dataset_fpcs,
|
22 |
+
crop_size=(224, 224),
|
23 |
+
patch_size=(16, 16),
|
24 |
+
tubelet_size=2,
|
25 |
+
):
|
26 |
+
super(MaskCollator, self).__init__()
|
27 |
+
|
28 |
+
self.mask_generators = dict()
|
29 |
+
for fpc in dataset_fpcs:
|
30 |
+
self.mask_generators[fpc] = []
|
31 |
+
for m in cfgs_mask:
|
32 |
+
mask_generator = _MaskGenerator(
|
33 |
+
crop_size=crop_size,
|
34 |
+
num_frames=fpc,
|
35 |
+
spatial_patch_size=patch_size,
|
36 |
+
temporal_patch_size=tubelet_size,
|
37 |
+
spatial_pred_mask_scale=m.get("spatial_scale"),
|
38 |
+
temporal_pred_mask_scale=m.get("temporal_scale"),
|
39 |
+
aspect_ratio=m.get("aspect_ratio"),
|
40 |
+
npred=m.get("num_blocks"),
|
41 |
+
max_context_frames_ratio=m.get("max_temporal_keep", 1.0),
|
42 |
+
max_keep=m.get("max_keep", None),
|
43 |
+
full_complement=m.get("full_complement", False),
|
44 |
+
pred_full_complement=m.get("pred_full_complement", False),
|
45 |
+
inv_block=m.get("inv_block", False),
|
46 |
+
)
|
47 |
+
self.mask_generators[fpc].append(mask_generator)
|
48 |
+
|
49 |
+
def step(self):
|
50 |
+
for fpc in self.mask_generators:
|
51 |
+
for mask_generator in self.mask_generators[fpc]:
|
52 |
+
mask_generator.step()
|
53 |
+
|
54 |
+
def __call__(self, batch):
|
55 |
+
|
56 |
+
# Batch: [buffer, label, clip_indices]
|
57 |
+
filtered_batches = {fpc: [] for fpc in self.mask_generators}
|
58 |
+
for sample in batch:
|
59 |
+
fpc = len(sample[-1][-1])
|
60 |
+
filtered_batches[fpc] += [sample]
|
61 |
+
|
62 |
+
fpc_collations = []
|
63 |
+
for fpc in filtered_batches:
|
64 |
+
fpc_batch = filtered_batches[fpc]
|
65 |
+
batch_size = len(fpc_batch)
|
66 |
+
if batch_size == 0:
|
67 |
+
continue
|
68 |
+
collated_batch = torch.utils.data.default_collate(fpc_batch)
|
69 |
+
collated_masks_pred, collated_masks_enc = [], []
|
70 |
+
for i, mask_generator in enumerate(self.mask_generators[fpc]):
|
71 |
+
masks_enc, masks_pred = mask_generator(batch_size)
|
72 |
+
collated_masks_enc.append(masks_enc)
|
73 |
+
collated_masks_pred.append(masks_pred)
|
74 |
+
fpc_collations += [(collated_batch, collated_masks_enc, collated_masks_pred)]
|
75 |
+
|
76 |
+
return fpc_collations
|
77 |
+
|
78 |
+
|
79 |
+
class _MaskGenerator(object):
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
crop_size=(224, 224),
|
84 |
+
num_frames=16,
|
85 |
+
spatial_patch_size=(16, 16),
|
86 |
+
temporal_patch_size=2,
|
87 |
+
spatial_pred_mask_scale=(0.2, 0.8),
|
88 |
+
temporal_pred_mask_scale=(1.0, 1.0),
|
89 |
+
aspect_ratio=(0.3, 3.0),
|
90 |
+
npred=1,
|
91 |
+
max_context_frames_ratio=1.0,
|
92 |
+
max_keep=None,
|
93 |
+
inv_block=False,
|
94 |
+
full_complement=False,
|
95 |
+
pred_full_complement=False,
|
96 |
+
):
|
97 |
+
super(_MaskGenerator, self).__init__()
|
98 |
+
if not isinstance(crop_size, tuple):
|
99 |
+
crop_size = (crop_size,) * 2
|
100 |
+
if not isinstance(spatial_patch_size, tuple):
|
101 |
+
spatial_patch_size = (spatial_patch_size,) * 2
|
102 |
+
self.crop_size = crop_size
|
103 |
+
self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)]
|
104 |
+
self.duration = num_frames // temporal_patch_size
|
105 |
+
self.full_complement = full_complement
|
106 |
+
self.pred_full_complement = pred_full_complement
|
107 |
+
|
108 |
+
self.spatial_patch_size = spatial_patch_size
|
109 |
+
self.temporal_patch_size = temporal_patch_size
|
110 |
+
|
111 |
+
self.aspect_ratio = aspect_ratio
|
112 |
+
self.spatial_pred_mask_scale = spatial_pred_mask_scale
|
113 |
+
self.temporal_pred_mask_scale = temporal_pred_mask_scale
|
114 |
+
self.npred = npred
|
115 |
+
self.max_context_duration = max(
|
116 |
+
1, int(self.duration * max_context_frames_ratio)
|
117 |
+
) # maximum number of time-steps (frames) spanned by context mask
|
118 |
+
self.max_keep = max_keep # maximum number of patches to keep in context
|
119 |
+
self._itr_counter = Value("i", -1) # collator is shared across worker processes
|
120 |
+
self.inv_block = inv_block
|
121 |
+
|
122 |
+
def step(self):
|
123 |
+
i = self._itr_counter
|
124 |
+
with i.get_lock():
|
125 |
+
i.value += 1
|
126 |
+
v = i.value
|
127 |
+
return v
|
128 |
+
|
129 |
+
def _sample_block_size(self, generator, temporal_scale, spatial_scale, aspect_ratio_scale):
|
130 |
+
# -- Sample temporal block mask scale
|
131 |
+
_rand = torch.rand(1, generator=generator).item()
|
132 |
+
min_t, max_t = temporal_scale
|
133 |
+
temporal_mask_scale = min_t + _rand * (max_t - min_t)
|
134 |
+
t = max(1, int(self.duration * temporal_mask_scale))
|
135 |
+
|
136 |
+
# -- Sample spatial block mask scale
|
137 |
+
_rand = torch.rand(1, generator=generator).item()
|
138 |
+
min_s, max_s = spatial_scale
|
139 |
+
spatial_mask_scale = min_s + _rand * (max_s - min_s)
|
140 |
+
spatial_num_keep = int(self.height * self.width * spatial_mask_scale)
|
141 |
+
|
142 |
+
# -- Sample block aspect-ratio
|
143 |
+
_rand = torch.rand(1, generator=generator).item()
|
144 |
+
min_ar, max_ar = aspect_ratio_scale
|
145 |
+
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
|
146 |
+
|
147 |
+
# -- Compute block height and width (given scale and aspect-ratio)
|
148 |
+
h = int(round(math.sqrt(spatial_num_keep * aspect_ratio)))
|
149 |
+
w = int(round(math.sqrt(spatial_num_keep / aspect_ratio)))
|
150 |
+
h = min(h, self.height)
|
151 |
+
w = min(w, self.width)
|
152 |
+
|
153 |
+
return (t, h, w)
|
154 |
+
|
155 |
+
def _sample_block_mask(self, b_size):
|
156 |
+
t, h, w = b_size
|
157 |
+
top = torch.randint(0, self.height - h + 1, (1,))
|
158 |
+
left = torch.randint(0, self.width - w + 1, (1,))
|
159 |
+
start = torch.randint(0, self.duration - t + 1, (1,))
|
160 |
+
|
161 |
+
mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
|
162 |
+
mask[start : start + t, top : top + h, left : left + w] = 0
|
163 |
+
|
164 |
+
# Context mask will only span the first X frames
|
165 |
+
# (X=self.max_context_frames)
|
166 |
+
if self.max_context_duration < self.duration:
|
167 |
+
mask[self.max_context_duration :, :, :] = 0
|
168 |
+
|
169 |
+
# --
|
170 |
+
return mask
|
171 |
+
|
172 |
+
def __call__(self, batch_size):
|
173 |
+
"""
|
174 |
+
Create encoder and predictor masks when collating imgs into a batch
|
175 |
+
# 1. sample pred block size using seed
|
176 |
+
# 2. sample several pred block locations for each image (w/o seed)
|
177 |
+
# 3. return pred masks and complement (enc mask)
|
178 |
+
"""
|
179 |
+
seed = self.step()
|
180 |
+
g = torch.Generator()
|
181 |
+
g.manual_seed(seed)
|
182 |
+
p_size = self._sample_block_size(
|
183 |
+
generator=g,
|
184 |
+
temporal_scale=self.temporal_pred_mask_scale,
|
185 |
+
spatial_scale=self.spatial_pred_mask_scale,
|
186 |
+
aspect_ratio_scale=self.aspect_ratio,
|
187 |
+
)
|
188 |
+
|
189 |
+
collated_masks_pred, collated_masks_enc = [], []
|
190 |
+
min_keep_enc = min_keep_pred = self.duration * self.height * self.width
|
191 |
+
for _ in range(batch_size):
|
192 |
+
|
193 |
+
empty_context = True
|
194 |
+
while empty_context:
|
195 |
+
|
196 |
+
mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
|
197 |
+
for _ in range(self.npred):
|
198 |
+
mask_e *= self._sample_block_mask(p_size)
|
199 |
+
mask_e = mask_e.flatten()
|
200 |
+
|
201 |
+
mask_p = torch.argwhere(mask_e == 0).squeeze()
|
202 |
+
mask_e = torch.nonzero(mask_e).squeeze()
|
203 |
+
|
204 |
+
empty_context = len(mask_e) == 0
|
205 |
+
if not empty_context:
|
206 |
+
min_keep_pred = min(min_keep_pred, len(mask_p))
|
207 |
+
min_keep_enc = min(min_keep_enc, len(mask_e))
|
208 |
+
collated_masks_pred.append(mask_p)
|
209 |
+
collated_masks_enc.append(mask_e)
|
210 |
+
|
211 |
+
if self.max_keep is not None:
|
212 |
+
min_keep_enc = min(min_keep_enc, self.max_keep)
|
213 |
+
|
214 |
+
collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc]
|
215 |
+
collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred]
|
216 |
+
if self.full_complement: # predictor mask is just complement of encoder mask
|
217 |
+
collated_masks_pred = [
|
218 |
+
torch.tensor(
|
219 |
+
sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
|
220 |
+
dtype=cm.dtype,
|
221 |
+
)
|
222 |
+
for cm in collated_masks_enc
|
223 |
+
]
|
224 |
+
elif self.pred_full_complement:
|
225 |
+
collated_masks_enc = [
|
226 |
+
torch.tensor(
|
227 |
+
sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
|
228 |
+
dtype=cm.dtype,
|
229 |
+
)
|
230 |
+
for cm in collated_masks_pred
|
231 |
+
]
|
232 |
+
|
233 |
+
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)
|
234 |
+
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
|
235 |
+
|
236 |
+
if self.inv_block:
|
237 |
+
return collated_masks_pred, collated_masks_enc # predict context from block
|
238 |
+
else:
|
239 |
+
return collated_masks_enc, collated_masks_pred
|
src/masks/utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def apply_masks(x, masks, concat=True):
|
10 |
+
"""
|
11 |
+
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
|
12 |
+
:param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
|
13 |
+
"""
|
14 |
+
all_x = []
|
15 |
+
for m in masks:
|
16 |
+
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
|
17 |
+
all_x += [torch.gather(x, dim=1, index=mask_keep)]
|
18 |
+
if not concat:
|
19 |
+
return all_x
|
20 |
+
|
21 |
+
return torch.cat(all_x, dim=0)
|
src/models/__pycache__/attentive_pooler.cpython-312.pyc
ADDED
Binary file (6.34 kB). View file
|
|
src/models/__pycache__/vision_transformer.cpython-312.pyc
ADDED
Binary file (16.2 kB). View file
|
|
src/models/ac_predictor.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from src.models.utils.modules import ACBlock as Block
|
13 |
+
from src.models.utils.modules import build_action_block_causal_attention_mask
|
14 |
+
from src.utils.tensors import trunc_normal_
|
15 |
+
|
16 |
+
|
17 |
+
class VisionTransformerPredictorAC(nn.Module):
|
18 |
+
"""Action Conditioned Vision Transformer Predictor"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
img_size=(224, 224),
|
23 |
+
patch_size=16,
|
24 |
+
num_frames=1,
|
25 |
+
tubelet_size=2,
|
26 |
+
embed_dim=768,
|
27 |
+
predictor_embed_dim=1024,
|
28 |
+
depth=24,
|
29 |
+
num_heads=16,
|
30 |
+
mlp_ratio=4.0,
|
31 |
+
qkv_bias=True,
|
32 |
+
qk_scale=None,
|
33 |
+
drop_rate=0.0,
|
34 |
+
attn_drop_rate=0.0,
|
35 |
+
drop_path_rate=0.0,
|
36 |
+
norm_layer=nn.LayerNorm,
|
37 |
+
init_std=0.02,
|
38 |
+
uniform_power=True,
|
39 |
+
use_silu=False,
|
40 |
+
wide_silu=True,
|
41 |
+
is_frame_causal=True,
|
42 |
+
use_activation_checkpointing=False,
|
43 |
+
use_rope=True,
|
44 |
+
action_embed_dim=7,
|
45 |
+
use_extrinsics=False,
|
46 |
+
**kwargs
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.is_frame_causal = is_frame_causal
|
50 |
+
self.use_extrinsics = use_extrinsics
|
51 |
+
|
52 |
+
# Map input to predictor dimension
|
53 |
+
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
54 |
+
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
55 |
+
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
56 |
+
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
57 |
+
|
58 |
+
# Determine positional embedding
|
59 |
+
if type(img_size) is int:
|
60 |
+
img_size = (img_size, img_size)
|
61 |
+
self.img_height, self.img_width = img_size
|
62 |
+
self.patch_size = patch_size
|
63 |
+
# --
|
64 |
+
self.num_frames = num_frames
|
65 |
+
self.tubelet_size = tubelet_size
|
66 |
+
self.is_video = num_frames > 1
|
67 |
+
|
68 |
+
self.grid_height = img_size[0] // self.patch_size
|
69 |
+
self.grid_width = img_size[1] // self.patch_size
|
70 |
+
self.use_activation_checkpointing = use_activation_checkpointing
|
71 |
+
|
72 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
73 |
+
|
74 |
+
# Position embedding
|
75 |
+
self.uniform_power = uniform_power
|
76 |
+
|
77 |
+
# Attention Blocks
|
78 |
+
self.use_rope = use_rope
|
79 |
+
self.predictor_blocks = nn.ModuleList(
|
80 |
+
[
|
81 |
+
Block(
|
82 |
+
use_rope=use_rope,
|
83 |
+
grid_size=self.grid_height,
|
84 |
+
dim=predictor_embed_dim,
|
85 |
+
num_heads=num_heads,
|
86 |
+
mlp_ratio=mlp_ratio,
|
87 |
+
qkv_bias=qkv_bias,
|
88 |
+
qk_scale=qk_scale,
|
89 |
+
drop=drop_rate,
|
90 |
+
act_layer=nn.SiLU if use_silu else nn.GELU,
|
91 |
+
wide_silu=wide_silu,
|
92 |
+
attn_drop=attn_drop_rate,
|
93 |
+
drop_path=dpr[i],
|
94 |
+
norm_layer=norm_layer,
|
95 |
+
)
|
96 |
+
for i in range(depth)
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
# Normalize & project back to input dimension
|
101 |
+
self.predictor_norm = norm_layer(predictor_embed_dim)
|
102 |
+
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
103 |
+
|
104 |
+
# ------ initialize weights
|
105 |
+
self.init_std = init_std
|
106 |
+
self.apply(self._init_weights)
|
107 |
+
self._rescale_blocks()
|
108 |
+
|
109 |
+
attn_mask = None
|
110 |
+
if self.is_frame_causal:
|
111 |
+
grid_depth = self.num_frames // self.tubelet_size
|
112 |
+
grid_height = self.img_height // self.patch_size
|
113 |
+
grid_width = self.img_width // self.patch_size
|
114 |
+
attn_mask = build_action_block_causal_attention_mask(
|
115 |
+
grid_depth, grid_height, grid_width, add_tokens=3 if use_extrinsics else 2
|
116 |
+
)
|
117 |
+
self.attn_mask = attn_mask
|
118 |
+
|
119 |
+
def _init_weights(self, m):
|
120 |
+
if isinstance(m, nn.Linear):
|
121 |
+
trunc_normal_(m.weight, std=self.init_std)
|
122 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
123 |
+
nn.init.constant_(m.bias, 0)
|
124 |
+
elif isinstance(m, nn.LayerNorm):
|
125 |
+
nn.init.constant_(m.bias, 0)
|
126 |
+
nn.init.constant_(m.weight, 1.0)
|
127 |
+
|
128 |
+
def _rescale_blocks(self):
|
129 |
+
def rescale(param, layer_id):
|
130 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
131 |
+
|
132 |
+
for layer_id, layer in enumerate(self.predictor_blocks):
|
133 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
134 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
135 |
+
|
136 |
+
def forward(self, x, actions, states, extrinsics=None):
|
137 |
+
"""
|
138 |
+
:param x: context tokens
|
139 |
+
"""
|
140 |
+
# Map tokens to pedictor dimensions
|
141 |
+
x = self.predictor_embed(x)
|
142 |
+
B, N_ctxt, D = x.size()
|
143 |
+
T = N_ctxt // (self.grid_height * self.grid_width)
|
144 |
+
|
145 |
+
# Interleave action tokens
|
146 |
+
s = self.state_encoder(states).unsqueeze(2)
|
147 |
+
a = self.action_encoder(actions).unsqueeze(2)
|
148 |
+
x = x.view(B, T, self.grid_height * self.grid_width, D) # [B, T, H*W, D]
|
149 |
+
if self.use_extrinsics:
|
150 |
+
e = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
151 |
+
x = torch.cat([a, s, e, x], dim=2).flatten(1, 2) # [B, T*(H*W+3), D]
|
152 |
+
else:
|
153 |
+
x = torch.cat([a, s, x], dim=2).flatten(1, 2) # [B, T*(H*W+2), D]
|
154 |
+
|
155 |
+
cond_tokens = 3 if self.use_extrinsics else 2
|
156 |
+
attn_mask = self.attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
157 |
+
|
158 |
+
# Fwd prop
|
159 |
+
for i, blk in enumerate(self.predictor_blocks):
|
160 |
+
if self.use_activation_checkpointing:
|
161 |
+
x = torch.utils.checkpoint.checkpoint(
|
162 |
+
blk,
|
163 |
+
x,
|
164 |
+
mask=None,
|
165 |
+
attn_mask=attn_mask,
|
166 |
+
T=T,
|
167 |
+
H=self.grid_height,
|
168 |
+
W=self.grid_width,
|
169 |
+
action_tokens=cond_tokens,
|
170 |
+
use_reentrant=False,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
x = blk(
|
174 |
+
x,
|
175 |
+
mask=None,
|
176 |
+
attn_mask=attn_mask,
|
177 |
+
T=T,
|
178 |
+
H=self.grid_height,
|
179 |
+
W=self.grid_width,
|
180 |
+
action_tokens=cond_tokens,
|
181 |
+
)
|
182 |
+
|
183 |
+
# Split out action and frame tokens
|
184 |
+
x = x.view(B, T, cond_tokens + self.grid_height * self.grid_width, D) # [B, T, K+H*W, D]
|
185 |
+
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
186 |
+
|
187 |
+
x = self.predictor_norm(x)
|
188 |
+
x = self.predictor_proj(x)
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
def vit_ac_predictor(**kwargs):
|
194 |
+
model = VisionTransformerPredictorAC(
|
195 |
+
mlp_ratio=4,
|
196 |
+
qkv_bias=True,
|
197 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
198 |
+
**kwargs
|
199 |
+
)
|
200 |
+
return model
|
src/models/attentive_pooler.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock
|
13 |
+
from src.utils.tensors import trunc_normal_
|
14 |
+
|
15 |
+
|
16 |
+
class AttentivePooler(nn.Module):
|
17 |
+
"""Attentive Pooler"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
num_queries=1,
|
22 |
+
embed_dim=768,
|
23 |
+
num_heads=12,
|
24 |
+
mlp_ratio=4.0,
|
25 |
+
depth=1,
|
26 |
+
norm_layer=nn.LayerNorm,
|
27 |
+
init_std=0.02,
|
28 |
+
qkv_bias=True,
|
29 |
+
complete_block=True,
|
30 |
+
use_activation_checkpointing=False,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.use_activation_checkpointing = use_activation_checkpointing
|
34 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
|
35 |
+
|
36 |
+
self.complete_block = complete_block
|
37 |
+
if complete_block:
|
38 |
+
self.cross_attention_block = CrossAttentionBlock(
|
39 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
43 |
+
|
44 |
+
self.blocks = None
|
45 |
+
if depth > 1:
|
46 |
+
self.blocks = nn.ModuleList(
|
47 |
+
[
|
48 |
+
Block(
|
49 |
+
dim=embed_dim,
|
50 |
+
num_heads=num_heads,
|
51 |
+
mlp_ratio=mlp_ratio,
|
52 |
+
qkv_bias=qkv_bias,
|
53 |
+
qk_scale=False,
|
54 |
+
norm_layer=norm_layer,
|
55 |
+
)
|
56 |
+
for i in range(depth - 1)
|
57 |
+
]
|
58 |
+
)
|
59 |
+
|
60 |
+
self.init_std = init_std
|
61 |
+
trunc_normal_(self.query_tokens, std=self.init_std)
|
62 |
+
self.apply(self._init_weights)
|
63 |
+
self._rescale_blocks()
|
64 |
+
|
65 |
+
def _rescale_blocks(self):
|
66 |
+
def rescale(param, layer_id):
|
67 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
68 |
+
|
69 |
+
layer_id = 0
|
70 |
+
if self.blocks is not None:
|
71 |
+
for layer_id, layer in enumerate(self.blocks):
|
72 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
73 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
74 |
+
|
75 |
+
if self.complete_block:
|
76 |
+
rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1)
|
77 |
+
|
78 |
+
def _init_weights(self, m):
|
79 |
+
if isinstance(m, nn.Linear):
|
80 |
+
trunc_normal_(m.weight, std=self.init_std)
|
81 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
82 |
+
nn.init.constant_(m.bias, 0)
|
83 |
+
elif isinstance(m, nn.LayerNorm):
|
84 |
+
nn.init.constant_(m.bias, 0)
|
85 |
+
nn.init.constant_(m.weight, 1.0)
|
86 |
+
elif isinstance(m, nn.Conv2d):
|
87 |
+
trunc_normal_(m.weight, std=self.init_std)
|
88 |
+
if m.bias is not None:
|
89 |
+
nn.init.constant_(m.bias, 0)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if self.blocks is not None:
|
93 |
+
for blk in self.blocks:
|
94 |
+
if self.use_activation_checkpointing:
|
95 |
+
x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False)
|
96 |
+
else:
|
97 |
+
x = blk(x)
|
98 |
+
q = self.query_tokens.repeat(len(x), 1, 1)
|
99 |
+
q = self.cross_attention_block(q, x)
|
100 |
+
return q
|
101 |
+
|
102 |
+
|
103 |
+
class AttentiveClassifier(nn.Module):
|
104 |
+
"""Attentive Classifier"""
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
embed_dim=768,
|
109 |
+
num_heads=12,
|
110 |
+
mlp_ratio=4.0,
|
111 |
+
depth=1,
|
112 |
+
norm_layer=nn.LayerNorm,
|
113 |
+
init_std=0.02,
|
114 |
+
qkv_bias=True,
|
115 |
+
num_classes=1000,
|
116 |
+
complete_block=True,
|
117 |
+
use_activation_checkpointing=False,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.pooler = AttentivePooler(
|
121 |
+
num_queries=1,
|
122 |
+
embed_dim=embed_dim,
|
123 |
+
num_heads=num_heads,
|
124 |
+
mlp_ratio=mlp_ratio,
|
125 |
+
depth=depth,
|
126 |
+
norm_layer=norm_layer,
|
127 |
+
init_std=init_std,
|
128 |
+
qkv_bias=qkv_bias,
|
129 |
+
complete_block=complete_block,
|
130 |
+
use_activation_checkpointing=use_activation_checkpointing,
|
131 |
+
)
|
132 |
+
self.linear = nn.Linear(embed_dim, num_classes, bias=True)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
x = self.pooler(x).squeeze(1)
|
136 |
+
x = self.linear(x)
|
137 |
+
return x
|
src/models/predictor.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from src.masks.utils import apply_masks
|
13 |
+
from src.models.utils.modules import Block
|
14 |
+
from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
|
15 |
+
from src.utils.tensors import repeat_interleave_batch, trunc_normal_
|
16 |
+
|
17 |
+
|
18 |
+
class VisionTransformerPredictor(nn.Module):
|
19 |
+
"""Vision Transformer"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
img_size=(224, 224),
|
24 |
+
patch_size=16,
|
25 |
+
num_frames=1,
|
26 |
+
tubelet_size=2,
|
27 |
+
embed_dim=768,
|
28 |
+
predictor_embed_dim=384,
|
29 |
+
depth=6,
|
30 |
+
num_heads=12,
|
31 |
+
mlp_ratio=4.0,
|
32 |
+
qkv_bias=True,
|
33 |
+
qk_scale=None,
|
34 |
+
drop_rate=0.0,
|
35 |
+
attn_drop_rate=0.0,
|
36 |
+
drop_path_rate=0.0,
|
37 |
+
norm_layer=nn.LayerNorm,
|
38 |
+
init_std=0.02,
|
39 |
+
uniform_power=False,
|
40 |
+
use_mask_tokens=False,
|
41 |
+
num_mask_tokens=2,
|
42 |
+
zero_init_mask_tokens=True,
|
43 |
+
use_silu=False,
|
44 |
+
wide_silu=True,
|
45 |
+
use_activation_checkpointing=False,
|
46 |
+
return_all_tokens=False,
|
47 |
+
chop_last_n_tokens=0,
|
48 |
+
use_rope=False,
|
49 |
+
**kwargs
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
self.return_all_tokens = return_all_tokens
|
53 |
+
self.chop_last_n_tokens = chop_last_n_tokens
|
54 |
+
|
55 |
+
# Map input to predictor dimension
|
56 |
+
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
57 |
+
|
58 |
+
# Mask tokens
|
59 |
+
self.mask_tokens = None
|
60 |
+
self.num_mask_tokens = 0
|
61 |
+
if use_mask_tokens:
|
62 |
+
self.num_mask_tokens = num_mask_tokens
|
63 |
+
self.mask_tokens = nn.ParameterList(
|
64 |
+
[nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) for i in range(num_mask_tokens)]
|
65 |
+
)
|
66 |
+
|
67 |
+
# Determine positional embedding
|
68 |
+
if type(img_size) is int:
|
69 |
+
img_size = (img_size, img_size)
|
70 |
+
self.img_height, self.img_width = img_size
|
71 |
+
self.patch_size = patch_size
|
72 |
+
# --
|
73 |
+
self.num_frames = num_frames
|
74 |
+
self.tubelet_size = tubelet_size
|
75 |
+
self.is_video = num_frames > 1
|
76 |
+
|
77 |
+
self.grid_height = img_size[0] // self.patch_size
|
78 |
+
self.grid_width = img_size[1] // self.patch_size
|
79 |
+
self.grid_depth = num_frames // self.tubelet_size
|
80 |
+
self.use_activation_checkpointing = use_activation_checkpointing
|
81 |
+
|
82 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
83 |
+
|
84 |
+
if self.is_video:
|
85 |
+
self.num_patches = num_patches = (
|
86 |
+
(num_frames // tubelet_size) * (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
self.num_patches = num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
90 |
+
# Position embedding
|
91 |
+
self.uniform_power = uniform_power
|
92 |
+
|
93 |
+
self.predictor_pos_embed = None
|
94 |
+
if not use_rope:
|
95 |
+
self.predictor_pos_embed = nn.Parameter(
|
96 |
+
torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False
|
97 |
+
)
|
98 |
+
|
99 |
+
# Attention Blocks
|
100 |
+
self.use_rope = use_rope
|
101 |
+
self.predictor_blocks = nn.ModuleList(
|
102 |
+
[
|
103 |
+
Block(
|
104 |
+
use_rope=use_rope,
|
105 |
+
grid_size=self.grid_height,
|
106 |
+
grid_depth=self.grid_depth,
|
107 |
+
dim=predictor_embed_dim,
|
108 |
+
num_heads=num_heads,
|
109 |
+
mlp_ratio=mlp_ratio,
|
110 |
+
qkv_bias=qkv_bias,
|
111 |
+
qk_scale=qk_scale,
|
112 |
+
drop=drop_rate,
|
113 |
+
act_layer=nn.SiLU if use_silu else nn.GELU,
|
114 |
+
wide_silu=wide_silu,
|
115 |
+
attn_drop=attn_drop_rate,
|
116 |
+
drop_path=dpr[i],
|
117 |
+
norm_layer=norm_layer,
|
118 |
+
)
|
119 |
+
for i in range(depth)
|
120 |
+
]
|
121 |
+
)
|
122 |
+
|
123 |
+
# Normalize & project back to input dimension
|
124 |
+
self.predictor_norm = norm_layer(predictor_embed_dim)
|
125 |
+
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
126 |
+
|
127 |
+
# ------ initialize weights
|
128 |
+
if self.predictor_pos_embed is not None:
|
129 |
+
self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed
|
130 |
+
self.init_std = init_std
|
131 |
+
if not zero_init_mask_tokens:
|
132 |
+
for mt in self.mask_tokens:
|
133 |
+
trunc_normal_(mt, std=init_std)
|
134 |
+
self.apply(self._init_weights)
|
135 |
+
self._rescale_blocks()
|
136 |
+
|
137 |
+
def _init_pos_embed(self, pos_embed):
|
138 |
+
embed_dim = pos_embed.size(-1)
|
139 |
+
grid_size = self.img_height // self.patch_size # TODO: update; currently assumes square input
|
140 |
+
if self.is_video:
|
141 |
+
grid_depth = self.num_frames // self.tubelet_size
|
142 |
+
sincos = get_3d_sincos_pos_embed(
|
143 |
+
embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=self.uniform_power
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)
|
147 |
+
pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))
|
148 |
+
|
149 |
+
def _init_weights(self, m):
|
150 |
+
if isinstance(m, nn.Linear):
|
151 |
+
trunc_normal_(m.weight, std=self.init_std)
|
152 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
153 |
+
nn.init.constant_(m.bias, 0)
|
154 |
+
elif isinstance(m, nn.LayerNorm):
|
155 |
+
nn.init.constant_(m.bias, 0)
|
156 |
+
nn.init.constant_(m.weight, 1.0)
|
157 |
+
|
158 |
+
def _rescale_blocks(self):
|
159 |
+
def rescale(param, layer_id):
|
160 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
161 |
+
|
162 |
+
for layer_id, layer in enumerate(self.predictor_blocks):
|
163 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
164 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
165 |
+
|
166 |
+
def forward(self, x, masks_x, masks_y, mask_index=1, has_cls=False):
|
167 |
+
"""
|
168 |
+
:param x: context tokens
|
169 |
+
:param masks_x: indices of context tokens in input
|
170 |
+
:params masks_y: indices of target tokens in input
|
171 |
+
"""
|
172 |
+
assert (masks_x is not None) and (masks_y is not None), "Cannot run predictor without mask indices"
|
173 |
+
if not isinstance(masks_x, list):
|
174 |
+
masks_x = [masks_x]
|
175 |
+
if not isinstance(masks_y, list):
|
176 |
+
masks_y = [masks_y]
|
177 |
+
|
178 |
+
# Batch Size
|
179 |
+
B = len(x) // len(masks_x)
|
180 |
+
|
181 |
+
# Map context tokens to pedictor dimensions
|
182 |
+
x = self.predictor_embed(x)
|
183 |
+
if has_cls:
|
184 |
+
x_cls = x[:, :1, :]
|
185 |
+
x = x[:, 1:, :]
|
186 |
+
_, N_ctxt, D = x.shape
|
187 |
+
|
188 |
+
# Add positional embedding to ctxt tokens
|
189 |
+
if not self.use_rope:
|
190 |
+
x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
|
191 |
+
x += apply_masks(x_pos_embed, masks_x)
|
192 |
+
|
193 |
+
# Make target tokens
|
194 |
+
mask_index = mask_index % self.num_mask_tokens
|
195 |
+
pred_tokens = self.mask_tokens[mask_index]
|
196 |
+
pred_tokens = pred_tokens.repeat(B, self.num_patches, 1)
|
197 |
+
pred_tokens = apply_masks(pred_tokens, masks_y)
|
198 |
+
# -- add pos embed
|
199 |
+
if not self.use_rope:
|
200 |
+
pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
|
201 |
+
pos_embs = apply_masks(pos_embs, masks_y)
|
202 |
+
pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
|
203 |
+
pred_tokens += pos_embs
|
204 |
+
|
205 |
+
# Concatenate context & target tokens
|
206 |
+
x = x.repeat(len(masks_x), 1, 1)
|
207 |
+
x = torch.cat([x, pred_tokens], dim=1)
|
208 |
+
|
209 |
+
# Positions of context & target tokens
|
210 |
+
masks_x = torch.cat(masks_x, dim=0)
|
211 |
+
masks_y = torch.cat(masks_y, dim=0)
|
212 |
+
masks = torch.cat([masks_x, masks_y], dim=1)
|
213 |
+
|
214 |
+
# Put tokens in sorted order
|
215 |
+
argsort = torch.argsort(masks, dim=1) # [B, N]
|
216 |
+
masks = torch.stack([masks[i, row] for i, row in enumerate(argsort)], dim=0)
|
217 |
+
x = torch.stack([x[i, row, :] for i, row in enumerate(argsort)], dim=0)
|
218 |
+
|
219 |
+
# Remove the last n tokens of sorted sequence before processing
|
220 |
+
if self.chop_last_n_tokens > 0:
|
221 |
+
x = x[:, : -self.chop_last_n_tokens]
|
222 |
+
masks = masks[:, : -self.chop_last_n_tokens]
|
223 |
+
|
224 |
+
if has_cls:
|
225 |
+
x = torch.cat([x_cls, x], dim=1)
|
226 |
+
|
227 |
+
# Fwd prop
|
228 |
+
for i, blk in enumerate(self.predictor_blocks):
|
229 |
+
if self.use_activation_checkpointing:
|
230 |
+
x = torch.utils.checkpoint.checkpoint(blk, x, masks, None, use_reentrant=False)
|
231 |
+
else:
|
232 |
+
x = blk(x, mask=masks, attn_mask=None)
|
233 |
+
x = self.predictor_norm(x)
|
234 |
+
|
235 |
+
if has_cls:
|
236 |
+
x = x[:, 1:, :]
|
237 |
+
|
238 |
+
# Return output corresponding to target tokens
|
239 |
+
if not self.return_all_tokens:
|
240 |
+
reverse_argsort = torch.argsort(argsort, dim=1) # [B, N]
|
241 |
+
x = torch.stack([x[i, row, :] for i, row in enumerate(reverse_argsort)], dim=0)
|
242 |
+
x = x[:, N_ctxt:]
|
243 |
+
|
244 |
+
x = self.predictor_proj(x)
|
245 |
+
|
246 |
+
return x
|
247 |
+
|
248 |
+
|
249 |
+
def vit_predictor(**kwargs):
|
250 |
+
model = VisionTransformerPredictor(
|
251 |
+
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
|
252 |
+
)
|
253 |
+
return model
|
src/models/utils/__pycache__/modules.cpython-312.pyc
ADDED
Binary file (30.1 kB). View file
|
|
src/models/utils/__pycache__/patch_embed.cpython-312.pyc
ADDED
Binary file (2.31 kB). View file
|
|
src/models/utils/__pycache__/pos_embs.cpython-312.pyc
ADDED
Binary file (4.2 kB). View file
|
|
src/models/utils/modules.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from timm.models.layers import drop_path
|
10 |
+
|
11 |
+
|
12 |
+
def build_action_block_causal_attention_mask(T, H, W, add_tokens=1):
|
13 |
+
N_T = add_tokens + (H * W)
|
14 |
+
N = T * N_T
|
15 |
+
mask = torch.zeros(N, N).bool()
|
16 |
+
mask_block = torch.ones(N_T, N_T).bool()
|
17 |
+
local_window_time = T
|
18 |
+
|
19 |
+
for t1 in range(T):
|
20 |
+
for t2 in range(max(0, t1 - local_window_time + 1), t1 + 1):
|
21 |
+
mask[t1 * N_T : (t1 + 1) * N_T, t2 * N_T : (t2 + 1) * N_T] = mask_block
|
22 |
+
|
23 |
+
return mask
|
24 |
+
|
25 |
+
|
26 |
+
def rotate_queries_or_keys(x, pos):
|
27 |
+
B, num_heads, N, D = x.size()
|
28 |
+
assert D % 2 == 0, "Embedding dimension must be a multiple of 2 for block matrix rotation"
|
29 |
+
|
30 |
+
# -- compute angle for each position
|
31 |
+
omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
|
32 |
+
omega /= D / 2.0
|
33 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
34 |
+
freq = torch.einsum("..., f -> ... f", pos, omega) # (..., N, D/2), outer product
|
35 |
+
|
36 |
+
# -- build rotation matrix and apply
|
37 |
+
emb_sin = freq.sin() # (..., N, D/2)
|
38 |
+
emb_cos = freq.cos() # (..., N, D/2)
|
39 |
+
|
40 |
+
emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
|
41 |
+
emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
|
42 |
+
|
43 |
+
# --
|
44 |
+
y = x.unflatten(-1, (-1, 2))
|
45 |
+
y1, y2 = y.unbind(
|
46 |
+
dim=-1,
|
47 |
+
)
|
48 |
+
y = torch.stack((-y2, y1), dim=-1)
|
49 |
+
y = y.flatten(-2)
|
50 |
+
return (x * emb_cos) + (y * emb_sin)
|
51 |
+
|
52 |
+
|
53 |
+
class DropPath(nn.Module):
|
54 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
55 |
+
|
56 |
+
def __init__(self, drop_prob=None):
|
57 |
+
super(DropPath, self).__init__()
|
58 |
+
self.drop_prob = drop_prob
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
return drop_path(x, self.drop_prob, self.training)
|
62 |
+
|
63 |
+
def extra_repr(self) -> str:
|
64 |
+
return "p={}".format(self.drop_prob)
|
65 |
+
|
66 |
+
|
67 |
+
class MLP(nn.Module):
|
68 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
69 |
+
super().__init__()
|
70 |
+
out_features = out_features or in_features
|
71 |
+
hidden_features = hidden_features or in_features
|
72 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
73 |
+
self.act = act_layer()
|
74 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
75 |
+
self.drop = nn.Dropout(drop)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
x = self.fc1(x)
|
79 |
+
x = self.act(x)
|
80 |
+
x = self.drop(x)
|
81 |
+
x = self.fc2(x)
|
82 |
+
x = self.drop(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class SwiGLUFFN(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, wide_silu=True
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
out_features = out_features or in_features
|
92 |
+
swiglu_hidden_features = hidden_features = hidden_features or in_features
|
93 |
+
if wide_silu:
|
94 |
+
swiglu_hidden_features = int(2 * hidden_features / 3)
|
95 |
+
align_as = 8
|
96 |
+
swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
|
97 |
+
self.fc1 = nn.Linear(in_features, swiglu_hidden_features)
|
98 |
+
self.fc2 = nn.Linear(in_features, swiglu_hidden_features)
|
99 |
+
self.act = act_layer()
|
100 |
+
self.fc3 = nn.Linear(swiglu_hidden_features, out_features)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
x1 = self.fc1(x)
|
104 |
+
x2 = self.fc2(x)
|
105 |
+
hidden = F.silu(x1) * x2
|
106 |
+
return self.fc3(hidden)
|
107 |
+
|
108 |
+
|
109 |
+
class ACRoPEAttention(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
dim,
|
113 |
+
num_heads=8,
|
114 |
+
qkv_bias=False,
|
115 |
+
qk_scale=None,
|
116 |
+
attn_drop=0.0,
|
117 |
+
proj_drop=0.0,
|
118 |
+
use_sdpa=True,
|
119 |
+
is_causal=False,
|
120 |
+
grid_size=16,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.num_heads = num_heads
|
124 |
+
self.head_dim = head_dim = dim // num_heads
|
125 |
+
self.scale = qk_scale or head_dim**-0.5
|
126 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
127 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
128 |
+
self.proj = nn.Linear(dim, dim)
|
129 |
+
self.proj_drop_prob = proj_drop
|
130 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
131 |
+
self.use_sdpa = use_sdpa
|
132 |
+
# --
|
133 |
+
self.d_dim = int(2 * ((head_dim // 3) // 2))
|
134 |
+
self.h_dim = int(2 * ((head_dim // 3) // 2))
|
135 |
+
self.w_dim = int(2 * ((head_dim // 3) // 2))
|
136 |
+
self.grid_size = grid_size
|
137 |
+
self.is_causal = is_causal
|
138 |
+
|
139 |
+
def _get_frame_pos(self, ids, H_patches, W_patches):
|
140 |
+
tokens_per_frame = int(H_patches * W_patches)
|
141 |
+
return ids // tokens_per_frame
|
142 |
+
|
143 |
+
def _get_height_pos(self, ids, H_patches, W_patches):
|
144 |
+
# Remove frame component from ids
|
145 |
+
tokens_per_frame = int(H_patches * W_patches)
|
146 |
+
tokens_per_row = W_patches
|
147 |
+
frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
|
148 |
+
ids = ids - tokens_per_frame * frame_ids
|
149 |
+
# --
|
150 |
+
return ids // tokens_per_row
|
151 |
+
|
152 |
+
def separate_positions(self, ids, H_patches, W_patches):
|
153 |
+
tokens_per_frame = int(H_patches * W_patches)
|
154 |
+
tokens_per_row = W_patches
|
155 |
+
frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
|
156 |
+
# --
|
157 |
+
height_ids = self._get_height_pos(ids, H_patches, W_patches)
|
158 |
+
# --
|
159 |
+
# Remove frame component from ids (1st term) and height component (2nd term)
|
160 |
+
width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
|
161 |
+
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
162 |
+
|
163 |
+
def forward(self, x, mask=None, attn_mask=None, T=None, H=None, W=None, action_tokens=0):
|
164 |
+
B, N, C = x.size()
|
165 |
+
|
166 |
+
# -- compute position of each frame token
|
167 |
+
if mask is not None:
|
168 |
+
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
169 |
+
d_mask, h_mask, w_mask = self.separate_positions(mask, H, W)
|
170 |
+
else:
|
171 |
+
mask = torch.arange(int(T * H * W), device=x.device)
|
172 |
+
d_mask, h_mask, w_mask = self.separate_positions(mask, H, W)
|
173 |
+
|
174 |
+
# -- snap spatial positions to grid size
|
175 |
+
h_mask *= self.grid_size / H
|
176 |
+
w_mask *= self.grid_size / W
|
177 |
+
|
178 |
+
# -- split out action tokens from sequence
|
179 |
+
if action_tokens > 0:
|
180 |
+
x = x.view(B, -1, action_tokens + H * W, C) # [B, T, 1+H*W, D]
|
181 |
+
|
182 |
+
action_q, action_k, action_v = [], [], []
|
183 |
+
for i in range(action_tokens):
|
184 |
+
a = x[:, :, i : i + 1, :].flatten(1, 2)
|
185 |
+
# Note action tokens do not work with masking
|
186 |
+
# -- compute qkv for action tokens and rotate
|
187 |
+
qkv = self.qkv(a).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
188 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
|
189 |
+
# --
|
190 |
+
qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(T, device=x.device))
|
191 |
+
kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(T, device=x.device))
|
192 |
+
qr = q[..., self.d_dim :]
|
193 |
+
kr = k[..., self.d_dim :]
|
194 |
+
action_q += [torch.cat([qd, qr], dim=-1).view(B, self.num_heads, T, 1, -1)]
|
195 |
+
action_k += [torch.cat([kd, kr], dim=-1).view(B, self.num_heads, T, 1, -1)]
|
196 |
+
action_v += [v.view(B, self.num_heads, T, 1, -1)]
|
197 |
+
|
198 |
+
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
199 |
+
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
200 |
+
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
201 |
+
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
202 |
+
|
203 |
+
# -- compute qkv for frame tokens and rotate
|
204 |
+
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
205 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
|
206 |
+
|
207 |
+
s = 0
|
208 |
+
# Rotate depth
|
209 |
+
qd = rotate_queries_or_keys(q[..., s : s + self.d_dim], pos=d_mask)
|
210 |
+
kd = rotate_queries_or_keys(k[..., s : s + self.d_dim], pos=d_mask)
|
211 |
+
s += self.d_dim
|
212 |
+
# Rotate height dim
|
213 |
+
qh = rotate_queries_or_keys(q[..., s : s + self.h_dim], pos=h_mask)
|
214 |
+
kh = rotate_queries_or_keys(k[..., s : s + self.h_dim], pos=h_mask)
|
215 |
+
s += self.h_dim
|
216 |
+
# Rotate width dim
|
217 |
+
qw = rotate_queries_or_keys(q[..., s : s + self.w_dim], pos=w_mask)
|
218 |
+
kw = rotate_queries_or_keys(k[..., s : s + self.w_dim], pos=w_mask)
|
219 |
+
s += self.w_dim
|
220 |
+
|
221 |
+
# Combine rotated dimension
|
222 |
+
if s < self.head_dim:
|
223 |
+
qr = q[..., s:]
|
224 |
+
kr = k[..., s:]
|
225 |
+
q = torch.cat([qd, qh, qw, qr], dim=-1)
|
226 |
+
k = torch.cat([kd, kh, kw, kr], dim=-1)
|
227 |
+
else:
|
228 |
+
q = torch.cat([qd, qh, qw], dim=-1)
|
229 |
+
k = torch.cat([kd, kh, kw], dim=-1)
|
230 |
+
|
231 |
+
if action_tokens > 0:
|
232 |
+
|
233 |
+
def merge_(tx, ta):
|
234 |
+
"""tx, tx in [B, num_heads, N, D]"""
|
235 |
+
tx = tx.view(B, self.num_heads, T, H * W, -1) # [B, T, H*W, D]
|
236 |
+
ta = ta.view(B, self.num_heads, T, action_tokens, -1) # [B, T, A, D]
|
237 |
+
return torch.cat([ta, tx], dim=3).flatten(2, 3)
|
238 |
+
|
239 |
+
q = merge_(q, action_q)
|
240 |
+
k = merge_(k, action_k)
|
241 |
+
v = merge_(v, action_v)
|
242 |
+
|
243 |
+
if attn_mask is not None or self.use_sdpa:
|
244 |
+
with torch.backends.cuda.sdp_kernel():
|
245 |
+
x = F.scaled_dot_product_attention(
|
246 |
+
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
247 |
+
)
|
248 |
+
attn = None
|
249 |
+
else:
|
250 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
|
251 |
+
attn = attn.softmax(dim=-1)
|
252 |
+
attn = self.attn_drop(attn)
|
253 |
+
x = attn @ v
|
254 |
+
|
255 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
256 |
+
x = self.proj(x)
|
257 |
+
x = self.proj_drop(x)
|
258 |
+
return x
|
259 |
+
|
260 |
+
|
261 |
+
class RoPEAttention(nn.Module):
|
262 |
+
def __init__(
|
263 |
+
self,
|
264 |
+
dim,
|
265 |
+
num_heads=8,
|
266 |
+
qkv_bias=False,
|
267 |
+
qk_scale=None,
|
268 |
+
attn_drop=0.0,
|
269 |
+
proj_drop=0.0,
|
270 |
+
use_sdpa=True,
|
271 |
+
grid_size=14,
|
272 |
+
is_causal=False,
|
273 |
+
):
|
274 |
+
super().__init__()
|
275 |
+
self.num_heads = num_heads
|
276 |
+
self.head_dim = head_dim = dim // num_heads
|
277 |
+
self.scale = qk_scale or head_dim**-0.5
|
278 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
279 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
280 |
+
self.proj = nn.Linear(dim, dim)
|
281 |
+
self.proj_drop_prob = proj_drop
|
282 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
283 |
+
self.use_sdpa = use_sdpa
|
284 |
+
# --
|
285 |
+
self.d_dim = int(2 * ((head_dim // 3) // 2))
|
286 |
+
self.h_dim = int(2 * ((head_dim // 3) // 2))
|
287 |
+
self.w_dim = int(2 * ((head_dim // 3) // 2))
|
288 |
+
self.grid_size = grid_size
|
289 |
+
self.is_causal = is_causal
|
290 |
+
|
291 |
+
def _get_frame_pos(self, ids, H_patches=None, W_patches=None):
|
292 |
+
if H_patches is None or W_patches is None:
|
293 |
+
tokens_per_frame = int(self.grid_size * self.grid_size)
|
294 |
+
else:
|
295 |
+
tokens_per_frame = int(H_patches * W_patches)
|
296 |
+
return ids // tokens_per_frame
|
297 |
+
|
298 |
+
def _get_height_pos(self, ids, H_patches=None, W_patches=None):
|
299 |
+
# Remove frame component from ids
|
300 |
+
if H_patches is None or W_patches is None:
|
301 |
+
tokens_per_frame = int(self.grid_size * self.grid_size)
|
302 |
+
tokens_per_row = self.grid_size
|
303 |
+
else:
|
304 |
+
tokens_per_frame = int(H_patches * W_patches)
|
305 |
+
tokens_per_row = W_patches
|
306 |
+
frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
|
307 |
+
ids = ids - tokens_per_frame * frame_ids
|
308 |
+
# --
|
309 |
+
return ids // tokens_per_row
|
310 |
+
|
311 |
+
def separate_positions(self, ids, H_patches=None, W_patches=None):
|
312 |
+
if H_patches is None or W_patches is None:
|
313 |
+
tokens_per_frame = int(self.grid_size * self.grid_size)
|
314 |
+
tokens_per_row = self.grid_size
|
315 |
+
else:
|
316 |
+
tokens_per_frame = int(H_patches * W_patches)
|
317 |
+
tokens_per_row = W_patches
|
318 |
+
frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
|
319 |
+
# --
|
320 |
+
height_ids = self._get_height_pos(ids, H_patches, W_patches)
|
321 |
+
# --
|
322 |
+
# Remove frame component from ids (1st term) and height component (2nd term)
|
323 |
+
width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
|
324 |
+
return frame_ids, height_ids, width_ids
|
325 |
+
|
326 |
+
def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patches=None):
|
327 |
+
B, N, C = x.size()
|
328 |
+
grid_depth = int(N // (self.grid_size * self.grid_size))
|
329 |
+
|
330 |
+
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
331 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
|
332 |
+
|
333 |
+
if mask is not None:
|
334 |
+
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
335 |
+
d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches)
|
336 |
+
else:
|
337 |
+
if T is None or H_patches is None or W_patches is None:
|
338 |
+
mask = torch.arange(int(grid_depth * self.grid_size * self.grid_size), device=x.device)
|
339 |
+
else:
|
340 |
+
mask = torch.arange(int(T * H_patches * W_patches), device=x.device)
|
341 |
+
d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches)
|
342 |
+
|
343 |
+
s = 0
|
344 |
+
# Rotate depth
|
345 |
+
qd = rotate_queries_or_keys(q[..., s : s + self.d_dim], pos=d_mask)
|
346 |
+
kd = rotate_queries_or_keys(k[..., s : s + self.d_dim], pos=d_mask)
|
347 |
+
s += self.d_dim
|
348 |
+
# Rotate height dim
|
349 |
+
qh = rotate_queries_or_keys(q[..., s : s + self.h_dim], pos=h_mask)
|
350 |
+
kh = rotate_queries_or_keys(k[..., s : s + self.h_dim], pos=h_mask)
|
351 |
+
s += self.h_dim
|
352 |
+
# Rotate width dim
|
353 |
+
qw = rotate_queries_or_keys(q[..., s : s + self.w_dim], pos=w_mask)
|
354 |
+
kw = rotate_queries_or_keys(k[..., s : s + self.w_dim], pos=w_mask)
|
355 |
+
s += self.w_dim
|
356 |
+
|
357 |
+
# Combine rotated dimension
|
358 |
+
if s < self.head_dim:
|
359 |
+
qr = q[..., s:]
|
360 |
+
kr = k[..., s:]
|
361 |
+
q = torch.cat([qd, qh, qw, qr], dim=-1)
|
362 |
+
k = torch.cat([kd, kh, kw, kr], dim=-1)
|
363 |
+
else:
|
364 |
+
q = torch.cat([qd, qh, qw], dim=-1)
|
365 |
+
k = torch.cat([kd, kh, kw], dim=-1)
|
366 |
+
|
367 |
+
if attn_mask is not None or self.use_sdpa:
|
368 |
+
with torch.backends.cuda.sdp_kernel():
|
369 |
+
x = F.scaled_dot_product_attention(
|
370 |
+
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
371 |
+
)
|
372 |
+
attn = None
|
373 |
+
else:
|
374 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
|
375 |
+
attn = attn.softmax(dim=-1)
|
376 |
+
attn = self.attn_drop(attn)
|
377 |
+
x = attn @ v
|
378 |
+
|
379 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
380 |
+
x = self.proj(x)
|
381 |
+
x = self.proj_drop(x)
|
382 |
+
return x
|
383 |
+
|
384 |
+
|
385 |
+
class Attention(nn.Module):
|
386 |
+
def __init__(
|
387 |
+
self,
|
388 |
+
dim,
|
389 |
+
num_heads=8,
|
390 |
+
qkv_bias=False,
|
391 |
+
qk_scale=None,
|
392 |
+
attn_drop=0.0,
|
393 |
+
proj_drop=0.0,
|
394 |
+
use_sdpa=True,
|
395 |
+
is_causal=False,
|
396 |
+
):
|
397 |
+
super().__init__()
|
398 |
+
self.num_heads = num_heads
|
399 |
+
head_dim = dim // num_heads
|
400 |
+
self.scale = qk_scale or head_dim**-0.5
|
401 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
402 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
403 |
+
self.proj = nn.Linear(dim, dim)
|
404 |
+
self.proj_drop_prob = proj_drop
|
405 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
406 |
+
self.use_sdpa = use_sdpa
|
407 |
+
self.is_causal = is_causal
|
408 |
+
|
409 |
+
def forward(self, x, mask=None, attn_mask=None):
|
410 |
+
B, N, C = x.shape
|
411 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
412 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
|
413 |
+
|
414 |
+
if attn_mask is not None or self.use_sdpa:
|
415 |
+
with torch.backends.cuda.sdp_kernel():
|
416 |
+
x = F.scaled_dot_product_attention(
|
417 |
+
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
418 |
+
)
|
419 |
+
attn = None
|
420 |
+
else:
|
421 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
|
422 |
+
attn = attn.softmax(dim=-1)
|
423 |
+
attn = self.attn_drop(attn)
|
424 |
+
x = attn @ v
|
425 |
+
|
426 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
427 |
+
x = self.proj(x)
|
428 |
+
x = self.proj_drop(x)
|
429 |
+
return x
|
430 |
+
|
431 |
+
|
432 |
+
class ACBlock(nn.Module):
|
433 |
+
def __init__(
|
434 |
+
self,
|
435 |
+
dim,
|
436 |
+
num_heads,
|
437 |
+
mlp_ratio=4.0,
|
438 |
+
qkv_bias=False,
|
439 |
+
qk_scale=None,
|
440 |
+
drop=0.0,
|
441 |
+
attn_drop=0.0,
|
442 |
+
drop_path=0.0,
|
443 |
+
act_layer=nn.GELU,
|
444 |
+
wide_silu=True,
|
445 |
+
norm_layer=nn.LayerNorm,
|
446 |
+
use_sdpa=True,
|
447 |
+
is_causal=False,
|
448 |
+
grid_size=16,
|
449 |
+
use_rope=False,
|
450 |
+
**kwargs,
|
451 |
+
):
|
452 |
+
super().__init__()
|
453 |
+
self.norm1 = norm_layer(dim)
|
454 |
+
if use_rope:
|
455 |
+
self.attn = ACRoPEAttention(
|
456 |
+
dim,
|
457 |
+
num_heads=num_heads,
|
458 |
+
qkv_bias=qkv_bias,
|
459 |
+
qk_scale=qk_scale,
|
460 |
+
attn_drop=attn_drop,
|
461 |
+
use_sdpa=use_sdpa,
|
462 |
+
is_causal=is_causal,
|
463 |
+
grid_size=grid_size,
|
464 |
+
proj_drop=drop,
|
465 |
+
)
|
466 |
+
else:
|
467 |
+
self.attn = Attention(
|
468 |
+
dim,
|
469 |
+
num_heads=num_heads,
|
470 |
+
qkv_bias=qkv_bias,
|
471 |
+
qk_scale=qk_scale,
|
472 |
+
attn_drop=attn_drop,
|
473 |
+
use_sdpa=use_sdpa,
|
474 |
+
is_causal=is_causal,
|
475 |
+
proj_drop=drop,
|
476 |
+
)
|
477 |
+
|
478 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
479 |
+
self.norm2 = norm_layer(dim)
|
480 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
481 |
+
if act_layer is nn.SiLU:
|
482 |
+
self.mlp = SwiGLUFFN(
|
483 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, wide_silu=wide_silu, drop=drop
|
484 |
+
)
|
485 |
+
else:
|
486 |
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
487 |
+
|
488 |
+
def forward(self, x, mask=None, attn_mask=None, T=None, H=None, W=None, action_tokens=0):
|
489 |
+
y = self.norm1(x)
|
490 |
+
if isinstance(self.attn, ACRoPEAttention):
|
491 |
+
y = self.attn(y, mask=mask, attn_mask=attn_mask, T=T, H=H, W=W, action_tokens=action_tokens)
|
492 |
+
else:
|
493 |
+
y = self.attn(y, mask=mask, attn_mask=attn_mask)
|
494 |
+
x = x + self.drop_path(y)
|
495 |
+
y = self.norm2(x)
|
496 |
+
x = x + self.drop_path(self.mlp(y))
|
497 |
+
return x
|
498 |
+
|
499 |
+
|
500 |
+
class Block(nn.Module):
|
501 |
+
def __init__(
|
502 |
+
self,
|
503 |
+
dim,
|
504 |
+
num_heads,
|
505 |
+
mlp_ratio=4.0,
|
506 |
+
qkv_bias=False,
|
507 |
+
qk_scale=None,
|
508 |
+
drop=0.0,
|
509 |
+
attn_drop=0.0,
|
510 |
+
drop_path=0.0,
|
511 |
+
act_layer=nn.GELU,
|
512 |
+
wide_silu=True,
|
513 |
+
norm_layer=nn.LayerNorm,
|
514 |
+
use_sdpa=True,
|
515 |
+
is_causal=False,
|
516 |
+
grid_size=16,
|
517 |
+
use_rope=False,
|
518 |
+
**kwargs,
|
519 |
+
):
|
520 |
+
super().__init__()
|
521 |
+
self.norm1 = norm_layer(dim)
|
522 |
+
if use_rope:
|
523 |
+
self.attn = RoPEAttention(
|
524 |
+
dim,
|
525 |
+
num_heads=num_heads,
|
526 |
+
qkv_bias=qkv_bias,
|
527 |
+
qk_scale=qk_scale,
|
528 |
+
attn_drop=attn_drop,
|
529 |
+
use_sdpa=use_sdpa,
|
530 |
+
is_causal=is_causal,
|
531 |
+
grid_size=grid_size,
|
532 |
+
proj_drop=drop,
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
self.attn = Attention(
|
536 |
+
dim,
|
537 |
+
num_heads=num_heads,
|
538 |
+
qkv_bias=qkv_bias,
|
539 |
+
qk_scale=qk_scale,
|
540 |
+
attn_drop=attn_drop,
|
541 |
+
use_sdpa=use_sdpa,
|
542 |
+
is_causal=is_causal,
|
543 |
+
proj_drop=drop,
|
544 |
+
)
|
545 |
+
|
546 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
547 |
+
self.norm2 = norm_layer(dim)
|
548 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
549 |
+
if act_layer is nn.SiLU:
|
550 |
+
self.mlp = SwiGLUFFN(
|
551 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, wide_silu=wide_silu, drop=drop
|
552 |
+
)
|
553 |
+
else:
|
554 |
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
555 |
+
|
556 |
+
def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patches=None):
|
557 |
+
if isinstance(self.attn, RoPEAttention):
|
558 |
+
y = self.attn(self.norm1(x), mask=mask, attn_mask=attn_mask, T=T, H_patches=H_patches, W_patches=W_patches)
|
559 |
+
else:
|
560 |
+
y = self.attn(self.norm1(x), mask=mask, attn_mask=attn_mask)
|
561 |
+
x = x + self.drop_path(y)
|
562 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
563 |
+
return x
|
564 |
+
|
565 |
+
|
566 |
+
class CrossAttention(nn.Module):
|
567 |
+
def __init__(self, dim, num_heads=12, qkv_bias=False, use_sdpa=True):
|
568 |
+
super().__init__()
|
569 |
+
self.num_heads = num_heads
|
570 |
+
head_dim = dim // num_heads
|
571 |
+
self.scale = head_dim**-0.5
|
572 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
573 |
+
self.kv = nn.Linear(dim, int(dim * 2), bias=qkv_bias)
|
574 |
+
# self.proj = nn.Linear(dim, dim)
|
575 |
+
self.use_sdpa = use_sdpa
|
576 |
+
|
577 |
+
def forward(self, q, x):
|
578 |
+
B, n, C = q.shape
|
579 |
+
q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
580 |
+
|
581 |
+
B, N, C = x.shape
|
582 |
+
kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
583 |
+
k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head)
|
584 |
+
|
585 |
+
if self.use_sdpa:
|
586 |
+
with torch.backends.cuda.sdp_kernel():
|
587 |
+
q = F.scaled_dot_product_attention(q, k, v)
|
588 |
+
else:
|
589 |
+
xattn = (q @ k.transpose(-2, -1)) * self.scale
|
590 |
+
xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len)
|
591 |
+
q = xattn @ v
|
592 |
+
|
593 |
+
q = q.transpose(1, 2).reshape(B, n, C)
|
594 |
+
return q
|
595 |
+
|
596 |
+
|
597 |
+
class CrossAttentionBlock(nn.Module):
|
598 |
+
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
599 |
+
super().__init__()
|
600 |
+
self.norm1 = norm_layer(dim)
|
601 |
+
self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
602 |
+
self.norm2 = norm_layer(dim)
|
603 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
604 |
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
|
605 |
+
|
606 |
+
def forward(self, q, x):
|
607 |
+
y = self.xattn(q, self.norm1(x))
|
608 |
+
q = q + y
|
609 |
+
q = q + self.mlp(self.norm2(q))
|
610 |
+
return q
|
src/models/utils/patch_embed.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class PatchEmbed(nn.Module):
|
11 |
+
"""
|
12 |
+
Image to Patch Embedding
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
|
16 |
+
super().__init__()
|
17 |
+
self.patch_size = patch_size
|
18 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
B, C, H, W = x.shape
|
22 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed3D(nn.Module):
|
27 |
+
"""
|
28 |
+
Image to Patch Embedding
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
patch_size=16,
|
34 |
+
tubelet_size=2,
|
35 |
+
in_chans=3,
|
36 |
+
embed_dim=768,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.patch_size = patch_size
|
40 |
+
self.tubelet_size = tubelet_size
|
41 |
+
|
42 |
+
self.proj = nn.Conv3d(
|
43 |
+
in_channels=in_chans,
|
44 |
+
out_channels=embed_dim,
|
45 |
+
kernel_size=(tubelet_size, patch_size, patch_size),
|
46 |
+
stride=(tubelet_size, patch_size, patch_size),
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x, **kwargs):
|
50 |
+
B, C, T, H, W = x.shape
|
51 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
52 |
+
return x
|
src/models/utils/pos_embs.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False):
|
10 |
+
"""
|
11 |
+
grid_size: int of the grid height and width
|
12 |
+
grid_depth: int of the grid depth
|
13 |
+
returns:
|
14 |
+
pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token)
|
15 |
+
or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token)
|
16 |
+
"""
|
17 |
+
grid_d = np.arange(grid_depth, dtype=float)
|
18 |
+
grid_h = np.arange(grid_size, dtype=float)
|
19 |
+
grid_w = np.arange(grid_size, dtype=float)
|
20 |
+
grid_h, grid_d, grid_w = np.meshgrid(
|
21 |
+
grid_h, grid_d, grid_w
|
22 |
+
) # order of meshgrid is very important for indexing as [d,h,w]
|
23 |
+
|
24 |
+
if not uniform_power:
|
25 |
+
h_embed_dim = embed_dim // 4
|
26 |
+
w_embed_dim = embed_dim // 4
|
27 |
+
d_embed_dim = embed_dim // 2
|
28 |
+
else:
|
29 |
+
h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2)
|
30 |
+
|
31 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1)
|
32 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2)
|
33 |
+
emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3)
|
34 |
+
pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1)
|
35 |
+
pos_embed = pos_embed[:, :embed_dim]
|
36 |
+
if cls_token:
|
37 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
38 |
+
return pos_embed
|
39 |
+
|
40 |
+
|
41 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
42 |
+
"""
|
43 |
+
grid_size: int of the grid height and width
|
44 |
+
returns:
|
45 |
+
pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token)
|
46 |
+
or [1+grid_size*grid_size, embed_dim] (w/ cls_token)
|
47 |
+
"""
|
48 |
+
grid_h = np.arange(grid_size, dtype=float)
|
49 |
+
grid_w = np.arange(grid_size, dtype=float)
|
50 |
+
grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w]
|
51 |
+
|
52 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2)
|
53 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2)
|
54 |
+
pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
55 |
+
if cls_token:
|
56 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
57 |
+
return pos_embed
|
58 |
+
|
59 |
+
|
60 |
+
def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
61 |
+
"""
|
62 |
+
embed_dim: output dimension for each position
|
63 |
+
grid_size: int of the grid length
|
64 |
+
returns:
|
65 |
+
pos_embed: [grid_size, embed_dim] (w/o cls_token)
|
66 |
+
or [1+grid_size, embed_dim] (w/ cls_token)
|
67 |
+
"""
|
68 |
+
grid = np.arange(grid_size, dtype=float)
|
69 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
|
70 |
+
if cls_token:
|
71 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
72 |
+
return pos_embed
|
73 |
+
|
74 |
+
|
75 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
76 |
+
"""
|
77 |
+
embed_dim: output dimension for each position
|
78 |
+
pos: a list of positions to be encoded: size (M,)
|
79 |
+
returns: (M, D)
|
80 |
+
"""
|
81 |
+
assert embed_dim % 2 == 0
|
82 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
83 |
+
omega /= embed_dim / 2.0
|
84 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
85 |
+
|
86 |
+
pos = pos.reshape(-1) # (M,)
|
87 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
88 |
+
|
89 |
+
emb_sin = np.sin(out) # (M, D/2)
|
90 |
+
emb_cos = np.cos(out) # (M, D/2)
|
91 |
+
|
92 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
93 |
+
return emb
|
src/models/vision_transformer.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from src.masks.utils import apply_masks
|
13 |
+
from src.models.utils.modules import Block
|
14 |
+
from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D
|
15 |
+
from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
|
16 |
+
from src.utils.tensors import trunc_normal_
|
17 |
+
|
18 |
+
|
19 |
+
class VisionTransformer(nn.Module):
|
20 |
+
"""Vision Transformer"""
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
img_size=(224, 224),
|
25 |
+
patch_size=16,
|
26 |
+
num_frames=1,
|
27 |
+
tubelet_size=2,
|
28 |
+
in_chans=3,
|
29 |
+
embed_dim=768,
|
30 |
+
depth=12,
|
31 |
+
num_heads=12,
|
32 |
+
mlp_ratio=4.0,
|
33 |
+
qkv_bias=True,
|
34 |
+
qk_scale=None,
|
35 |
+
drop_rate=0.0,
|
36 |
+
attn_drop_rate=0.0,
|
37 |
+
drop_path_rate=0.0,
|
38 |
+
norm_layer=nn.LayerNorm,
|
39 |
+
init_std=0.02,
|
40 |
+
out_layers=None,
|
41 |
+
uniform_power=False,
|
42 |
+
use_silu=False,
|
43 |
+
wide_silu=True,
|
44 |
+
use_sdpa=True,
|
45 |
+
use_activation_checkpointing=False,
|
46 |
+
use_rope=False,
|
47 |
+
handle_nonsquare_inputs=True,
|
48 |
+
**kwargs
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
self.num_features = self.embed_dim = embed_dim
|
52 |
+
self.num_heads = num_heads
|
53 |
+
self.out_layers = out_layers
|
54 |
+
self.handle_nonsquare_inputs = handle_nonsquare_inputs
|
55 |
+
|
56 |
+
if type(img_size) is int:
|
57 |
+
img_size = (img_size, img_size)
|
58 |
+
self.img_height, self.img_width = img_size
|
59 |
+
self.patch_size = patch_size
|
60 |
+
self.num_frames = num_frames
|
61 |
+
self.tubelet_size = tubelet_size
|
62 |
+
self.is_video = num_frames > 1
|
63 |
+
|
64 |
+
self.use_activation_checkpointing = use_activation_checkpointing
|
65 |
+
|
66 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
67 |
+
|
68 |
+
# Tokenize pixels with convolution
|
69 |
+
if self.is_video:
|
70 |
+
self.patch_embed = PatchEmbed3D(
|
71 |
+
patch_size=patch_size, tubelet_size=tubelet_size, in_chans=in_chans, embed_dim=embed_dim
|
72 |
+
)
|
73 |
+
self.num_patches = (num_frames // tubelet_size) * (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
74 |
+
else:
|
75 |
+
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
76 |
+
self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
77 |
+
|
78 |
+
# Position embedding
|
79 |
+
self.uniform_power = uniform_power
|
80 |
+
self.use_rope = use_rope
|
81 |
+
if self.use_rope:
|
82 |
+
self.pos_embed = None
|
83 |
+
else:
|
84 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=False)
|
85 |
+
|
86 |
+
# Attention Blocks
|
87 |
+
self.blocks = nn.ModuleList(
|
88 |
+
[
|
89 |
+
Block(
|
90 |
+
use_rope=use_rope,
|
91 |
+
grid_size=img_size[0] // patch_size,
|
92 |
+
grid_depth=num_frames // tubelet_size,
|
93 |
+
dim=embed_dim,
|
94 |
+
num_heads=num_heads,
|
95 |
+
mlp_ratio=mlp_ratio,
|
96 |
+
use_sdpa=use_sdpa,
|
97 |
+
qkv_bias=qkv_bias,
|
98 |
+
qk_scale=qk_scale,
|
99 |
+
drop=drop_rate,
|
100 |
+
act_layer=nn.SiLU if use_silu else nn.GELU,
|
101 |
+
wide_silu=wide_silu,
|
102 |
+
attn_drop=attn_drop_rate,
|
103 |
+
drop_path=dpr[i],
|
104 |
+
norm_layer=norm_layer,
|
105 |
+
)
|
106 |
+
for i in range(depth)
|
107 |
+
]
|
108 |
+
)
|
109 |
+
self.norm = norm_layer(embed_dim)
|
110 |
+
|
111 |
+
# ------ initialize weights
|
112 |
+
if self.pos_embed is not None:
|
113 |
+
self._init_pos_embed(self.pos_embed.data) # sincos pos-embed
|
114 |
+
self.init_std = init_std
|
115 |
+
self.apply(self._init_weights)
|
116 |
+
self._rescale_blocks()
|
117 |
+
|
118 |
+
def _init_pos_embed(self, pos_embed):
|
119 |
+
embed_dim = pos_embed.size(-1)
|
120 |
+
grid_size = self.img_height // self.patch_size # TODO: update; currently assumes square input
|
121 |
+
if self.is_video:
|
122 |
+
grid_depth = self.num_frames // self.tubelet_size
|
123 |
+
sincos = get_3d_sincos_pos_embed(
|
124 |
+
embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=self.uniform_power
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)
|
128 |
+
pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))
|
129 |
+
|
130 |
+
def _init_weights(self, m):
|
131 |
+
if isinstance(m, nn.Linear):
|
132 |
+
trunc_normal_(m.weight, std=self.init_std)
|
133 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
134 |
+
nn.init.constant_(m.bias, 0)
|
135 |
+
elif isinstance(m, nn.LayerNorm):
|
136 |
+
nn.init.constant_(m.bias, 0)
|
137 |
+
nn.init.constant_(m.weight, 1.0)
|
138 |
+
elif isinstance(m, nn.Conv2d):
|
139 |
+
trunc_normal_(m.weight, std=self.init_std)
|
140 |
+
if m.bias is not None:
|
141 |
+
nn.init.constant_(m.bias, 0)
|
142 |
+
elif isinstance(m, nn.Conv3d):
|
143 |
+
trunc_normal_(m.weight, std=self.init_std)
|
144 |
+
if m.bias is not None:
|
145 |
+
nn.init.constant_(m.bias, 0)
|
146 |
+
|
147 |
+
def _rescale_blocks(self):
|
148 |
+
def rescale(param, layer_id):
|
149 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
150 |
+
|
151 |
+
for layer_id, layer in enumerate(self.blocks):
|
152 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
153 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
154 |
+
|
155 |
+
def get_num_layers(self):
|
156 |
+
return len(self.blocks)
|
157 |
+
|
158 |
+
def no_weight_decay(self):
|
159 |
+
return {}
|
160 |
+
|
161 |
+
def forward(self, x, masks=None):
|
162 |
+
"""
|
163 |
+
:param x: input image/video
|
164 |
+
:param masks: indices of patch tokens to mask (remove)
|
165 |
+
"""
|
166 |
+
if masks is not None and not isinstance(masks, list):
|
167 |
+
masks = [masks]
|
168 |
+
|
169 |
+
# Tokenize input
|
170 |
+
# Image
|
171 |
+
if x.ndim == 4:
|
172 |
+
_, _, H, W = x.shape
|
173 |
+
T = 1
|
174 |
+
# Video
|
175 |
+
elif x.ndim == 5:
|
176 |
+
_, _, T, H, W = x.shape
|
177 |
+
T = T // self.tubelet_size
|
178 |
+
H_patches = H // self.patch_size
|
179 |
+
W_patches = W // self.patch_size
|
180 |
+
if not self.handle_nonsquare_inputs:
|
181 |
+
T = H_patches = W_patches = None
|
182 |
+
|
183 |
+
if not self.use_rope:
|
184 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
185 |
+
x = self.patch_embed(x)
|
186 |
+
x += pos_embed
|
187 |
+
else:
|
188 |
+
x = self.patch_embed(x)
|
189 |
+
|
190 |
+
# Mask away unwanted tokens (if masks provided)
|
191 |
+
if masks is not None:
|
192 |
+
x = apply_masks(x, masks)
|
193 |
+
masks = torch.cat(masks, dim=0)
|
194 |
+
|
195 |
+
# Fwd prop
|
196 |
+
outs = []
|
197 |
+
for i, blk in enumerate(self.blocks):
|
198 |
+
if self.use_activation_checkpointing:
|
199 |
+
x = torch.utils.checkpoint.checkpoint(
|
200 |
+
blk, x, masks, None, T=T, H_patches=H_patches, W_patches=W_patches, use_reentrant=False
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
x = blk(x, mask=masks, attn_mask=None, T=T, H_patches=H_patches, W_patches=W_patches)
|
204 |
+
if self.out_layers is not None and i in self.out_layers:
|
205 |
+
outs.append(self.norm(x))
|
206 |
+
|
207 |
+
if self.out_layers is not None:
|
208 |
+
return outs
|
209 |
+
|
210 |
+
if self.norm is not None:
|
211 |
+
x = self.norm(x)
|
212 |
+
|
213 |
+
return x
|
214 |
+
|
215 |
+
def interpolate_pos_encoding(self, x, pos_embed):
|
216 |
+
|
217 |
+
_, N, dim = pos_embed.shape
|
218 |
+
|
219 |
+
if self.is_video:
|
220 |
+
|
221 |
+
# If pos_embed already corret size, just return
|
222 |
+
_, _, T, H, W = x.shape
|
223 |
+
if H == self.img_height and W == self.img_width and T == self.num_frames:
|
224 |
+
return pos_embed
|
225 |
+
|
226 |
+
# Just chop off last N tokens of positional embedding
|
227 |
+
elif H == self.img_height and W == self.img_width and T < self.num_frames:
|
228 |
+
new_N = int((T // self.tubelet_size) * (H // self.patch_size) * (W // self.patch_size))
|
229 |
+
return pos_embed[:, :new_N, :]
|
230 |
+
|
231 |
+
# Convert depth, height, width of input to be measured in patches
|
232 |
+
# instead of pixels/frames
|
233 |
+
T = T // self.tubelet_size
|
234 |
+
H = H // self.patch_size
|
235 |
+
W = W // self.patch_size
|
236 |
+
|
237 |
+
# Compute the initialized shape of the positional embedding measured
|
238 |
+
# in patches
|
239 |
+
N_t = self.num_frames // self.tubelet_size
|
240 |
+
N_h = self.img_height // self.patch_size
|
241 |
+
N_w = self.img_width // self.patch_size
|
242 |
+
assert N_h * N_w * N_t == N, "Positional embedding initialized incorrectly"
|
243 |
+
|
244 |
+
# Compute scale factor for spatio-temporal interpolation
|
245 |
+
scale_factor = (T / N_t, H / N_h, W / N_w)
|
246 |
+
|
247 |
+
pos_embed = nn.functional.interpolate(
|
248 |
+
pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3),
|
249 |
+
scale_factor=scale_factor,
|
250 |
+
mode="trilinear",
|
251 |
+
)
|
252 |
+
pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim)
|
253 |
+
return pos_embed
|
254 |
+
|
255 |
+
else:
|
256 |
+
|
257 |
+
# If pos_embed already corret size, just return
|
258 |
+
_, _, H, W = x.shape
|
259 |
+
if H == self.img_height and W == self.img_width:
|
260 |
+
return pos_embed
|
261 |
+
|
262 |
+
# Compute scale factor for spatial interpolation
|
263 |
+
npatch = (H // self.patch_size) * (W // self.patch_size)
|
264 |
+
scale_factor = math.sqrt(npatch / N)
|
265 |
+
|
266 |
+
pos_embed = nn.functional.interpolate(
|
267 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
268 |
+
scale_factor=scale_factor,
|
269 |
+
mode="bicubic",
|
270 |
+
)
|
271 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
272 |
+
return pos_embed
|
273 |
+
|
274 |
+
|
275 |
+
def vit_large(patch_size=16, **kwargs):
|
276 |
+
model = VisionTransformer(
|
277 |
+
patch_size=patch_size,
|
278 |
+
embed_dim=1024,
|
279 |
+
depth=24,
|
280 |
+
num_heads=16,
|
281 |
+
mlp_ratio=4,
|
282 |
+
qkv_bias=True,
|
283 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
284 |
+
**kwargs
|
285 |
+
)
|
286 |
+
return model
|
287 |
+
|
288 |
+
|
289 |
+
def vit_huge(patch_size=16, **kwargs):
|
290 |
+
model = VisionTransformer(
|
291 |
+
patch_size=patch_size,
|
292 |
+
embed_dim=1280,
|
293 |
+
depth=32,
|
294 |
+
num_heads=16,
|
295 |
+
mlp_ratio=4,
|
296 |
+
qkv_bias=True,
|
297 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
298 |
+
**kwargs
|
299 |
+
)
|
300 |
+
return model
|
301 |
+
|
302 |
+
|
303 |
+
def vit_giant_xformers(patch_size=16, **kwargs):
|
304 |
+
model = VisionTransformer(
|
305 |
+
patch_size=patch_size,
|
306 |
+
embed_dim=1408,
|
307 |
+
depth=40,
|
308 |
+
num_heads=22,
|
309 |
+
mlp_ratio=48 / 11,
|
310 |
+
qkv_bias=True,
|
311 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
312 |
+
**kwargs
|
313 |
+
)
|
314 |
+
return model
|
315 |
+
|
316 |
+
|
317 |
+
# We do not use any of the following ViT definitions in V-JEPA 2, but retain them for
|
318 |
+
# compatibility reasons.
|
319 |
+
def vit_synthetic(patch_size=16, **kwargs):
|
320 |
+
# For performance testing only
|
321 |
+
model = VisionTransformer(
|
322 |
+
patch_size=patch_size,
|
323 |
+
embed_dim=1,
|
324 |
+
depth=1,
|
325 |
+
num_heads=1,
|
326 |
+
mlp_ratio=4,
|
327 |
+
qkv_bias=True,
|
328 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
329 |
+
**kwargs
|
330 |
+
)
|
331 |
+
return model
|
332 |
+
|
333 |
+
|
334 |
+
def vit_tiny(patch_size=16, **kwargs):
|
335 |
+
model = VisionTransformer(
|
336 |
+
patch_size=patch_size,
|
337 |
+
embed_dim=192,
|
338 |
+
depth=12,
|
339 |
+
num_heads=3,
|
340 |
+
mlp_ratio=4,
|
341 |
+
qkv_bias=True,
|
342 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
343 |
+
**kwargs
|
344 |
+
)
|
345 |
+
return model
|
346 |
+
|
347 |
+
|
348 |
+
def vit_small(patch_size=16, **kwargs):
|
349 |
+
model = VisionTransformer(
|
350 |
+
patch_size=patch_size,
|
351 |
+
embed_dim=384,
|
352 |
+
depth=12,
|
353 |
+
num_heads=6,
|
354 |
+
mlp_ratio=4,
|
355 |
+
qkv_bias=True,
|
356 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
357 |
+
**kwargs
|
358 |
+
)
|
359 |
+
return model
|
360 |
+
|
361 |
+
|
362 |
+
def vit_base(patch_size=16, **kwargs):
|
363 |
+
model = VisionTransformer(
|
364 |
+
patch_size=patch_size,
|
365 |
+
embed_dim=768,
|
366 |
+
depth=12,
|
367 |
+
num_heads=12,
|
368 |
+
mlp_ratio=4,
|
369 |
+
qkv_bias=True,
|
370 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
371 |
+
**kwargs
|
372 |
+
)
|
373 |
+
return model
|
374 |
+
|
375 |
+
|
376 |
+
def vit_large_rope(patch_size=16, **kwargs):
|
377 |
+
model = VisionTransformer(
|
378 |
+
patch_size=patch_size,
|
379 |
+
embed_dim=1024,
|
380 |
+
depth=24,
|
381 |
+
num_heads=16,
|
382 |
+
mlp_ratio=4,
|
383 |
+
qkv_bias=True,
|
384 |
+
use_rope=True,
|
385 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
386 |
+
**kwargs
|
387 |
+
)
|
388 |
+
return model
|
389 |
+
|
390 |
+
|
391 |
+
def vit_huge_rope(patch_size=16, **kwargs):
|
392 |
+
model = VisionTransformer(
|
393 |
+
patch_size=patch_size,
|
394 |
+
embed_dim=1280,
|
395 |
+
depth=32,
|
396 |
+
num_heads=16,
|
397 |
+
mlp_ratio=4,
|
398 |
+
qkv_bias=True,
|
399 |
+
use_rope=True,
|
400 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
401 |
+
**kwargs
|
402 |
+
)
|
403 |
+
return model
|
404 |
+
|
405 |
+
|
406 |
+
def vit_giant(patch_size=16, **kwargs):
|
407 |
+
model = VisionTransformer(
|
408 |
+
patch_size=patch_size,
|
409 |
+
embed_dim=1408,
|
410 |
+
depth=40,
|
411 |
+
num_heads=16,
|
412 |
+
mlp_ratio=48 / 11,
|
413 |
+
qkv_bias=True,
|
414 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
415 |
+
**kwargs
|
416 |
+
)
|
417 |
+
return model
|
418 |
+
|
419 |
+
|
420 |
+
def vit_giant_rope(patch_size=16, **kwargs):
|
421 |
+
model = VisionTransformer(
|
422 |
+
patch_size=patch_size,
|
423 |
+
embed_dim=1408,
|
424 |
+
depth=40,
|
425 |
+
num_heads=16,
|
426 |
+
mlp_ratio=48 / 11,
|
427 |
+
qkv_bias=True,
|
428 |
+
use_rope=True,
|
429 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
430 |
+
**kwargs
|
431 |
+
)
|
432 |
+
return model
|
433 |
+
|
434 |
+
|
435 |
+
def vit_giant_xformers_rope(patch_size=16, **kwargs):
|
436 |
+
model = VisionTransformer(
|
437 |
+
patch_size=patch_size,
|
438 |
+
embed_dim=1408,
|
439 |
+
depth=40,
|
440 |
+
num_heads=22,
|
441 |
+
mlp_ratio=48 / 11,
|
442 |
+
qkv_bias=True,
|
443 |
+
use_rope=True,
|
444 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
445 |
+
**kwargs
|
446 |
+
)
|
447 |
+
return model
|
448 |
+
|
449 |
+
|
450 |
+
def vit_gigantic(patch_size=16, **kwargs):
|
451 |
+
model = VisionTransformer(
|
452 |
+
patch_size=patch_size,
|
453 |
+
embed_dim=1664,
|
454 |
+
depth=48,
|
455 |
+
num_heads=16,
|
456 |
+
mpl_ratio=64 / 13,
|
457 |
+
qkv_bias=True,
|
458 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
459 |
+
**kwargs
|
460 |
+
)
|
461 |
+
return model
|
462 |
+
|
463 |
+
|
464 |
+
def vit_gigantic_xformers(patch_size=16, **kwargs):
|
465 |
+
model = VisionTransformer(
|
466 |
+
patch_size=patch_size,
|
467 |
+
embed_dim=1664,
|
468 |
+
depth=48,
|
469 |
+
num_heads=26,
|
470 |
+
mpl_ratio=64 / 13,
|
471 |
+
qkv_bias=True,
|
472 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
473 |
+
**kwargs
|
474 |
+
)
|
475 |
+
return model
|
476 |
+
|
477 |
+
|
478 |
+
VIT_EMBED_DIMS = {
|
479 |
+
"vit_synthetic": 1,
|
480 |
+
"vit_tiny": 192,
|
481 |
+
"vit_small": 384,
|
482 |
+
"vit_base": 768,
|
483 |
+
"vit_large": 1024,
|
484 |
+
"vit_huge": 1280,
|
485 |
+
"vit_giant": 1408,
|
486 |
+
"vit_gigantic": 1664,
|
487 |
+
}
|
src/utils/__pycache__/tensors.cpython-312.pyc
ADDED
Binary file (2.18 kB). View file
|
|
src/utils/checkpoint_loader.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
from typing import Any
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.serialization import MAP_LOCATION
|
13 |
+
|
14 |
+
from src.utils.logging import get_logger
|
15 |
+
|
16 |
+
logger = get_logger(os.path.basename(__file__))
|
17 |
+
|
18 |
+
|
19 |
+
def robust_checkpoint_loader(r_path: str, map_location: MAP_LOCATION = "cpu", max_retries: int = 3) -> Any:
|
20 |
+
"""
|
21 |
+
Loads a checkpoint from a path, retrying up to max_retries times if the checkpoint is not found.
|
22 |
+
"""
|
23 |
+
retries = 0
|
24 |
+
|
25 |
+
while retries < max_retries:
|
26 |
+
try:
|
27 |
+
return torch.load(r_path, map_location=map_location)
|
28 |
+
except Exception as e:
|
29 |
+
logger.warning(f"Encountered exception when loading checkpoint {e}")
|
30 |
+
retries += 1
|
31 |
+
if retries < max_retries:
|
32 |
+
sleep_time_s = (2**retries) * random.uniform(1.0, 1.1)
|
33 |
+
logger.warning(f"Sleeping {sleep_time_s}s and trying again, count {retries}/{max_retries}")
|
34 |
+
time.sleep(sleep_time_s)
|
35 |
+
continue
|
36 |
+
else:
|
37 |
+
raise e
|
src/utils/distributed.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
from src.utils.logging import get_logger
|
13 |
+
|
14 |
+
logger = get_logger()
|
15 |
+
|
16 |
+
|
17 |
+
def init_distributed(port=37129, rank_and_world_size=(None, None)):
|
18 |
+
# try to set all environment variables to avoid triggering a segfault
|
19 |
+
# environment variables can be reallocated during the execution of torch.distributed.init_process_group
|
20 |
+
# the idea is a race condition may trigger if init_progress_group is modifying an environment variable at
|
21 |
+
# the same time as Python, so we try to set all environs before initializing distributed
|
22 |
+
if "SLURM_JOB_ID" in os.environ:
|
23 |
+
# Use the slurm_tmpdir (if it exists) instead of /tmp
|
24 |
+
tmpdir = Path(f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}")
|
25 |
+
if tmpdir.exists():
|
26 |
+
os.environ["TMPDIR"] = str(tmpdir)
|
27 |
+
|
28 |
+
if dist.is_available() and dist.is_initialized():
|
29 |
+
return dist.get_world_size(), dist.get_rank()
|
30 |
+
|
31 |
+
rank, world_size = rank_and_world_size
|
32 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
33 |
+
|
34 |
+
if (rank is None) or (world_size is None):
|
35 |
+
try:
|
36 |
+
world_size = int(os.environ["SLURM_NTASKS"])
|
37 |
+
rank = int(os.environ["SLURM_PROCID"])
|
38 |
+
os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"]
|
39 |
+
except Exception:
|
40 |
+
logger.info("SLURM vars not set (distributed training not available)")
|
41 |
+
world_size, rank = 1, 0
|
42 |
+
return world_size, rank
|
43 |
+
|
44 |
+
try:
|
45 |
+
os.environ["MASTER_PORT"] = str(port)
|
46 |
+
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
|
47 |
+
except Exception as e:
|
48 |
+
world_size, rank = 1, 0
|
49 |
+
logger.info(f"Rank: {rank}. Distributed training not available {e}")
|
50 |
+
|
51 |
+
return world_size, rank
|
52 |
+
|
53 |
+
|
54 |
+
class AllGather(torch.autograd.Function):
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def forward(ctx, x):
|
58 |
+
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
|
59 |
+
x = x.contiguous()
|
60 |
+
outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
61 |
+
dist.all_gather(outputs, x)
|
62 |
+
return torch.cat(outputs, 0)
|
63 |
+
return x
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def backward(ctx, grads):
|
67 |
+
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
|
68 |
+
s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
|
69 |
+
e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
|
70 |
+
grads = grads.contiguous()
|
71 |
+
dist.all_reduce(grads)
|
72 |
+
return grads[s:e]
|
73 |
+
return grads
|
74 |
+
|
75 |
+
|
76 |
+
class AllReduceSum(torch.autograd.Function):
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def forward(ctx, x):
|
80 |
+
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
|
81 |
+
x = x.contiguous()
|
82 |
+
dist.all_reduce(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def backward(ctx, grads):
|
87 |
+
return grads
|
88 |
+
|
89 |
+
|
90 |
+
class AllReduce(torch.autograd.Function):
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def forward(ctx, x):
|
94 |
+
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
|
95 |
+
x = x.contiguous() / dist.get_world_size()
|
96 |
+
dist.all_reduce(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def backward(ctx, grads):
|
101 |
+
return grads
|
src/utils/logging.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def gpu_timer(closure, log_timings=True):
|
15 |
+
"""Helper to time gpu-time to execute closure()"""
|
16 |
+
log_timings = log_timings and torch.cuda.is_available()
|
17 |
+
|
18 |
+
elapsed_time = -1.0
|
19 |
+
if log_timings:
|
20 |
+
start = torch.cuda.Event(enable_timing=True)
|
21 |
+
end = torch.cuda.Event(enable_timing=True)
|
22 |
+
start.record()
|
23 |
+
|
24 |
+
result = closure()
|
25 |
+
|
26 |
+
if log_timings:
|
27 |
+
end.record()
|
28 |
+
torch.cuda.synchronize()
|
29 |
+
elapsed_time = start.elapsed_time(end)
|
30 |
+
|
31 |
+
return result, elapsed_time
|
32 |
+
|
33 |
+
|
34 |
+
LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(name)-20s][%(funcName)-25s] %(message)s"
|
35 |
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
36 |
+
|
37 |
+
|
38 |
+
def get_logger(name=None, force=False):
|
39 |
+
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force)
|
40 |
+
return logging.getLogger(name=name)
|
41 |
+
|
42 |
+
|
43 |
+
class CSVLogger(object):
|
44 |
+
|
45 |
+
def __init__(self, fname, *argv, **kwargs):
|
46 |
+
self.fname = fname
|
47 |
+
self.types = []
|
48 |
+
mode = kwargs.get("mode", "+a")
|
49 |
+
self.delim = kwargs.get("delim", ",")
|
50 |
+
# -- print headers
|
51 |
+
with open(self.fname, mode) as f:
|
52 |
+
for i, v in enumerate(argv, 1):
|
53 |
+
self.types.append(v[0])
|
54 |
+
if i < len(argv):
|
55 |
+
print(v[1], end=self.delim, file=f)
|
56 |
+
else:
|
57 |
+
print(v[1], end="\n", file=f)
|
58 |
+
|
59 |
+
def log(self, *argv):
|
60 |
+
with open(self.fname, "+a") as f:
|
61 |
+
for i, tv in enumerate(zip(self.types, argv), 1):
|
62 |
+
end = self.delim if i < len(argv) else "\n"
|
63 |
+
print(tv[0] % tv[1], end=end, file=f)
|
64 |
+
|
65 |
+
|
66 |
+
class AverageMeter(object):
|
67 |
+
"""computes and stores the average and current value"""
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
self.reset()
|
71 |
+
|
72 |
+
def reset(self):
|
73 |
+
self.val = 0
|
74 |
+
self.avg = 0
|
75 |
+
self.max = float("-inf")
|
76 |
+
self.min = float("inf")
|
77 |
+
self.sum = 0
|
78 |
+
self.count = 0
|
79 |
+
|
80 |
+
def update(self, val, n=1):
|
81 |
+
self.val = val
|
82 |
+
try:
|
83 |
+
self.max = max(val, self.max)
|
84 |
+
self.min = min(val, self.min)
|
85 |
+
except Exception:
|
86 |
+
pass
|
87 |
+
self.sum += val * n
|
88 |
+
self.count += n
|
89 |
+
self.avg = self.sum / self.count
|
90 |
+
|
91 |
+
|
92 |
+
def jepa_rootpath():
|
93 |
+
this_file = os.path.abspath(__file__)
|
94 |
+
return "/".join(this_file.split("/")[:-3])
|
95 |
+
|
96 |
+
|
97 |
+
def git_information():
|
98 |
+
jepa_root = jepa_rootpath()
|
99 |
+
try:
|
100 |
+
resp = (
|
101 |
+
subprocess.check_output(["git", "-C", jepa_root, "rev-parse", "HEAD", "--abbrev-ref", "HEAD"])
|
102 |
+
.decode("ascii")
|
103 |
+
.strip()
|
104 |
+
)
|
105 |
+
commit, branch = resp.split("\n")
|
106 |
+
return f"branch: {branch}\ncommit: {commit}\n"
|
107 |
+
except Exception:
|
108 |
+
return "unknown"
|
src/utils/monitoring.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import dataclasses
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
from typing import Dict, Tuple
|
10 |
+
|
11 |
+
import psutil
|
12 |
+
|
13 |
+
|
14 |
+
@dataclasses.dataclass
|
15 |
+
class ResourceStatsSample:
|
16 |
+
timestamp: float
|
17 |
+
cpu_percent: float
|
18 |
+
read_count: int
|
19 |
+
write_count: int
|
20 |
+
read_bytes: int
|
21 |
+
write_bytes: int
|
22 |
+
read_chars: int
|
23 |
+
write_chars: int
|
24 |
+
cpu_times_user: float
|
25 |
+
cpu_times_system: float
|
26 |
+
cpu_times_children_user: float
|
27 |
+
cpu_times_children_system: float
|
28 |
+
cpu_times_iowait: float
|
29 |
+
cpu_affinity: str
|
30 |
+
cpu_num: int
|
31 |
+
num_threads: int
|
32 |
+
num_voluntary_ctx_switches: int
|
33 |
+
num_involuntary_ctx_switches: int
|
34 |
+
|
35 |
+
def as_tuple(self) -> Dict:
|
36 |
+
"""Return values mirroring fields."""
|
37 |
+
return dataclasses.astuple(self)
|
38 |
+
|
39 |
+
def fields(self) -> Tuple[dataclasses.Field, ...]:
|
40 |
+
"""Return fields in this dataclass."""
|
41 |
+
return dataclasses.fields(self.__class__)
|
42 |
+
|
43 |
+
|
44 |
+
class ResourceMonitoringThread(threading.Thread):
|
45 |
+
def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None):
|
46 |
+
"""Starts a thread to monitor pid every refresh_interval seconds.
|
47 |
+
|
48 |
+
Passes a ResourceStatsSample object to the callback."""
|
49 |
+
super(ResourceMonitoringThread, self).__init__()
|
50 |
+
if refresh_interval is None:
|
51 |
+
refresh_interval = 5
|
52 |
+
self.is_running_event = threading.Event()
|
53 |
+
self.p = psutil.Process(pid)
|
54 |
+
self.refresh_interval = refresh_interval
|
55 |
+
if stats_callback_fn is None:
|
56 |
+
# Default callback
|
57 |
+
def stats_callback_fn(resource_sample: ResourceStatsSample):
|
58 |
+
print(f"PID {self.p.pid} Stats: {resource_sample.resource_stats}")
|
59 |
+
|
60 |
+
elif not callable(stats_callback_fn):
|
61 |
+
raise ValueError("Callback needs to be callable, got {}".format(type(stats_callback_fn)))
|
62 |
+
self.stats_callback_fn = stats_callback_fn
|
63 |
+
|
64 |
+
def stop(self) -> None:
|
65 |
+
self.is_running_event.set()
|
66 |
+
|
67 |
+
def run(self) -> None:
|
68 |
+
while not self.is_running_event.is_set():
|
69 |
+
self.sample_counters()
|
70 |
+
self.is_running_event.wait(self.refresh_interval)
|
71 |
+
|
72 |
+
def log_sample(self, resource_sample: ResourceStatsSample) -> None:
|
73 |
+
self.stats_callback_fn(resource_sample)
|
74 |
+
|
75 |
+
def sample_counters(self) -> None:
|
76 |
+
if not self.p.is_running():
|
77 |
+
self.stop()
|
78 |
+
return
|
79 |
+
|
80 |
+
with self.p.oneshot():
|
81 |
+
cpu_percent = self.p.cpu_percent()
|
82 |
+
cpu_times = self.p.cpu_times()
|
83 |
+
io_counters = self.p.io_counters()
|
84 |
+
cpu_affinity = self.p.cpu_affinity()
|
85 |
+
cpu_num = self.p.cpu_num()
|
86 |
+
num_threads = self.p.num_threads()
|
87 |
+
num_ctx_switches = self.p.num_ctx_switches()
|
88 |
+
timestamp = time.time()
|
89 |
+
|
90 |
+
read_count = io_counters.read_count
|
91 |
+
write_count = io_counters.write_count
|
92 |
+
read_bytes = io_counters.read_bytes
|
93 |
+
write_bytes = io_counters.write_bytes
|
94 |
+
read_chars = io_counters.read_chars
|
95 |
+
write_chars = io_counters.write_chars
|
96 |
+
|
97 |
+
def compress_cpu_affinity(cpu_affinity):
|
98 |
+
"""Change list representation to interval/range representation."""
|
99 |
+
if not cpu_affinity:
|
100 |
+
return ""
|
101 |
+
cpu_affinity_compressed = []
|
102 |
+
min_x = None
|
103 |
+
max_x = None
|
104 |
+
last_x = None
|
105 |
+
|
106 |
+
# Find contiguous ranges
|
107 |
+
for x in cpu_affinity:
|
108 |
+
if last_x is None:
|
109 |
+
# Start interval
|
110 |
+
min_x = x
|
111 |
+
max_x = x
|
112 |
+
last_x = x
|
113 |
+
continue
|
114 |
+
elif x == (last_x + 1):
|
115 |
+
# Move interval up
|
116 |
+
max_x = x
|
117 |
+
elif max_x is not None:
|
118 |
+
# Interval ended, start again
|
119 |
+
if min_x == max_x:
|
120 |
+
cpu_affinity_compressed.append("{}".format(min_x))
|
121 |
+
else:
|
122 |
+
cpu_affinity_compressed.append("{}-{}".format(min_x, max_x))
|
123 |
+
min_x = x
|
124 |
+
max_x = x
|
125 |
+
last_x = x
|
126 |
+
# Terminate last range
|
127 |
+
if max_x is not None:
|
128 |
+
if min_x == max_x:
|
129 |
+
cpu_affinity_compressed.append("{}".format(min_x))
|
130 |
+
else:
|
131 |
+
cpu_affinity_compressed.append("{}-{}".format(min_x, max_x))
|
132 |
+
|
133 |
+
# Concat
|
134 |
+
cpu_affinity_compressed = ",".join(cpu_affinity_compressed)
|
135 |
+
|
136 |
+
return cpu_affinity_compressed
|
137 |
+
|
138 |
+
cpu_affinity = compress_cpu_affinity(cpu_affinity)
|
139 |
+
|
140 |
+
resource_sample = ResourceStatsSample(
|
141 |
+
timestamp=timestamp,
|
142 |
+
cpu_percent=cpu_percent,
|
143 |
+
read_count=read_count,
|
144 |
+
write_count=write_count,
|
145 |
+
read_bytes=read_bytes,
|
146 |
+
write_bytes=write_bytes,
|
147 |
+
read_chars=read_chars,
|
148 |
+
write_chars=write_chars,
|
149 |
+
cpu_times_user=cpu_times.user,
|
150 |
+
cpu_times_system=cpu_times.system,
|
151 |
+
cpu_times_children_user=cpu_times.children_user,
|
152 |
+
cpu_times_children_system=cpu_times.children_system,
|
153 |
+
cpu_times_iowait=cpu_times.iowait,
|
154 |
+
cpu_affinity=cpu_affinity,
|
155 |
+
cpu_num=cpu_num,
|
156 |
+
num_threads=num_threads,
|
157 |
+
num_voluntary_ctx_switches=num_ctx_switches.voluntary,
|
158 |
+
num_involuntary_ctx_switches=num_ctx_switches.involuntary,
|
159 |
+
)
|
160 |
+
self.log_sample(resource_sample)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
import multiprocessing
|
165 |
+
|
166 |
+
pid = multiprocessing.current_process().pid
|
167 |
+
monitor_thread = ResourceMonitoringThread(pid, 1)
|
168 |
+
monitor_thread.start()
|
169 |
+
time.sleep(5)
|
170 |
+
print("Shutdown")
|
171 |
+
monitor_thread.stop()
|
src/utils/schedulers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
class WSDSchedule(object):
|
10 |
+
|
11 |
+
def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0):
|
12 |
+
self.optimizer = optimizer
|
13 |
+
self.start_lr = start_lr
|
14 |
+
self.ref_lr = ref_lr
|
15 |
+
self.final_lr = final_lr
|
16 |
+
self.anneal_steps = anneal_steps
|
17 |
+
self.warmup_steps = warmup_steps
|
18 |
+
self.T_max = T_max - warmup_steps - anneal_steps
|
19 |
+
self._step = 0.0
|
20 |
+
|
21 |
+
def step(self):
|
22 |
+
self._step += 1
|
23 |
+
if self._step < self.warmup_steps:
|
24 |
+
progress = float(self._step) / float(max(1, self.warmup_steps))
|
25 |
+
new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
|
26 |
+
elif self._step < self.T_max + self.warmup_steps:
|
27 |
+
new_lr = self.ref_lr
|
28 |
+
else:
|
29 |
+
_step = self._step - (self.T_max + self.warmup_steps)
|
30 |
+
progress = float(_step) / float(max(1, self.anneal_steps))
|
31 |
+
new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
|
32 |
+
|
33 |
+
for group in self.optimizer.param_groups:
|
34 |
+
group["lr"] = new_lr
|
35 |
+
if "lr_scale" in group:
|
36 |
+
group["lr"] *= group["lr_scale"]
|
37 |
+
|
38 |
+
return new_lr
|
39 |
+
|
40 |
+
|
41 |
+
class WarmupCosineSchedule(object):
|
42 |
+
|
43 |
+
def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
|
44 |
+
self.optimizer = optimizer
|
45 |
+
self.start_lr = start_lr
|
46 |
+
self.ref_lr = ref_lr
|
47 |
+
self.final_lr = final_lr
|
48 |
+
self.warmup_steps = warmup_steps
|
49 |
+
self.T_max = T_max - warmup_steps
|
50 |
+
self._step = 0.0
|
51 |
+
|
52 |
+
def step(self):
|
53 |
+
self._step += 1
|
54 |
+
if self._step < self.warmup_steps:
|
55 |
+
progress = float(self._step) / float(max(1, self.warmup_steps))
|
56 |
+
new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
|
57 |
+
else:
|
58 |
+
# -- progress after warmup
|
59 |
+
progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
|
60 |
+
new_lr = max(
|
61 |
+
self.final_lr,
|
62 |
+
self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)),
|
63 |
+
)
|
64 |
+
|
65 |
+
for group in self.optimizer.param_groups:
|
66 |
+
group["lr"] = new_lr
|
67 |
+
|
68 |
+
return new_lr
|
69 |
+
|
70 |
+
|
71 |
+
class CosineWDSchedule(object):
|
72 |
+
|
73 |
+
def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0):
|
74 |
+
self.optimizer = optimizer
|
75 |
+
self.ref_wd = ref_wd
|
76 |
+
self.final_wd = final_wd
|
77 |
+
self.T_max = T_max
|
78 |
+
self._step = 0.0
|
79 |
+
|
80 |
+
def step(self):
|
81 |
+
self._step += 1
|
82 |
+
progress = self._step / self.T_max
|
83 |
+
new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress))
|
84 |
+
|
85 |
+
if self.final_wd <= self.ref_wd:
|
86 |
+
new_wd = max(self.final_wd, new_wd)
|
87 |
+
else:
|
88 |
+
new_wd = min(self.final_wd, new_wd)
|
89 |
+
|
90 |
+
for group in self.optimizer.param_groups:
|
91 |
+
if ("WD_exclude" not in group) or not group["WD_exclude"]:
|
92 |
+
group["weight_decay"] = new_wd
|
93 |
+
return new_wd
|
src/utils/tensors.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from logging import getLogger
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
logger = getLogger()
|
12 |
+
|
13 |
+
|
14 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
15 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
16 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
17 |
+
def norm_cdf(x):
|
18 |
+
# Computes standard normal cumulative distribution function
|
19 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
20 |
+
|
21 |
+
with torch.no_grad():
|
22 |
+
# Values are generated by using a truncated uniform distribution and
|
23 |
+
# then using the inverse CDF for the normal distribution.
|
24 |
+
# Get upper and lower cdf values
|
25 |
+
lower = norm_cdf((a - mean) / std)
|
26 |
+
upper = norm_cdf((b - mean) / std)
|
27 |
+
|
28 |
+
# Uniformly fill tensor with values from [lower, upper], then translate to
|
29 |
+
# [2*lower-1, 2*upper-1].
|
30 |
+
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
31 |
+
|
32 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
33 |
+
# standard normal
|
34 |
+
tensor.erfinv_()
|
35 |
+
|
36 |
+
# Transform to proper mean, std
|
37 |
+
tensor.mul_(std * math.sqrt(2.0))
|
38 |
+
tensor.add_(mean)
|
39 |
+
|
40 |
+
# Clamp to ensure it's in the proper range
|
41 |
+
tensor.clamp_(min=a, max=b)
|
42 |
+
return tensor
|
43 |
+
|
44 |
+
|
45 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
46 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
47 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
48 |
+
|
49 |
+
|
50 |
+
def repeat_interleave_batch(x, B, repeat):
|
51 |
+
N = len(x) // B
|
52 |
+
x = torch.cat([torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N)], dim=0)
|
53 |
+
return x
|
src/utils/wrappers.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class MultiSeqWrapper(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, backbone):
|
12 |
+
super().__init__()
|
13 |
+
self.backbone = backbone
|
14 |
+
|
15 |
+
def forward(self, x, masks=None):
|
16 |
+
"""
|
17 |
+
:param x: [list] List of Tensors of different seq lengths
|
18 |
+
:param masks: [list[list]] List of Tensors (out index: masks for given seq length, inner index: multimasks for that seq len)
|
19 |
+
"""
|
20 |
+
if masks is None:
|
21 |
+
return [self.backbone(xi) for xi in x]
|
22 |
+
|
23 |
+
outs = [[] for _ in x]
|
24 |
+
for i, (xi, mi) in enumerate(zip(x, masks)):
|
25 |
+
for mij in mi:
|
26 |
+
outs[i] += [self.backbone(xi, masks=mij)]
|
27 |
+
return outs
|
28 |
+
|
29 |
+
|
30 |
+
class PredictorMultiSeqWrapper(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self, backbone):
|
33 |
+
super().__init__()
|
34 |
+
self.backbone = backbone
|
35 |
+
|
36 |
+
def forward(self, x, masks_x, masks_y, has_cls=False):
|
37 |
+
n = 0
|
38 |
+
outs = [[] for _ in x]
|
39 |
+
for i, (xi, mxi, myi) in enumerate(zip(x, masks_x, masks_y)):
|
40 |
+
for xij, mxij, myij in zip(xi, mxi, myi):
|
41 |
+
outs[i] += [self.backbone(xij, mxij, myij, mask_index=i, has_cls=has_cls)]
|
42 |
+
n += 1
|
43 |
+
return outs
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4aa645548d069f0235573e314b319436bc9c7f4a7aa6e2c07f494de56a57b955
|
3 |
+
size 17210205
|