VarunKodathala commited on
Commit
0e37bb2
·
verified ·
1 Parent(s): 940aabd

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +74 -3
  3. config.json +26 -0
  4. model.safetensors +3 -0
  5. soccer_qa_inference.py +304 -0
  6. special_tokens_map.json +26 -0
  7. src/datasets/data_manager.py +90 -0
  8. src/datasets/imagenet1k.py +152 -0
  9. src/datasets/utils/dataloader.py +234 -0
  10. src/datasets/utils/utils.py +21 -0
  11. src/datasets/utils/video/__pycache__/functional.cpython-312.pyc +0 -0
  12. src/datasets/utils/video/__pycache__/randaugment.cpython-312.pyc +0 -0
  13. src/datasets/utils/video/__pycache__/transforms.cpython-312.pyc +0 -0
  14. src/datasets/utils/video/__pycache__/volume_transforms.cpython-312.pyc +0 -0
  15. src/datasets/utils/video/functional.py +110 -0
  16. src/datasets/utils/video/randaugment.py +536 -0
  17. src/datasets/utils/video/randerase.py +170 -0
  18. src/datasets/utils/video/transforms.py +1161 -0
  19. src/datasets/utils/video/transforms_builder.py +165 -0
  20. src/datasets/utils/video/volume_transforms.py +159 -0
  21. src/datasets/utils/weighted_sampler.py +336 -0
  22. src/datasets/utils/worker_init_fn.py +76 -0
  23. src/datasets/video_dataset.py +373 -0
  24. src/hub/__init__.py +0 -0
  25. src/hub/backbones.py +177 -0
  26. src/masks/__pycache__/utils.cpython-312.pyc +0 -0
  27. src/masks/default.py +18 -0
  28. src/masks/multiseq_multiblock3d.py +239 -0
  29. src/masks/utils.py +21 -0
  30. src/models/__pycache__/attentive_pooler.cpython-312.pyc +0 -0
  31. src/models/__pycache__/vision_transformer.cpython-312.pyc +0 -0
  32. src/models/ac_predictor.py +200 -0
  33. src/models/attentive_pooler.py +137 -0
  34. src/models/predictor.py +253 -0
  35. src/models/utils/__pycache__/modules.cpython-312.pyc +0 -0
  36. src/models/utils/__pycache__/patch_embed.cpython-312.pyc +0 -0
  37. src/models/utils/__pycache__/pos_embs.cpython-312.pyc +0 -0
  38. src/models/utils/modules.py +610 -0
  39. src/models/utils/patch_embed.py +52 -0
  40. src/models/utils/pos_embs.py +93 -0
  41. src/models/vision_transformer.py +487 -0
  42. src/utils/__pycache__/tensors.cpython-312.pyc +0 -0
  43. src/utils/checkpoint_loader.py +37 -0
  44. src/utils/distributed.py +101 -0
  45. src/utils/logging.py +108 -0
  46. src/utils/monitoring.py +171 -0
  47. src/utils/schedulers.py +93 -0
  48. src/utils/tensors.py +53 -0
  49. src/utils/wrappers.py +43 -0
  50. tokenizer.json +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,74 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: cc-by-nc-4.0
4
+ tags:
5
+ - soccer
6
+ - video-qa
7
+ - question-answering
8
+ - vision-language
9
+ - multimodal
10
+ - sports-analysis
11
+ library_name: transformers
12
+ pipeline_tag: video-text-to-text
13
+ ---
14
+
15
+ # Soccer QA 4B - Soccer Video Question Answering Model
16
+
17
+ **⚠️ RESEARCH USE ONLY - NON-COMMERCIAL LICENSE**
18
+
19
+ Soccer QA 4B is a unified video question-answering model specifically designed for soccer video understanding.
20
+
21
+ ## Model Description
22
+
23
+ This model can answer questions about soccer videos by analyzing visual content and generating natural language responses.
24
+
25
+ **Example:**
26
+ - **Input**: Video + "What unfolded during the game in the video?"
27
+ - **Output**: "During the game, there was a foul committed by a player from the yellow-jerseyed team, leading to a yellow card being issued..."
28
+
29
+ ## Architecture
30
+ - **Vision Encoder**: DWT-VJEPA2-based video encoder (vit_giant, 1408 dim)
31
+ - **Text Model**: LLaMA 3.2-3B with LoRA fine-tuning
32
+ - **Vision-Text Bridge**: Learned projection layer (1408 → 2048 → 3072)
33
+ - **Specialization**: Fine-tuned on soccer video QA data
34
+
35
+ ## Usage (Helper functions are in repo)
36
+
37
+ ```python
38
+ from soccer_qa_inference import SoccerQA
39
+
40
+ model = SoccerQA("/path/to/model")
41
+ answer = model.ask("video.mp4", "Was this a Foul?", max_tokens=45)
42
+ print(answer)
43
+ ```
44
+
45
+ ## Model Details
46
+ - **Parameters**: ~4B total
47
+ - **Input**: Video files (16 frames, 256x256) + text questions
48
+ - **Output**: Natural language answers
49
+ - **Domain**: Soccer/football video analysis
50
+ - **Context**: Handles complex game situations, player actions, fouls, etc.
51
+
52
+ ## Training Data
53
+ - Soccer video clips with question-answer pairs
54
+ - Covers various game situations: fouls, shots, saves, player actions
55
+ - Annotated with detailed descriptions of game events
56
+
57
+ ## Limitations
58
+ - Research use only, no commercial applications
59
+ - Optimized specifically for soccer content
60
+ - May not generalize well to other sports or video domains
61
+ - Requires high-quality video input for best results
62
+
63
+ ## License
64
+ CC-BY-NC-4.0 - Research use only, no commercial applications permitted.
65
+
66
+ ## Citation
67
+ ```bibtex
68
+ @misc{soccer-qa-4b-2025,
69
+ title={Soccer QA 4B: Video Question Answering for Soccer Analysis},
70
+ author={Varun Kodathala, Sports Vision},
71
+ year={2025},
72
+ note={Research model for soccer video understanding}
73
+ }
74
+ ```
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "soccer_qa_4b",
3
+ "architectures": [
4
+ "SoccerQA4BModel"
5
+ ],
6
+ "vision_dim": 1408,
7
+ "projection_dim": 2048,
8
+ "text_dim": 3072,
9
+ "img_size": 256,
10
+ "num_frames": 16,
11
+ "max_length": 256,
12
+ "temperature": 0.7,
13
+ "imagenet_mean": [
14
+ 0.485,
15
+ 0.456,
16
+ 0.406
17
+ ],
18
+ "imagenet_std": [
19
+ 0.229,
20
+ 0.224,
21
+ 0.225
22
+ ],
23
+ "hidden_size": 3072,
24
+ "vocab_size": 128257,
25
+ "model_description": "Soccer video question answering model"
26
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f268918412af5ab623937ca776bf5c91eb26f04c3ff7e4cc257598aeda61b7cc
3
+ size 18512562808
soccer_qa_inference.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Soccer QA Inference - Single Class, Clean API
4
+
5
+ Usage in Colab:
6
+ from soccer_qa_inference import SoccerQA
7
+ model = SoccerQA("soccer-qa-3b-unified")
8
+ answer = model.ask("video.mp4", "What happened?", max_tokens=128)
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from safetensors.torch import load_file
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM
18
+ from decord import VideoReader
19
+
20
+ # Import your existing modules
21
+ import src.datasets.utils.video.transforms as video_transforms
22
+ import src.datasets.utils.video.volume_transforms as volume_transforms
23
+ from src.models.vision_transformer import vit_giant_rope
24
+
25
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
26
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
27
+
28
+ def get_video(fname, num_frames=16):
29
+ """Load video and sample frames uniformly"""
30
+ vr = VideoReader(fname)
31
+ frame_idx = np.linspace(0, len(vr) - 1, num=num_frames).astype(np.int64)
32
+ video = vr.get_batch(frame_idx).asnumpy()
33
+ return video
34
+
35
+ def build_video_transform(img_size):
36
+ """Build video preprocessing transform"""
37
+ short_side_size = int(256.0 / 224 * img_size)
38
+ eval_transform = video_transforms.Compose([
39
+ video_transforms.Resize(short_side_size, interpolation="bilinear"),
40
+ video_transforms.CenterCrop(size=(img_size, img_size)),
41
+ volume_transforms.ClipToTensor(),
42
+ video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
43
+ ])
44
+ return eval_transform
45
+
46
+ class SoccerQA:
47
+ """Single class for Soccer QA inference - Clean Colab API"""
48
+
49
+ def __init__(self, model_dir="/home/varunkodathala/jepa_llm/soccer_pretrain/soccer-qa-3b-unified"):
50
+ """Initialize Soccer QA model
51
+
52
+ Args:
53
+ model_dir: Path to merged model directory
54
+ """
55
+ self.model_dir = model_dir
56
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+
58
+ print(f"🚀 Loading Soccer QA from {model_dir}...")
59
+
60
+ # Load config and tokenizer
61
+ self._load_config()
62
+ self._load_tokenizer()
63
+
64
+ # Build models
65
+ self._build_vision_model()
66
+ self._build_text_model()
67
+ self._build_projection()
68
+
69
+ # Load all weights
70
+ self._load_weights()
71
+
72
+ # Build video transforms
73
+ self.video_transform = build_video_transform(self.img_size)
74
+
75
+ print("✅ Soccer QA ready!")
76
+
77
+ def _load_config(self):
78
+ """Load model configuration"""
79
+ config_path = os.path.join(self.model_dir, "config.json")
80
+ with open(config_path, 'r') as f:
81
+ self.config = json.load(f)
82
+
83
+ self.vision_dim = self.config["vision_dim"] # 1408
84
+ self.projection_dim = self.config["projection_dim"] # 2048
85
+ self.text_dim = self.config["text_dim"] # 3072
86
+ self.img_size = self.config["img_size"] # 256
87
+ self.num_frames = self.config["num_frames"] # 16
88
+
89
+ def _load_tokenizer(self):
90
+ """Load tokenizer with <video> token"""
91
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
92
+ if self.tokenizer.pad_token is None:
93
+ self.tokenizer.pad_token = self.tokenizer.eos_token
94
+
95
+ def _build_vision_model(self):
96
+ """Build vision transformer using your src modules"""
97
+ self.vision_model = vit_giant_rope(
98
+ img_size=(self.img_size, self.img_size),
99
+ num_frames=self.num_frames
100
+ )
101
+ self.vision_model.to(self.device).eval()
102
+
103
+ # Freeze vision model
104
+ for param in self.vision_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ def _build_text_model(self):
108
+ """Build text model - we'll load merged weights later"""
109
+ self.text_model = AutoModelForCausalLM.from_pretrained(
110
+ "meta-llama/Llama-3.2-3B",
111
+ torch_dtype=torch.float32,
112
+ device_map=self.device,
113
+ trust_remote_code=True
114
+ )
115
+
116
+ # Resize for <video> token to match saved model
117
+ self.text_model.resize_token_embeddings(len(self.tokenizer))
118
+ self.text_model.eval()
119
+
120
+ def _build_projection(self):
121
+ """Build vision projection layer"""
122
+ self.vision_projection = nn.Sequential(
123
+ nn.Linear(self.vision_dim, self.projection_dim), # 1408 -> 2048
124
+ nn.ReLU(),
125
+ nn.Dropout(0.1),
126
+ nn.Linear(self.projection_dim, self.text_dim), # 2048 -> 3072
127
+ nn.LayerNorm(self.text_dim)
128
+ ).to(self.device)
129
+
130
+ def _load_weights(self):
131
+ """Load all weights from safetensors - optimized approach"""
132
+ model_path = os.path.join(self.model_dir, "model.safetensors")
133
+ print(f"Loading weights from: {model_path}")
134
+ state_dict = load_file(model_path, device=str(self.device))
135
+
136
+ # Load vision encoder weights
137
+ vision_state = {}
138
+ for key, value in state_dict.items():
139
+ if key.startswith("vision_encoder."):
140
+ new_key = key.replace("vision_encoder.", "")
141
+ vision_state[new_key] = value
142
+
143
+ msg = self.vision_model.load_state_dict(vision_state, strict=False)
144
+ print(f"Vision model loaded: {msg}")
145
+
146
+ # Load projection weights
147
+ projection_state = {}
148
+ for key, value in state_dict.items():
149
+ if key.startswith("vision_projection."):
150
+ new_key = key.replace("vision_projection.", "")
151
+ projection_state[new_key] = value
152
+
153
+ self.vision_projection.load_state_dict(projection_state)
154
+ print("Projection layer loaded")
155
+
156
+ # Load text model weights directly from merged state_dict
157
+ text_state = {}
158
+ for key, value in state_dict.items():
159
+ if key.startswith("text_model."):
160
+ new_key = key.replace("text_model.", "")
161
+ text_state[new_key] = value
162
+
163
+ # Apply merged weights directly to text model
164
+ missing_keys, unexpected_keys = self.text_model.load_state_dict(text_state, strict=False)
165
+ if missing_keys:
166
+ print(f"Missing keys in text model: {len(missing_keys)} (this is normal)")
167
+ if unexpected_keys:
168
+ print(f"Unexpected keys in text model: {len(unexpected_keys)}")
169
+ print("✅ Text model loaded with merged weights")
170
+
171
+ # Clear state_dict from memory
172
+ del state_dict
173
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
174
+
175
+ def _get_video_embeddings(self, video_path):
176
+ """Extract video embeddings from video file"""
177
+ with torch.inference_mode():
178
+ # Load video
179
+ video = get_video(video_path, self.num_frames)
180
+ video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
181
+
182
+ # Preprocess
183
+ x = self.video_transform(video).to(self.device).unsqueeze(0) # [1, 16, 3, 256, 256]
184
+
185
+ # Extract features
186
+ features = self.vision_model(x) # [1, 2048, 1408]
187
+
188
+ # Handle reshaping
189
+ squeezed = features.squeeze(0) # [2048, 1408]
190
+ if squeezed.shape[0] % 2048 == 0:
191
+ num_clips = squeezed.shape[0] // 2048
192
+ reshaped = squeezed.view(num_clips, 2048, 1408)
193
+ else:
194
+ reshaped = squeezed.unsqueeze(0) # [1, 2048, 1408]
195
+
196
+ return reshaped
197
+
198
+ def _project_vision_features(self, vision_features):
199
+ """Project vision features to text embedding space"""
200
+ # vision_features: [num_clips, 2048, 1408]
201
+ num_clips, patches_per_clip, feature_dim = vision_features.shape
202
+
203
+ # Flatten: [num_clips * 2048, 1408]
204
+ flattened = vision_features.view(-1, feature_dim)
205
+
206
+ # Project: [num_clips * 2048, 3072]
207
+ projected = self.vision_projection(flattened)
208
+
209
+ # Return flattened for sequence: [total_patches, 3072]
210
+ return projected
211
+
212
+ def ask(self, video_path, question, max_tokens=128, temperature=0.7, top_p=0.9,
213
+ repetition_penalty=1.2, no_repeat_ngram_size=3):
214
+ """Ask a question about a video
215
+
216
+ Args:
217
+ video_path: Path to video file
218
+ question: Question about the video
219
+ max_tokens: Maximum tokens to generate
220
+ temperature: Sampling temperature
221
+ top_p: Nucleus sampling parameter
222
+ repetition_penalty: Penalty for repetition
223
+ no_repeat_ngram_size: N-gram size for repetition blocking
224
+
225
+ Returns:
226
+ Generated answer as string
227
+ """
228
+ with torch.no_grad():
229
+ # Get video embeddings
230
+ video_features = self._get_video_embeddings(video_path) # [num_clips, 2048, 1408]
231
+ vision_embeds = self._project_vision_features(video_features) # [total_patches, 3072]
232
+ vision_embeds = vision_embeds.unsqueeze(0) # [1, total_patches, 3072]
233
+
234
+ # Process question (remove <video> token if present)
235
+ question_clean = question.replace("<video>", "").strip()
236
+
237
+ # Tokenize question
238
+ question_tokens = self.tokenizer(
239
+ question_clean,
240
+ return_tensors="pt",
241
+ add_special_tokens=True
242
+ ).to(self.device)
243
+
244
+ # Get text embeddings
245
+ text_embeds = self.text_model.get_input_embeddings()(question_tokens.input_ids)
246
+
247
+ # Combine vision and text embeddings
248
+ combined_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
249
+
250
+ # Create attention mask
251
+ vision_attention = torch.ones(
252
+ 1, vision_embeds.shape[1],
253
+ dtype=question_tokens.attention_mask.dtype,
254
+ device=self.device
255
+ )
256
+ combined_attention_mask = torch.cat([vision_attention, question_tokens.attention_mask], dim=1)
257
+
258
+ # Generate response
259
+ generated_ids = self.text_model.generate(
260
+ inputs_embeds=combined_embeds,
261
+ attention_mask=combined_attention_mask,
262
+ max_new_tokens=max_tokens,
263
+ temperature=temperature,
264
+ top_p=top_p,
265
+ do_sample=True,
266
+ pad_token_id=self.tokenizer.pad_token_id,
267
+ eos_token_id=self.tokenizer.eos_token_id,
268
+ repetition_penalty=repetition_penalty,
269
+ no_repeat_ngram_size=no_repeat_ngram_size,
270
+ use_cache=True,
271
+ return_dict_in_generate=False
272
+ )
273
+
274
+ # Handle different return formats from generate()
275
+ if generated_ids.shape[1] > combined_embeds.shape[1]:
276
+ # Full sequence returned - slice from combined length
277
+ new_tokens = generated_ids[:, combined_embeds.shape[1]:]
278
+ else:
279
+ # Only new tokens returned - use all
280
+ new_tokens = generated_ids
281
+
282
+ generated_text = self.tokenizer.batch_decode(
283
+ new_tokens,
284
+ skip_special_tokens=True
285
+ )[0]
286
+
287
+ return generated_text.strip()
288
+
289
+ def batch_ask(self, video_path, questions, **kwargs):
290
+ """Ask multiple questions about the same video
291
+
292
+ Args:
293
+ video_path: Path to video file
294
+ questions: List of questions
295
+ **kwargs: Generation parameters
296
+
297
+ Returns:
298
+ List of {"question": str, "answer": str} dicts
299
+ """
300
+ results = []
301
+ for question in questions:
302
+ answer = self.ask(video_path, question, **kwargs)
303
+ results.append({"question": question, "answer": answer})
304
+ return results
special_tokens_map.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<video>"
4
+ ],
5
+ "bos_token": {
6
+ "content": "<|begin_of_text|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "eos_token": {
13
+ "content": "<|end_of_text|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "pad_token": {
20
+ "content": "<|end_of_text|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ }
26
+ }
src/datasets/data_manager.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from logging import getLogger
7
+
8
+ _GLOBAL_SEED = 0
9
+ logger = getLogger()
10
+
11
+
12
+ def init_data(
13
+ batch_size,
14
+ transform=None,
15
+ shared_transform=None,
16
+ data="ImageNet",
17
+ collator=None,
18
+ pin_mem=True,
19
+ num_workers=8,
20
+ world_size=1,
21
+ rank=0,
22
+ root_path=None,
23
+ image_folder=None,
24
+ training=True,
25
+ drop_last=True,
26
+ subset_file=None,
27
+ clip_len=None,
28
+ dataset_fpcs=None,
29
+ frame_sample_rate=None,
30
+ duration=None,
31
+ fps=None,
32
+ num_clips=1,
33
+ random_clip_sampling=True,
34
+ allow_clip_overlap=False,
35
+ filter_short_videos=False,
36
+ filter_long_videos=int(1e9),
37
+ datasets_weights=None,
38
+ persistent_workers=False,
39
+ deterministic=True,
40
+ log_dir=None,
41
+ ):
42
+ if data.lower() == "imagenet":
43
+ from src.datasets.imagenet1k import make_imagenet1k
44
+
45
+ dataset, data_loader, dist_sampler = make_imagenet1k(
46
+ transform=transform,
47
+ batch_size=batch_size,
48
+ collator=collator,
49
+ pin_mem=pin_mem,
50
+ training=training,
51
+ num_workers=num_workers,
52
+ world_size=world_size,
53
+ rank=rank,
54
+ root_path=root_path,
55
+ image_folder=image_folder,
56
+ persistent_workers=persistent_workers,
57
+ drop_last=drop_last,
58
+ subset_file=subset_file,
59
+ )
60
+
61
+ elif data.lower() == "videodataset":
62
+ from src.datasets.video_dataset import make_videodataset
63
+
64
+ dataset, data_loader, dist_sampler = make_videodataset(
65
+ data_paths=root_path,
66
+ batch_size=batch_size,
67
+ frames_per_clip=clip_len,
68
+ dataset_fpcs=dataset_fpcs,
69
+ frame_step=frame_sample_rate,
70
+ duration=duration,
71
+ fps=fps,
72
+ num_clips=num_clips,
73
+ random_clip_sampling=random_clip_sampling,
74
+ allow_clip_overlap=allow_clip_overlap,
75
+ filter_short_videos=filter_short_videos,
76
+ filter_long_videos=filter_long_videos,
77
+ shared_transform=shared_transform,
78
+ transform=transform,
79
+ datasets_weights=datasets_weights,
80
+ collator=collator,
81
+ num_workers=num_workers,
82
+ pin_mem=pin_mem,
83
+ persistent_workers=persistent_workers,
84
+ world_size=world_size,
85
+ rank=rank,
86
+ deterministic=deterministic,
87
+ log_dir=log_dir,
88
+ )
89
+
90
+ return (data_loader, dist_sampler)
src/datasets/imagenet1k.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import subprocess
8
+ import time
9
+ from logging import getLogger
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torchvision
14
+
15
+ _GLOBAL_SEED = 0
16
+ logger = getLogger()
17
+
18
+
19
+ class ImageNet(torchvision.datasets.ImageFolder):
20
+
21
+ def __init__(
22
+ self,
23
+ root,
24
+ image_folder="imagenet_full_size/061417/",
25
+ tar_file="imagenet_full_size-061417.tar.gz",
26
+ transform=None,
27
+ train=True,
28
+ job_id=None,
29
+ local_rank=None,
30
+ index_targets=False,
31
+ ):
32
+ """
33
+ ImageNet
34
+
35
+ Dataset wrapper
36
+
37
+ :param root: root network directory for ImageNet data
38
+ :param image_folder: path to images inside root network directory
39
+ :param tar_file: zipped image_folder inside root network directory
40
+ :param train: whether to load train data (or validation)
41
+ :param job_id: scheduler job-id used to create dir on local machine
42
+ :param index_targets: whether to index the id of each labeled image
43
+ """
44
+
45
+ suffix = "train/" if train else "val/"
46
+ data_path = os.path.join(root, image_folder, suffix)
47
+ logger.info(f"data-path {data_path}")
48
+
49
+ super(ImageNet, self).__init__(root=data_path, transform=transform)
50
+ logger.info("Initialized ImageNet")
51
+
52
+ if index_targets:
53
+ self.targets = []
54
+ for sample in self.samples:
55
+ self.targets.append(sample[1])
56
+ self.targets = np.array(self.targets)
57
+ self.samples = np.array(self.samples)
58
+
59
+ mint = None
60
+ self.target_indices = []
61
+ for t in range(len(self.classes)):
62
+ indices = np.squeeze(np.argwhere(self.targets == t)).tolist()
63
+ self.target_indices.append(indices)
64
+ mint = len(indices) if mint is None else min(mint, len(indices))
65
+ logger.debug(f"num-labeled target {t} {len(indices)}")
66
+ logger.info(f"min. labeled indices {mint}")
67
+
68
+
69
+ class ImageNetSubset(object):
70
+
71
+ def __init__(self, dataset, subset_file):
72
+ """
73
+ ImageNetSubset
74
+
75
+ :param dataset: ImageNet dataset object
76
+ :param subset_file: '.txt' file containing IDs of IN1K images to keep
77
+ """
78
+ self.dataset = dataset
79
+ self.subset_file = subset_file
80
+ self.filter_dataset_(subset_file)
81
+
82
+ def filter_dataset_(self, subset_file):
83
+ """Filter self.dataset to a subset"""
84
+ root = self.dataset.root
85
+ class_to_idx = self.dataset.class_to_idx
86
+ # -- update samples to subset of IN1k targets/samples
87
+ new_samples = []
88
+ logger.info(f"Using {subset_file}")
89
+ with open(subset_file, "r") as rfile:
90
+ for line in rfile:
91
+ class_name = line.split("_")[0]
92
+ target = class_to_idx[class_name]
93
+ img = line.split("\n")[0]
94
+ new_samples.append((os.path.join(root, class_name, img), target))
95
+ self.samples = new_samples
96
+
97
+ @property
98
+ def classes(self):
99
+ return self.dataset.classes
100
+
101
+ def __len__(self):
102
+ return len(self.samples)
103
+
104
+ def __getitem__(self, index):
105
+ path, target = self.samples[index]
106
+ img = self.dataset.loader(path)
107
+ if self.dataset.transform is not None:
108
+ img = self.dataset.transform(img)
109
+ if self.dataset.target_transform is not None:
110
+ target = self.dataset.target_transform(target)
111
+ return img, target
112
+
113
+
114
+ def make_imagenet1k(
115
+ transform,
116
+ batch_size,
117
+ collator=None,
118
+ pin_mem=True,
119
+ num_workers=8,
120
+ world_size=1,
121
+ rank=0,
122
+ root_path=None,
123
+ image_folder=None,
124
+ training=True,
125
+ drop_last=True,
126
+ persistent_workers=False,
127
+ subset_file=None,
128
+ ):
129
+ dataset = ImageNet(
130
+ root=root_path,
131
+ image_folder=image_folder,
132
+ transform=transform,
133
+ train=training,
134
+ index_targets=False,
135
+ )
136
+ if subset_file is not None:
137
+ dataset = ImageNetSubset(dataset, subset_file)
138
+ logger.info("ImageNet dataset created")
139
+ dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank)
140
+ data_loader = torch.utils.data.DataLoader(
141
+ dataset,
142
+ collate_fn=collator,
143
+ sampler=dist_sampler,
144
+ batch_size=batch_size,
145
+ drop_last=drop_last,
146
+ pin_memory=pin_mem,
147
+ num_workers=num_workers,
148
+ persistent_workers=persistent_workers,
149
+ )
150
+ logger.info("ImageNet unsupervised data loader created")
151
+
152
+ return dataset, data_loader, dist_sampler
src/datasets/utils/dataloader.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import bisect
7
+ import csv
8
+ import io
9
+ import time
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import _utils
14
+ from torch.utils.data.dataloader import ExceptionWrapper, _DatasetKind, _MultiProcessingDataLoaderIter
15
+
16
+ from src.utils.monitoring import ResourceMonitoringThread
17
+
18
+
19
+ class ConcatIndices:
20
+ """Helper to map indices of concatenated/mixed datasets to the sample index for the corresponding dataset."""
21
+
22
+ cumulative_sizes: np.ndarray
23
+
24
+ def __init__(self, sizes):
25
+ self.cumulative_sizes = np.cumsum(sizes)
26
+
27
+ def __len__(self):
28
+ return self.cumulative_sizes[-1]
29
+
30
+ def __getitem__(self, idx):
31
+ # Returns a pair (dataset_idx, sample_idx)
32
+ if idx < 0 or idx >= len(self):
33
+ raise ValueError(f"index must be between 0 and the total size ({len(self)})")
34
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
35
+ if dataset_idx == 0:
36
+ return dataset_idx, idx
37
+ return dataset_idx, idx - self.cumulative_sizes[dataset_idx - 1]
38
+
39
+
40
+ class CSVLogger(object):
41
+ """An append-to CSV abstraction. File I/O requires a flush."""
42
+
43
+ def __init__(self, fname, header):
44
+ """Write header to internal buffers."""
45
+ self.fname = fname
46
+ self.buffer = io.StringIO()
47
+ self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC)
48
+ self.writer.writerow(header)
49
+ self.initialized = False
50
+
51
+ def writerow(self, row) -> None:
52
+ """Write row to internal buffers."""
53
+ self.writer.writerow(row)
54
+
55
+ def flush(self) -> None:
56
+ """Flush buffer to file."""
57
+ # Overwrite old file
58
+ mode = "a+" if self.initialized else "w"
59
+
60
+ with open(self.fname, mode, newline="") as f:
61
+ f.write(self.buffer.getvalue())
62
+
63
+ self.buffer = io.StringIO()
64
+ self.writer = csv.writer(self.buffer, quoting=csv.QUOTE_NONNUMERIC)
65
+ self.initialized = True
66
+
67
+
68
+ class MonitoredDataset(torch.utils.data.Dataset):
69
+ """Implement resource monitoring on a per-worker basis.
70
+
71
+ The sampling occurs every monitor_interval seconds and writes the log
72
+ every log_interval seconds to a file specified by log_filename, which
73
+ maps a worker id to a file using the '%w' placeholder.
74
+
75
+ Warning: Do not call this dataset before it is consumed in the DataLoader.
76
+ """
77
+
78
+ def __init__(
79
+ self, dataset: torch.utils.data.Dataset, log_filename: str, log_interval: float, monitor_interval: float
80
+ ):
81
+ self.dataset = dataset
82
+ self.log_filename = str(log_filename)
83
+ self.log_interval = log_interval
84
+ self.monitor_interval = monitor_interval
85
+ self._csv_log = None
86
+ self._monitoring_thread = None
87
+ self._last_log_time = None
88
+ # Patch getitems dynamically
89
+ if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
90
+
91
+ def __getitems__(self, index):
92
+ self.maybe_start_resource_monitoring()
93
+ return self.dataset.__getitems__(index)
94
+
95
+ self.__getitems__ = __getitems__
96
+
97
+ def __del__(self):
98
+ self.stop_resource_monitoring()
99
+
100
+ def __getitem__(self, index):
101
+ self.maybe_start_resource_monitoring()
102
+ return self.dataset.__getitem__(index)
103
+
104
+ def __len__(self):
105
+ return len(self.dataset)
106
+
107
+ def _elapsed_log_time(self):
108
+ if self._last_log_time is None:
109
+ return float("inf")
110
+ else:
111
+ return time.perf_counter() - self._last_log_time
112
+
113
+ def _update_log_time(self):
114
+ self._last_log_time = time.perf_counter()
115
+
116
+ def maybe_start_resource_monitoring(self):
117
+ if self._monitoring_thread is None:
118
+
119
+ def callback_fn(resource_sample):
120
+ worker_info = torch.utils.data.get_worker_info()
121
+ worker_id = worker_info.id
122
+
123
+ if self._csv_log is None:
124
+ header = [f.name for f in resource_sample.fields()]
125
+ log_filename = self.log_filename.replace("%w", str(worker_id))
126
+ self._csv_log = CSVLogger(log_filename, header)
127
+ row_values = resource_sample.as_tuple()
128
+ self._csv_log.writerow(row_values)
129
+
130
+ if self._elapsed_log_time() > self.log_interval:
131
+ self._csv_log.flush()
132
+ self._update_log_time()
133
+
134
+ self._monitoring_thread = ResourceMonitoringThread(
135
+ None, self.monitor_interval, stats_callback_fn=callback_fn
136
+ )
137
+ self._monitoring_thread.start()
138
+
139
+ def stop_resource_monitoring(self):
140
+ if self._monitoring_thread:
141
+ self._monitoring_thread.stop()
142
+
143
+
144
+ class NondeterministicDataLoader(torch.utils.data.DataLoader):
145
+ """Override torch dataloader to return out of order."""
146
+
147
+ def __init__(self, *args, **kwargs):
148
+ """Pass through constructor."""
149
+ super().__init__(*args, **kwargs)
150
+
151
+ def _get_iterator(self):
152
+ if self.num_workers:
153
+ self.check_worker_number_rationality()
154
+ return _SloppyMultiProcessingDataLoaderIter(self)
155
+ else:
156
+ return super()._get_iterator()
157
+
158
+
159
+ class _SloppyMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
160
+
161
+ def __init__(self, *args, **kwargs):
162
+ """Pass through constructor."""
163
+ super().__init__(*args, **kwargs)
164
+
165
+ def _next_data(self):
166
+ """Adds out of order returns."""
167
+ while True:
168
+ # If the worker responsible for `self._rcvd_idx` has already ended
169
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
170
+ # we try to advance `self._rcvd_idx` to find the next valid index.
171
+ #
172
+ # This part needs to run in the loop because both the `self._get_data()`
173
+ # call and `_IterableDatasetStopIteration` check below can mark
174
+ # extra worker(s) as dead.
175
+ while self._rcvd_idx < self._send_idx:
176
+ info = self._task_info[self._rcvd_idx]
177
+ if info is None:
178
+ # Found a reordered tombstone
179
+ del self._task_info[self._rcvd_idx]
180
+ self._rcvd_idx += 1
181
+ self._try_put_index()
182
+ else:
183
+ worker_id = info[0]
184
+ # has data or is still active
185
+ if len(info) == 2 or self._workers_status[worker_id]:
186
+ break
187
+ del self._task_info[self._rcvd_idx]
188
+ self._rcvd_idx += 1
189
+ else:
190
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
191
+ if not self._persistent_workers:
192
+ self._shutdown_workers()
193
+ raise StopIteration
194
+
195
+ # Now `self._rcvd_idx` is the batch index we want to fetch
196
+
197
+ # Check if the next sample has already been generated
198
+ if len(self._task_info[self._rcvd_idx]) == 2:
199
+ data = self._task_info.pop(self._rcvd_idx)[1]
200
+ return self._process_data(data)
201
+
202
+ assert not self._shutdown and self._tasks_outstanding > 0
203
+ idx, data = self._get_data()
204
+ self._tasks_outstanding -= 1
205
+ if self._dataset_kind == _DatasetKind.Iterable:
206
+ # Check for _IterableDatasetStopIteration
207
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
208
+ if self._persistent_workers:
209
+ self._workers_status[data.worker_id] = False
210
+ else:
211
+ self._mark_worker_as_unavailable(data.worker_id)
212
+ self._try_put_index()
213
+ continue
214
+
215
+ if idx != self._rcvd_idx:
216
+ # Tombstone to recieve later
217
+ self._task_info[idx] = None
218
+ if isinstance(data, ExceptionWrapper):
219
+ data.reraise()
220
+ return data
221
+ else:
222
+ del self._task_info[idx]
223
+ return self._process_data(data)
224
+
225
+
226
+ def get_worker_info():
227
+ worker_info = torch.utils.data.get_worker_info()
228
+ if worker_info is None:
229
+ num_workers = 1
230
+ worker_id = 0
231
+ else:
232
+ num_workers = worker_info.num_workers
233
+ worker_id = worker_info.id
234
+ return num_workers, worker_id
src/datasets/utils/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from src.utils.cluster import dataset_paths
7
+ from src.utils.logging import get_logger
8
+
9
+ logger = get_logger("Datasets utils")
10
+
11
+
12
+ def get_dataset_paths(datasets: list[str]):
13
+ paths = []
14
+ for d in datasets:
15
+ try:
16
+ path = dataset_paths().get(d)
17
+ except Exception:
18
+ raise Exception(f"Unknown dataset: {d}")
19
+ paths.append(path)
20
+ logger.info(f"Datapaths {paths}")
21
+ return paths
src/datasets/utils/video/__pycache__/functional.cpython-312.pyc ADDED
Binary file (5.52 kB). View file
 
src/datasets/utils/video/__pycache__/randaugment.cpython-312.pyc ADDED
Binary file (18.1 kB). View file
 
src/datasets/utils/video/__pycache__/transforms.cpython-312.pyc ADDED
Binary file (53.7 kB). View file
 
src/datasets/utils/video/__pycache__/volume_transforms.cpython-312.pyc ADDED
Binary file (6.57 kB). View file
 
src/datasets/utils/video/functional.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numbers
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import PIL
11
+ import torch
12
+ from torchvision.transforms import functional as tvf
13
+
14
+
15
+ def _is_tensor_clip(clip):
16
+ return torch.is_tensor(clip) and clip.ndimension() == 4
17
+
18
+
19
+ def crop_clip(clip, min_h, min_w, h, w):
20
+ if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
21
+ if clip[0].shape[-1] == 3:
22
+ cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip]
23
+ else:
24
+ assert clip[0].shape[0] == 3
25
+ cropped = [img[:, min_h : min_h + h, min_w : min_w + w] for img in clip]
26
+
27
+ elif isinstance(clip[0], PIL.Image.Image):
28
+ cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip]
29
+
30
+ else:
31
+ raise TypeError(
32
+ "Expected numpy.ndarray or PIL.Image or torch.Tensor):" + "but got list of {0}".format(type(clip[0]))
33
+ )
34
+ return cropped
35
+
36
+
37
+ def resize_clip(clip, size, interpolation="bilinear"):
38
+ if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
39
+ if isinstance(size, numbers.Number):
40
+ if clip[0].shape[-1] == 3:
41
+ im_h, im_w, im_c = clip[0].shape
42
+ else:
43
+ assert clip[0].shape[0] == 3
44
+ im_c, im_h, im_w = clip[0].shape
45
+ # Min spatial dim already matches minimal size
46
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
47
+ return clip
48
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
49
+ size = (new_w, new_h)
50
+ else:
51
+ size = size[0], size[1]
52
+
53
+ if isinstance(clip[0], np.ndarray):
54
+ if interpolation == "bilinear":
55
+ np_inter = cv2.INTER_LINEAR
56
+ else:
57
+ np_inter = cv2.INTER_NEAREST
58
+ scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip]
59
+ else: # isinstance(clip[0], torch.Tensor)
60
+ if interpolation == "bilinear":
61
+ np_inter = tvf.InterpolationMode.BILINEAR
62
+ else:
63
+ np_inter = tvf.InterpolationMode.NEAREST
64
+ size = (size[1], size[0]) # torchvision transformers expect the size in (h, w) order.
65
+ scaled = [tvf.resize(img, size, interpolation=np_inter) for img in clip]
66
+ elif isinstance(clip[0], PIL.Image.Image):
67
+ if isinstance(size, numbers.Number):
68
+ im_w, im_h = clip[0].size
69
+ # Min spatial dim already matches minimal size
70
+ if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
71
+ return clip
72
+ new_h, new_w = get_resize_sizes(im_h, im_w, size)
73
+ size = (new_w, new_h)
74
+ else:
75
+ size = size[1], size[0]
76
+ if interpolation == "bilinear":
77
+ pil_inter = PIL.Image.BILINEAR
78
+ else:
79
+ pil_inter = PIL.Image.NEAREST
80
+ scaled = [img.resize(size, pil_inter) for img in clip]
81
+ else:
82
+ raise TypeError(
83
+ "Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0]))
84
+ )
85
+ return scaled
86
+
87
+
88
+ def get_resize_sizes(im_h, im_w, size):
89
+ if im_w < im_h:
90
+ ow = size
91
+ oh = int(size * im_h / im_w)
92
+ else:
93
+ oh = size
94
+ ow = int(size * im_w / im_h)
95
+ return oh, ow
96
+
97
+
98
+ def normalize(clip, mean, std, inplace=False):
99
+ if not _is_tensor_clip(clip):
100
+ raise TypeError("tensor is not a torch clip.")
101
+
102
+ if not inplace:
103
+ clip = clip.clone()
104
+
105
+ dtype = clip.dtype
106
+ mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
107
+ std = torch.as_tensor(std, dtype=dtype, device=clip.device)
108
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
109
+
110
+ return clip
src/datasets/utils/video/randaugment.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Copyright 2020 Ross Wightman
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # This implementation is based on
18
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
19
+ # published under an Apache License 2.0.
20
+
21
+ # COMMENT FROM ORIGINAL:
22
+ # AutoAugment, RandAugment, and AugMix for PyTorch
23
+ # This code implements the searched ImageNet policies with various tweaks and
24
+ # improvements and does not include any of the search code. AA and RA
25
+ # Implementation adapted from:
26
+ # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
27
+ # AugMix adapted from:
28
+ # https://github.com/google-research/augmix
29
+ # Papers:
30
+ # AutoAugment: Learning Augmentation Policies from Data
31
+ # https://arxiv.org/abs/1805.09501
32
+ # Learning Data Augmentation Strategies for Object Detection
33
+ # https://arxiv.org/abs/1906.11172
34
+ # RandAugment: Practical automated data augmentation...
35
+ # https://arxiv.org/abs/1909.13719
36
+ # AugMix: A Simple Data Processing Method to Improve Robustness and
37
+ # Uncertainty https://arxiv.org/abs/1912.02781
38
+
39
+ import math
40
+ import random
41
+ import re
42
+
43
+ import numpy as np
44
+ import PIL
45
+ from PIL import Image, ImageEnhance, ImageOps
46
+
47
+ _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
48
+
49
+ _FILL = (128, 128, 128)
50
+
51
+ # This signifies the max integer that the controller RNN could predict for the
52
+ # augmentation scheme.
53
+ _MAX_LEVEL = 10.0
54
+
55
+ _HPARAMS_DEFAULT = {
56
+ "translate_const": 250,
57
+ "img_mean": _FILL,
58
+ }
59
+
60
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
61
+
62
+
63
+ def _interpolation(kwargs):
64
+ interpolation = kwargs.pop("resample", Image.BILINEAR)
65
+ if isinstance(interpolation, (list, tuple)):
66
+ return random.choice(interpolation)
67
+ else:
68
+ return interpolation
69
+
70
+
71
+ def _check_args_tf(kwargs):
72
+ if "fillcolor" in kwargs and _PIL_VER < (5, 0):
73
+ kwargs.pop("fillcolor")
74
+ kwargs["resample"] = _interpolation(kwargs)
75
+
76
+
77
+ def shear_x(img, factor, **kwargs):
78
+ _check_args_tf(kwargs)
79
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
80
+
81
+
82
+ def shear_y(img, factor, **kwargs):
83
+ _check_args_tf(kwargs)
84
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
85
+
86
+
87
+ def translate_x_rel(img, pct, **kwargs):
88
+ pixels = pct * img.size[0]
89
+ _check_args_tf(kwargs)
90
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
91
+
92
+
93
+ def translate_y_rel(img, pct, **kwargs):
94
+ pixels = pct * img.size[1]
95
+ _check_args_tf(kwargs)
96
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
97
+
98
+
99
+ def translate_x_abs(img, pixels, **kwargs):
100
+ _check_args_tf(kwargs)
101
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
102
+
103
+
104
+ def translate_y_abs(img, pixels, **kwargs):
105
+ _check_args_tf(kwargs)
106
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
107
+
108
+
109
+ def rotate(img, degrees, **kwargs):
110
+ _check_args_tf(kwargs)
111
+ if _PIL_VER >= (5, 2):
112
+ return img.rotate(degrees, **kwargs)
113
+ elif _PIL_VER >= (5, 0):
114
+ w, h = img.size
115
+ post_trans = (0, 0)
116
+ rotn_center = (w / 2.0, h / 2.0)
117
+ angle = -math.radians(degrees)
118
+ matrix = [
119
+ round(math.cos(angle), 15),
120
+ round(math.sin(angle), 15),
121
+ 0.0,
122
+ round(-math.sin(angle), 15),
123
+ round(math.cos(angle), 15),
124
+ 0.0,
125
+ ]
126
+
127
+ def transform(x, y, matrix):
128
+ (a, b, c, d, e, f) = matrix
129
+ return a * x + b * y + c, d * x + e * y + f
130
+
131
+ matrix[2], matrix[5] = transform(
132
+ -rotn_center[0] - post_trans[0],
133
+ -rotn_center[1] - post_trans[1],
134
+ matrix,
135
+ )
136
+ matrix[2] += rotn_center[0]
137
+ matrix[5] += rotn_center[1]
138
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
139
+ else:
140
+ return img.rotate(degrees, resample=kwargs["resample"])
141
+
142
+
143
+ def auto_contrast(img, **__):
144
+ return ImageOps.autocontrast(img)
145
+
146
+
147
+ def invert(img, **__):
148
+ return ImageOps.invert(img)
149
+
150
+
151
+ def equalize(img, **__):
152
+ return ImageOps.equalize(img)
153
+
154
+
155
+ def solarize(img, thresh, **__):
156
+ return ImageOps.solarize(img, thresh)
157
+
158
+
159
+ def solarize_add(img, add, thresh=128, **__):
160
+ lut = []
161
+ for i in range(256):
162
+ if i < thresh:
163
+ lut.append(min(255, i + add))
164
+ else:
165
+ lut.append(i)
166
+ if img.mode in ("L", "RGB"):
167
+ if img.mode == "RGB" and len(lut) == 256:
168
+ lut = lut + lut + lut
169
+ return img.point(lut)
170
+ else:
171
+ return img
172
+
173
+
174
+ def posterize(img, bits_to_keep, **__):
175
+ if bits_to_keep >= 8:
176
+ return img
177
+ return ImageOps.posterize(img, bits_to_keep)
178
+
179
+
180
+ def contrast(img, factor, **__):
181
+ return ImageEnhance.Contrast(img).enhance(factor)
182
+
183
+
184
+ def color(img, factor, **__):
185
+ return ImageEnhance.Color(img).enhance(factor)
186
+
187
+
188
+ def brightness(img, factor, **__):
189
+ return ImageEnhance.Brightness(img).enhance(factor)
190
+
191
+
192
+ def sharpness(img, factor, **__):
193
+ return ImageEnhance.Sharpness(img).enhance(factor)
194
+
195
+
196
+ def _randomly_negate(v):
197
+ """With 50% prob, negate the value"""
198
+ return -v if random.random() > 0.5 else v
199
+
200
+
201
+ def _rotate_level_to_arg(level, _hparams):
202
+ # range [-30, 30]
203
+ level = (level / _MAX_LEVEL) * 30.0
204
+ level = _randomly_negate(level)
205
+ return (level,)
206
+
207
+
208
+ def _enhance_level_to_arg(level, _hparams):
209
+ # range [0.1, 1.9]
210
+ return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
211
+
212
+
213
+ def _enhance_increasing_level_to_arg(level, _hparams):
214
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
215
+ # range [0.1, 1.9]
216
+ level = (level / _MAX_LEVEL) * 0.9
217
+ level = 1.0 + _randomly_negate(level)
218
+ return (level,)
219
+
220
+
221
+ def _shear_level_to_arg(level, _hparams):
222
+ # range [-0.3, 0.3]
223
+ level = (level / _MAX_LEVEL) * 0.3
224
+ level = _randomly_negate(level)
225
+ return (level,)
226
+
227
+
228
+ def _translate_abs_level_to_arg(level, hparams):
229
+ translate_const = hparams["translate_const"]
230
+ level = (level / _MAX_LEVEL) * float(translate_const)
231
+ level = _randomly_negate(level)
232
+ return (level,)
233
+
234
+
235
+ def _translate_rel_level_to_arg(level, hparams):
236
+ # default range [-0.45, 0.45]
237
+ translate_pct = hparams.get("translate_pct", 0.45)
238
+ level = (level / _MAX_LEVEL) * translate_pct
239
+ level = _randomly_negate(level)
240
+ return (level,)
241
+
242
+
243
+ def _posterize_level_to_arg(level, _hparams):
244
+ # As per Tensorflow TPU EfficientNet impl
245
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
246
+ # intensity/severity of augmentation decreases with level
247
+ return (int((level / _MAX_LEVEL) * 4),)
248
+
249
+
250
+ def _posterize_increasing_level_to_arg(level, hparams):
251
+ # As per Tensorflow models research and UDA impl
252
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
253
+ # intensity/severity of augmentation increases with level
254
+ return (4 - _posterize_level_to_arg(level, hparams)[0],)
255
+
256
+
257
+ def _posterize_original_level_to_arg(level, _hparams):
258
+ # As per original AutoAugment paper description
259
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
260
+ # intensity/severity of augmentation decreases with level
261
+ return (int((level / _MAX_LEVEL) * 4) + 4,)
262
+
263
+
264
+ def _solarize_level_to_arg(level, _hparams):
265
+ # range [0, 256]
266
+ # intensity/severity of augmentation decreases with level
267
+ return (int((level / _MAX_LEVEL) * 256),)
268
+
269
+
270
+ def _solarize_increasing_level_to_arg(level, _hparams):
271
+ # range [0, 256]
272
+ # intensity/severity of augmentation increases with level
273
+ return (256 - _solarize_level_to_arg(level, _hparams)[0],)
274
+
275
+
276
+ def _solarize_add_level_to_arg(level, _hparams):
277
+ # range [0, 110]
278
+ return (int((level / _MAX_LEVEL) * 110),)
279
+
280
+
281
+ LEVEL_TO_ARG = {
282
+ "AutoContrast": None,
283
+ "Equalize": None,
284
+ "Invert": None,
285
+ "Rotate": _rotate_level_to_arg,
286
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
287
+ "Posterize": _posterize_level_to_arg,
288
+ "PosterizeIncreasing": _posterize_increasing_level_to_arg,
289
+ "PosterizeOriginal": _posterize_original_level_to_arg,
290
+ "Solarize": _solarize_level_to_arg,
291
+ "SolarizeIncreasing": _solarize_increasing_level_to_arg,
292
+ "SolarizeAdd": _solarize_add_level_to_arg,
293
+ "Color": _enhance_level_to_arg,
294
+ "ColorIncreasing": _enhance_increasing_level_to_arg,
295
+ "Contrast": _enhance_level_to_arg,
296
+ "ContrastIncreasing": _enhance_increasing_level_to_arg,
297
+ "Brightness": _enhance_level_to_arg,
298
+ "BrightnessIncreasing": _enhance_increasing_level_to_arg,
299
+ "Sharpness": _enhance_level_to_arg,
300
+ "SharpnessIncreasing": _enhance_increasing_level_to_arg,
301
+ "ShearX": _shear_level_to_arg,
302
+ "ShearY": _shear_level_to_arg,
303
+ "TranslateX": _translate_abs_level_to_arg,
304
+ "TranslateY": _translate_abs_level_to_arg,
305
+ "TranslateXRel": _translate_rel_level_to_arg,
306
+ "TranslateYRel": _translate_rel_level_to_arg,
307
+ }
308
+
309
+
310
+ NAME_TO_OP = {
311
+ "AutoContrast": auto_contrast,
312
+ "Equalize": equalize,
313
+ "Invert": invert,
314
+ "Rotate": rotate,
315
+ "Posterize": posterize,
316
+ "PosterizeIncreasing": posterize,
317
+ "PosterizeOriginal": posterize,
318
+ "Solarize": solarize,
319
+ "SolarizeIncreasing": solarize,
320
+ "SolarizeAdd": solarize_add,
321
+ "Color": color,
322
+ "ColorIncreasing": color,
323
+ "Contrast": contrast,
324
+ "ContrastIncreasing": contrast,
325
+ "Brightness": brightness,
326
+ "BrightnessIncreasing": brightness,
327
+ "Sharpness": sharpness,
328
+ "SharpnessIncreasing": sharpness,
329
+ "ShearX": shear_x,
330
+ "ShearY": shear_y,
331
+ "TranslateX": translate_x_abs,
332
+ "TranslateY": translate_y_abs,
333
+ "TranslateXRel": translate_x_rel,
334
+ "TranslateYRel": translate_y_rel,
335
+ }
336
+
337
+
338
+ class AugmentOp:
339
+ """
340
+ Apply for video.
341
+ """
342
+
343
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
344
+ hparams = hparams or _HPARAMS_DEFAULT
345
+ self.aug_fn = NAME_TO_OP[name]
346
+ self.level_fn = LEVEL_TO_ARG[name]
347
+ self.prob = prob
348
+ self.magnitude = magnitude
349
+ self.hparams = hparams.copy()
350
+ self.kwargs = {
351
+ "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL,
352
+ "resample": hparams["interpolation"] if "interpolation" in hparams else _RANDOM_INTERPOLATION,
353
+ }
354
+
355
+ # If magnitude_std is > 0, we introduce some randomness
356
+ # in the usually fixed policy and sample magnitude from a normal distribution
357
+ # with mean `magnitude` and std-dev of `magnitude_std`.
358
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
359
+ self.magnitude_std = self.hparams.get("magnitude_std", 0)
360
+
361
+ def __call__(self, img_list):
362
+ if self.prob < 1.0 and random.random() > self.prob:
363
+ return img_list
364
+ magnitude = self.magnitude
365
+ if self.magnitude_std and self.magnitude_std > 0:
366
+ magnitude = random.gauss(magnitude, self.magnitude_std)
367
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
368
+ level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else ()
369
+
370
+ if isinstance(img_list, list):
371
+ return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list]
372
+ else:
373
+ return self.aug_fn(img_list, *level_args, **self.kwargs)
374
+
375
+
376
+ _RAND_TRANSFORMS = [
377
+ "AutoContrast",
378
+ "Equalize",
379
+ "Invert",
380
+ "Rotate",
381
+ "Posterize",
382
+ "Solarize",
383
+ "SolarizeAdd",
384
+ "Color",
385
+ "Contrast",
386
+ "Brightness",
387
+ "Sharpness",
388
+ "ShearX",
389
+ "ShearY",
390
+ "TranslateXRel",
391
+ "TranslateYRel",
392
+ ]
393
+
394
+
395
+ _RAND_INCREASING_TRANSFORMS = [
396
+ "AutoContrast",
397
+ "Equalize",
398
+ "Invert",
399
+ "Rotate",
400
+ "PosterizeIncreasing",
401
+ "SolarizeIncreasing",
402
+ "SolarizeAdd",
403
+ "ColorIncreasing",
404
+ "ContrastIncreasing",
405
+ "BrightnessIncreasing",
406
+ "SharpnessIncreasing",
407
+ "ShearX",
408
+ "ShearY",
409
+ "TranslateXRel",
410
+ "TranslateYRel",
411
+ ]
412
+
413
+
414
+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
415
+ # They may not result in increased performance, but could likely be tuned to so.
416
+ _RAND_CHOICE_WEIGHTS_0 = {
417
+ "Rotate": 0.3,
418
+ "ShearX": 0.2,
419
+ "ShearY": 0.2,
420
+ "TranslateXRel": 0.1,
421
+ "TranslateYRel": 0.1,
422
+ "Color": 0.025,
423
+ "Sharpness": 0.025,
424
+ "AutoContrast": 0.025,
425
+ "Solarize": 0.005,
426
+ "SolarizeAdd": 0.005,
427
+ "Contrast": 0.005,
428
+ "Brightness": 0.005,
429
+ "Equalize": 0.005,
430
+ "Posterize": 0,
431
+ "Invert": 0,
432
+ }
433
+
434
+ _RAND_CHOICE_WEIGHTS_1 = {
435
+ "Rotate": 0.0,
436
+ "ShearX": 0.0,
437
+ "ShearY": 0.0,
438
+ "TranslateXRel": 0.0,
439
+ "TranslateYRel": 0.0,
440
+ "Color": 0.25,
441
+ "Sharpness": 0.25,
442
+ "AutoContrast": 0.25,
443
+ "Solarize": 0.05,
444
+ "SolarizeAdd": 0.05,
445
+ "Contrast": 0.05,
446
+ "Brightness": 0.05,
447
+ "Equalize": 0.05,
448
+ "Posterize": 0,
449
+ "Invert": 0,
450
+ }
451
+
452
+
453
+ def _select_rand_weights(weight_idx=0, transforms=None):
454
+ transforms = transforms or _RAND_TRANSFORMS
455
+ assert weight_idx == 0 or weight_idx == 1 # only two sets of weights currently
456
+ if weight_idx == 0:
457
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
458
+ elif weight_idx == 1:
459
+ rand_weights = _RAND_CHOICE_WEIGHTS_1
460
+ probs = [rand_weights[k] for k in transforms]
461
+ probs /= np.sum(probs)
462
+ return probs
463
+
464
+
465
+ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
466
+ hparams = hparams or _HPARAMS_DEFAULT
467
+ transforms = transforms or _RAND_TRANSFORMS
468
+ return [AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
469
+
470
+
471
+ class RandAugment:
472
+ def __init__(self, ops, num_layers=2, choice_weights=None):
473
+ self.ops = ops
474
+ self.num_layers = num_layers
475
+ self.choice_weights = choice_weights
476
+
477
+ def __call__(self, img):
478
+ # no replacement when using weighted choice
479
+ ops = np.random.choice(
480
+ self.ops,
481
+ self.num_layers,
482
+ replace=self.choice_weights is None,
483
+ p=self.choice_weights,
484
+ )
485
+ for op in ops:
486
+ img = op(img)
487
+ return img
488
+
489
+
490
+ def rand_augment_transform(config_str, hparams):
491
+ """
492
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
493
+
494
+ Create a RandAugment transform
495
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
496
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
497
+ sections, not order sepecific determine
498
+ 'm' - integer magnitude of rand augment
499
+ 'n' - integer num layers (number of transform ops selected per image)
500
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
501
+ 'mstd' - float std deviation of magnitude noise applied
502
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
503
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
504
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
505
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
506
+ :return: A PyTorch compatible Transform
507
+ """
508
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
509
+ num_layers = 2 # default to 2 ops per image
510
+ weight_idx = None # default to no probability weights for op choice
511
+ transforms = _RAND_TRANSFORMS
512
+ config = config_str.split("-")
513
+ assert config[0] == "rand"
514
+ config = config[1:]
515
+ for c in config:
516
+ cs = re.split(r"(\d.*)", c)
517
+ if len(cs) < 2:
518
+ continue
519
+ key, val = cs[:2]
520
+ if key == "mstd":
521
+ # noise param injected via hparams for now
522
+ hparams.setdefault("magnitude_std", float(val))
523
+ elif key == "inc":
524
+ if bool(val):
525
+ transforms = _RAND_INCREASING_TRANSFORMS
526
+ elif key == "m":
527
+ magnitude = int(val)
528
+ elif key == "n":
529
+ num_layers = int(val)
530
+ elif key == "w":
531
+ weight_idx = int(val)
532
+ else:
533
+ assert NotImplementedError
534
+ ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
535
+ choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
536
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
src/datasets/utils/video/randerase.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Copyright 2020 Ross Wightman
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # This implementation is based on
18
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
19
+ # published under an Apache License 2.0.
20
+
21
+
22
+ import math
23
+ import random
24
+
25
+ import torch
26
+
27
+
28
+ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"):
29
+ # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
30
+ # paths, flip the order so normal is run on CPU if this becomes a problem
31
+ # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
32
+ if per_pixel:
33
+ return torch.empty(patch_size, dtype=dtype, device=device).normal_()
34
+ elif rand_color:
35
+ return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
36
+ else:
37
+ return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
38
+
39
+
40
+ class RandomErasing:
41
+ """Randomly selects a rectangle region in an image and erases its pixels.
42
+ 'Random Erasing Data Augmentation' by Zhong et al.
43
+ See https://arxiv.org/pdf/1708.04896.pdf
44
+ This variant of RandomErasing is intended to be applied to either a batch
45
+ or single image tensor after it has been normalized by dataset mean and std.
46
+ Args:
47
+ probability: Probability that the Random Erasing operation will be performed.
48
+ min_area: Minimum percentage of erased area wrt input image area.
49
+ max_area: Maximum percentage of erased area wrt input image area.
50
+ min_aspect: Minimum aspect ratio of erased area.
51
+ mode: pixel color mode, one of 'const', 'rand', or 'pixel'
52
+ 'const' - erase block is constant color of 0 for all channels
53
+ 'rand' - erase block is same per-channel random (normal) color
54
+ 'pixel' - erase block is per-pixel random (normal) color
55
+ max_count: maximum number of erasing blocks per image, area per box is scaled by count.
56
+ per-image count is randomly chosen between 1 and this value.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ probability=0.5,
62
+ min_area=0.02,
63
+ max_area=1 / 3,
64
+ min_aspect=0.3,
65
+ max_aspect=None,
66
+ mode="const",
67
+ min_count=1,
68
+ max_count=None,
69
+ num_splits=0,
70
+ device="cuda",
71
+ cube=True,
72
+ ):
73
+ self.probability = probability
74
+ self.min_area = min_area
75
+ self.max_area = max_area
76
+ max_aspect = max_aspect or 1 / min_aspect
77
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
78
+ self.min_count = min_count
79
+ self.max_count = max_count or min_count
80
+ self.num_splits = num_splits
81
+ mode = mode.lower()
82
+ self.rand_color = False
83
+ self.per_pixel = False
84
+ self.cube = cube
85
+ if mode == "rand":
86
+ self.rand_color = True # per block random normal
87
+ elif mode == "pixel":
88
+ self.per_pixel = True # per pixel random normal
89
+ else:
90
+ assert not mode or mode == "const"
91
+ self.device = device
92
+
93
+ def _erase(self, img, chan, img_h, img_w, dtype):
94
+ if random.random() > self.probability:
95
+ return
96
+ area = img_h * img_w
97
+ count = self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count)
98
+ for _ in range(count):
99
+ for _ in range(10):
100
+ target_area = random.uniform(self.min_area, self.max_area) * area / count
101
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
102
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
103
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
104
+ if w < img_w and h < img_h:
105
+ top = random.randint(0, img_h - h)
106
+ left = random.randint(0, img_w - w)
107
+ img[:, top : top + h, left : left + w] = _get_pixels(
108
+ self.per_pixel,
109
+ self.rand_color,
110
+ (chan, h, w),
111
+ dtype=dtype,
112
+ device=self.device,
113
+ )
114
+ break
115
+
116
+ def _erase_cube(
117
+ self,
118
+ img,
119
+ batch_start,
120
+ batch_size,
121
+ chan,
122
+ img_h,
123
+ img_w,
124
+ dtype,
125
+ ):
126
+ if random.random() > self.probability:
127
+ return
128
+ area = img_h * img_w
129
+ count = self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count)
130
+ for _ in range(count):
131
+ for _ in range(100):
132
+ target_area = random.uniform(self.min_area, self.max_area) * area / count
133
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
134
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
135
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
136
+ if w < img_w and h < img_h:
137
+ top = random.randint(0, img_h - h)
138
+ left = random.randint(0, img_w - w)
139
+ for i in range(batch_start, batch_size):
140
+ img_instance = img[i]
141
+ img_instance[:, top : top + h, left : left + w] = _get_pixels(
142
+ self.per_pixel,
143
+ self.rand_color,
144
+ (chan, h, w),
145
+ dtype=dtype,
146
+ device=self.device,
147
+ )
148
+ break
149
+
150
+ def __call__(self, input):
151
+ if len(input.size()) == 3:
152
+ self._erase(input, *input.size(), input.dtype)
153
+ else:
154
+ batch_size, chan, img_h, img_w = input.size()
155
+ # skip first slice of batch if num_splits is set (for clean portion of samples)
156
+ batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
157
+ if self.cube:
158
+ self._erase_cube(
159
+ input,
160
+ batch_start,
161
+ batch_size,
162
+ chan,
163
+ img_h,
164
+ img_w,
165
+ input.dtype,
166
+ )
167
+ else:
168
+ for i in range(batch_start, batch_size):
169
+ self._erase(input[i], chan, img_h, img_w, input.dtype)
170
+ return input
src/datasets/utils/video/transforms.py ADDED
@@ -0,0 +1,1161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import numbers
8
+ import random
9
+
10
+ import numpy as np
11
+ import PIL
12
+ import torch
13
+ import torchvision
14
+ import torchvision.transforms.functional as F
15
+ from PIL import Image
16
+ from torch import Tensor
17
+ from torchvision import transforms
18
+
19
+ import src.datasets.utils.video.functional as FF
20
+ from src.datasets.utils.video.randaugment import rand_augment_transform
21
+
22
+ _pil_interpolation_to_str = {
23
+ Image.NEAREST: "PIL.Image.NEAREST",
24
+ Image.BILINEAR: "PIL.Image.BILINEAR",
25
+ Image.BICUBIC: "PIL.Image.BICUBIC",
26
+ Image.LANCZOS: "PIL.Image.LANCZOS",
27
+ Image.HAMMING: "PIL.Image.HAMMING",
28
+ Image.BOX: "PIL.Image.BOX",
29
+ }
30
+
31
+
32
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
33
+ PAD_FRAME_METHODS = ["circulant"]
34
+
35
+
36
+ def _pil_interp(method):
37
+ if method == "bicubic":
38
+ return Image.BICUBIC
39
+ elif method == "lanczos":
40
+ return Image.LANCZOS
41
+ elif method == "hamming":
42
+ return Image.HAMMING
43
+ else:
44
+ return Image.BILINEAR
45
+
46
+
47
+ def random_short_side_scale_jitter(images, min_size, max_size, boxes=None, inverse_uniform_sampling=False):
48
+ """
49
+ Perform a spatial short scale jittering on the given images and
50
+ corresponding boxes.
51
+ Args:
52
+ images (tensor): images to perform scale jitter. Dimension is
53
+ `num frames` x `channel` x `height` x `width`.
54
+ min_size (int): the minimal size to scale the frames.
55
+ max_size (int): the maximal size to scale the frames.
56
+ boxes (ndarray): optional. Corresponding boxes to images.
57
+ Dimension is `num boxes` x 4.
58
+ inverse_uniform_sampling (bool): if True, sample uniformly in
59
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
60
+ scale. If False, take a uniform sample from [min_scale, max_scale].
61
+ Returns:
62
+ (tensor): the scaled images with dimension of
63
+ `num frames` x `channel` x `new height` x `new width`.
64
+ (ndarray or None): the scaled boxes with dimension of
65
+ `num boxes` x 4.
66
+ """
67
+ if inverse_uniform_sampling:
68
+ size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)))
69
+ else:
70
+ size = int(round(np.random.uniform(min_size, max_size)))
71
+
72
+ height = images.shape[2]
73
+ width = images.shape[3]
74
+ if (width <= height and width == size) or (height <= width and height == size):
75
+ return images, boxes
76
+ new_width = size
77
+ new_height = size
78
+ if width < height:
79
+ new_height = int(math.floor((float(height) / width) * size))
80
+ if boxes is not None:
81
+ boxes = boxes * float(new_height) / height
82
+ else:
83
+ new_width = int(math.floor((float(width) / height) * size))
84
+ if boxes is not None:
85
+ boxes = boxes * float(new_width) / width
86
+
87
+ return (
88
+ torch.nn.functional.interpolate(
89
+ images,
90
+ size=(new_height, new_width),
91
+ mode="bilinear",
92
+ align_corners=False,
93
+ ),
94
+ boxes,
95
+ )
96
+
97
+
98
+ def crop_boxes(boxes, x_offset, y_offset):
99
+ """
100
+ Peform crop on the bounding boxes given the offsets.
101
+ Args:
102
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
103
+ is `num boxes` x 4.
104
+ x_offset (int): cropping offset in the x axis.
105
+ y_offset (int): cropping offset in the y axis.
106
+ Returns:
107
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
108
+ `num boxes` x 4.
109
+ """
110
+ cropped_boxes = boxes.copy()
111
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
112
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
113
+
114
+ return cropped_boxes
115
+
116
+
117
+ def random_crop(images, size, boxes=None):
118
+ """
119
+ Perform random spatial crop on the given images and corresponding boxes.
120
+ Args:
121
+ images (tensor): images to perform random crop. The dimension is
122
+ `num frames` x `channel` x `height` x `width`.
123
+ size (int): the size of height and width to crop on the image.
124
+ boxes (ndarray or None): optional. Corresponding boxes to images.
125
+ Dimension is `num boxes` x 4.
126
+ Returns:
127
+ cropped (tensor): cropped images with dimension of
128
+ `num frames` x `channel` x `size` x `size`.
129
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
130
+ `num boxes` x 4.
131
+ """
132
+ if images.shape[2] == size and images.shape[3] == size:
133
+ return images
134
+ height = images.shape[2]
135
+ width = images.shape[3]
136
+ y_offset = 0
137
+ if height > size:
138
+ y_offset = int(np.random.randint(0, height - size))
139
+ x_offset = 0
140
+ if width > size:
141
+ x_offset = int(np.random.randint(0, width - size))
142
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
143
+
144
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
145
+
146
+ return cropped, cropped_boxes
147
+
148
+
149
+ def horizontal_flip(prob, images, boxes=None):
150
+ """
151
+ Perform horizontal flip on the given images and corresponding boxes.
152
+ Args:
153
+ prob (float): probility to flip the images.
154
+ images (tensor): images to perform horizontal flip, the dimension is
155
+ `num frames` x `channel` x `height` x `width`.
156
+ boxes (ndarray or None): optional. Corresponding boxes to images.
157
+ Dimension is `num boxes` x 4.
158
+ Returns:
159
+ images (tensor): images with dimension of
160
+ `num frames` x `channel` x `height` x `width`.
161
+ flipped_boxes (ndarray or None): the flipped boxes with dimension of
162
+ `num boxes` x 4.
163
+ """
164
+ if boxes is None:
165
+ flipped_boxes = None
166
+ else:
167
+ flipped_boxes = boxes.copy()
168
+
169
+ if np.random.uniform() < prob:
170
+ images = images.flip((-1))
171
+
172
+ if len(images.shape) == 3:
173
+ width = images.shape[2]
174
+ elif len(images.shape) == 4:
175
+ width = images.shape[3]
176
+ else:
177
+ raise NotImplementedError("Dimension does not supported")
178
+ if boxes is not None:
179
+ flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
180
+
181
+ return images, flipped_boxes
182
+
183
+
184
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
185
+ """
186
+ Perform uniform spatial sampling on the images and corresponding boxes.
187
+ Args:
188
+ images (tensor): images to perform uniform crop. The dimension is
189
+ `num frames` x `channel` x `height` x `width`.
190
+ size (int): size of height and weight to crop the images.
191
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
192
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
193
+ crop if height is larger than width.
194
+ boxes (ndarray or None): optional. Corresponding boxes to images.
195
+ Dimension is `num boxes` x 4.
196
+ scale_size (int): optinal. If not None, resize the images to scale_size before
197
+ performing any crop.
198
+ Returns:
199
+ cropped (tensor): images with dimension of
200
+ `num frames` x `channel` x `size` x `size`.
201
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
202
+ `num boxes` x 4.
203
+ """
204
+ assert spatial_idx in [0, 1, 2]
205
+ ndim = len(images.shape)
206
+ if ndim == 3:
207
+ images = images.unsqueeze(0)
208
+ height = images.shape[2]
209
+ width = images.shape[3]
210
+
211
+ if scale_size is not None:
212
+ if width <= height:
213
+ width, height = scale_size, int(height / width * scale_size)
214
+ else:
215
+ width, height = int(width / height * scale_size), scale_size
216
+ images = torch.nn.functional.interpolate(
217
+ images,
218
+ size=(height, width),
219
+ mode="bilinear",
220
+ align_corners=False,
221
+ )
222
+
223
+ y_offset = int(math.ceil((height - size) / 2))
224
+ x_offset = int(math.ceil((width - size) / 2))
225
+
226
+ if height > width:
227
+ if spatial_idx == 0:
228
+ y_offset = 0
229
+ elif spatial_idx == 2:
230
+ y_offset = height - size
231
+ else:
232
+ if spatial_idx == 0:
233
+ x_offset = 0
234
+ elif spatial_idx == 2:
235
+ x_offset = width - size
236
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
237
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
238
+ if ndim == 3:
239
+ cropped = cropped.squeeze(0)
240
+ return cropped, cropped_boxes
241
+
242
+
243
+ def clip_boxes_to_image(boxes, height, width):
244
+ """
245
+ Clip an array of boxes to an image with the given height and width.
246
+ Args:
247
+ boxes (ndarray): bounding boxes to perform clipping.
248
+ Dimension is `num boxes` x 4.
249
+ height (int): given image height.
250
+ width (int): given image width.
251
+ Returns:
252
+ clipped_boxes (ndarray): the clipped boxes with dimension of
253
+ `num boxes` x 4.
254
+ """
255
+ clipped_boxes = boxes.copy()
256
+ clipped_boxes[:, [0, 2]] = np.minimum(width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]))
257
+ clipped_boxes[:, [1, 3]] = np.minimum(height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]))
258
+ return clipped_boxes
259
+
260
+
261
+ def blend(images1, images2, alpha):
262
+ """
263
+ Blend two images with a given weight alpha.
264
+ Args:
265
+ images1 (tensor): the first images to be blended, the dimension is
266
+ `num frames` x `channel` x `height` x `width`.
267
+ images2 (tensor): the second images to be blended, the dimension is
268
+ `num frames` x `channel` x `height` x `width`.
269
+ alpha (float): the blending weight.
270
+ Returns:
271
+ (tensor): blended images, the dimension is
272
+ `num frames` x `channel` x `height` x `width`.
273
+ """
274
+ return images1 * alpha + images2 * (1 - alpha)
275
+
276
+
277
+ def grayscale(images):
278
+ """
279
+ Get the grayscale for the input images. The channels of images should be
280
+ in order BGR.
281
+ Args:
282
+ images (tensor): the input images for getting grayscale. Dimension is
283
+ `num frames` x `channel` x `height` x `width`.
284
+ Returns:
285
+ img_gray (tensor): blended images, the dimension is
286
+ `num frames` x `channel` x `height` x `width`.
287
+ """
288
+ # R -> 0.299, G -> 0.587, B -> 0.114.
289
+ img_gray = torch.tensor(images)
290
+ gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
291
+ img_gray[:, 0] = gray_channel
292
+ img_gray[:, 1] = gray_channel
293
+ img_gray[:, 2] = gray_channel
294
+ return img_gray
295
+
296
+
297
+ def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
298
+ """
299
+ Perfrom a color jittering on the input images. The channels of images
300
+ should be in order BGR.
301
+ Args:
302
+ images (tensor): images to perform color jitter. Dimension is
303
+ `num frames` x `channel` x `height` x `width`.
304
+ img_brightness (float): jitter ratio for brightness.
305
+ img_contrast (float): jitter ratio for contrast.
306
+ img_saturation (float): jitter ratio for saturation.
307
+ Returns:
308
+ images (tensor): the jittered images, the dimension is
309
+ `num frames` x `channel` x `height` x `width`.
310
+ """
311
+
312
+ jitter = []
313
+ if img_brightness != 0:
314
+ jitter.append("brightness")
315
+ if img_contrast != 0:
316
+ jitter.append("contrast")
317
+ if img_saturation != 0:
318
+ jitter.append("saturation")
319
+
320
+ if len(jitter) > 0:
321
+ order = np.random.permutation(np.arange(len(jitter)))
322
+ for idx in range(0, len(jitter)):
323
+ if jitter[order[idx]] == "brightness":
324
+ images = brightness_jitter(img_brightness, images)
325
+ elif jitter[order[idx]] == "contrast":
326
+ images = contrast_jitter(img_contrast, images)
327
+ elif jitter[order[idx]] == "saturation":
328
+ images = saturation_jitter(img_saturation, images)
329
+ return images
330
+
331
+
332
+ def brightness_jitter(var, images):
333
+ """
334
+ Perfrom brightness jittering on the input images. The channels of images
335
+ should be in order BGR.
336
+ Args:
337
+ var (float): jitter ratio for brightness.
338
+ images (tensor): images to perform color jitter. Dimension is
339
+ `num frames` x `channel` x `height` x `width`.
340
+ Returns:
341
+ images (tensor): the jittered images, the dimension is
342
+ `num frames` x `channel` x `height` x `width`.
343
+ """
344
+ alpha = 1.0 + np.random.uniform(-var, var)
345
+
346
+ img_bright = torch.zeros(images.shape)
347
+ images = blend(images, img_bright, alpha)
348
+ return images
349
+
350
+
351
+ def contrast_jitter(var, images):
352
+ """
353
+ Perfrom contrast jittering on the input images. The channels of images
354
+ should be in order BGR.
355
+ Args:
356
+ var (float): jitter ratio for contrast.
357
+ images (tensor): images to perform color jitter. Dimension is
358
+ `num frames` x `channel` x `height` x `width`.
359
+ Returns:
360
+ images (tensor): the jittered images, the dimension is
361
+ `num frames` x `channel` x `height` x `width`.
362
+ """
363
+ alpha = 1.0 + np.random.uniform(-var, var)
364
+
365
+ img_gray = grayscale(images)
366
+ img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
367
+ images = blend(images, img_gray, alpha)
368
+ return images
369
+
370
+
371
+ def saturation_jitter(var, images):
372
+ """
373
+ Perfrom saturation jittering on the input images. The channels of images
374
+ should be in order BGR.
375
+ Args:
376
+ var (float): jitter ratio for saturation.
377
+ images (tensor): images to perform color jitter. Dimension is
378
+ `num frames` x `channel` x `height` x `width`.
379
+ Returns:
380
+ images (tensor): the jittered images, the dimension is
381
+ `num frames` x `channel` x `height` x `width`.
382
+ """
383
+ alpha = 1.0 + np.random.uniform(-var, var)
384
+ img_gray = grayscale(images)
385
+ images = blend(images, img_gray, alpha)
386
+
387
+ return images
388
+
389
+
390
+ def lighting_jitter(images, alphastd, eigval, eigvec):
391
+ """
392
+ Perform AlexNet-style PCA jitter on the given images.
393
+ Args:
394
+ images (tensor): images to perform lighting jitter. Dimension is
395
+ `num frames` x `channel` x `height` x `width`.
396
+ alphastd (float): jitter ratio for PCA jitter.
397
+ eigval (list): eigenvalues for PCA jitter.
398
+ eigvec (list[list]): eigenvectors for PCA jitter.
399
+ Returns:
400
+ out_images (tensor): the jittered images, the dimension is
401
+ `num frames` x `channel` x `height` x `width`.
402
+ """
403
+ if alphastd == 0:
404
+ return images
405
+ # generate alpha1, alpha2, alpha3.
406
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
407
+ eig_vec = np.array(eigvec)
408
+ eig_val = np.reshape(eigval, (1, 3))
409
+ rgb = np.sum(
410
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
411
+ axis=1,
412
+ )
413
+ out_images = torch.zeros_like(images)
414
+ if len(images.shape) == 3:
415
+ # C H W
416
+ channel_dim = 0
417
+ elif len(images.shape) == 4:
418
+ # T C H W
419
+ channel_dim = 1
420
+ else:
421
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
422
+
423
+ for idx in range(images.shape[channel_dim]):
424
+ # C H W
425
+ if len(images.shape) == 3:
426
+ out_images[idx] = images[idx] + rgb[2 - idx]
427
+ # T C H W
428
+ elif len(images.shape) == 4:
429
+ out_images[:, idx] = images[:, idx] + rgb[2 - idx]
430
+ else:
431
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
432
+
433
+ return out_images
434
+
435
+
436
+ def color_normalization(images, mean, stddev):
437
+ """
438
+ Perform color nomration on the given images.
439
+ Args:
440
+ images (tensor): images to perform color normalization. Dimension is
441
+ `num frames` x `channel` x `height` x `width`.
442
+ mean (list): mean values for normalization.
443
+ stddev (list): standard deviations for normalization.
444
+
445
+ Returns:
446
+ out_images (tensor): the noramlized images, the dimension is
447
+ `num frames` x `channel` x `height` x `width`.
448
+ """
449
+ if len(images.shape) == 3:
450
+ assert len(mean) == images.shape[0], "channel mean not computed properly"
451
+ assert len(stddev) == images.shape[0], "channel stddev not computed properly"
452
+ elif len(images.shape) == 4:
453
+ assert len(mean) == images.shape[1], "channel mean not computed properly"
454
+ assert len(stddev) == images.shape[1], "channel stddev not computed properly"
455
+ else:
456
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
457
+
458
+ out_images = torch.zeros_like(images)
459
+ for idx in range(len(mean)):
460
+ # C H W
461
+ if len(images.shape) == 3:
462
+ out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
463
+ elif len(images.shape) == 4:
464
+ out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
465
+ else:
466
+ raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
467
+ return out_images
468
+
469
+
470
+ def _get_param_spatial_crop(scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False):
471
+ """
472
+ Given scale, ratio, height and width, return sampled coordinates of the videos.
473
+ """
474
+ for _ in range(num_repeat):
475
+ area = height * width
476
+ target_area = random.uniform(*scale) * area
477
+ if log_scale:
478
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
479
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
480
+ else:
481
+ aspect_ratio = random.uniform(*ratio)
482
+
483
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
484
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
485
+
486
+ if np.random.uniform() < 0.5 and switch_hw:
487
+ w, h = h, w
488
+
489
+ if 0 < w <= width and 0 < h <= height:
490
+ i = random.randint(0, height - h)
491
+ j = random.randint(0, width - w)
492
+ return i, j, h, w
493
+
494
+ # Fallback to central crop
495
+ in_ratio = float(width) / float(height)
496
+ if in_ratio < min(ratio):
497
+ w = width
498
+ h = int(round(w / min(ratio)))
499
+ elif in_ratio > max(ratio):
500
+ h = height
501
+ w = int(round(h * max(ratio)))
502
+ else: # whole image
503
+ w = width
504
+ h = height
505
+ i = (height - h) // 2
506
+ j = (width - w) // 2
507
+ return i, j, h, w
508
+
509
+
510
+ def random_resized_crop(
511
+ images,
512
+ target_height,
513
+ target_width,
514
+ scale=(0.8, 1.0),
515
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
516
+ ):
517
+ """
518
+ Crop the given images to random size and aspect ratio. A crop of random
519
+ size (default: of 0.08 to 1.0) of the original size and a random aspect
520
+ ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
521
+ crop is finally resized to given size. This is popularly used to train the
522
+ Inception networks.
523
+
524
+ Args:
525
+ images: Images to perform resizing and cropping.
526
+ target_height: Desired height after cropping.
527
+ target_width: Desired width after cropping.
528
+ scale: Scale range of Inception-style area based random resizing.
529
+ ratio: Aspect ratio range of Inception-style area based random resizing.
530
+ """
531
+
532
+ height = images.shape[2]
533
+ width = images.shape[3]
534
+
535
+ i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
536
+ cropped = images[:, :, i : i + h, j : j + w]
537
+ return torch.nn.functional.interpolate(
538
+ cropped,
539
+ size=(target_height, target_width),
540
+ mode="bilinear",
541
+ align_corners=False,
542
+ )
543
+
544
+
545
+ def random_resized_crop_with_shift(
546
+ images,
547
+ target_height,
548
+ target_width,
549
+ scale=(0.8, 1.0),
550
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
551
+ ):
552
+ """
553
+ This is similar to random_resized_crop. However, it samples two different
554
+ boxes (for cropping) for the first and last frame. It then linearly
555
+ interpolates the two boxes for other frames.
556
+
557
+ Args:
558
+ images: Images to perform resizing and cropping.
559
+ target_height: Desired height after cropping.
560
+ target_width: Desired width after cropping.
561
+ scale: Scale range of Inception-style area based random resizing.
562
+ ratio: Aspect ratio range of Inception-style area based random resizing.
563
+ """
564
+ t = images.shape[1]
565
+ height = images.shape[2]
566
+ width = images.shape[3]
567
+
568
+ i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
569
+ i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
570
+ i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
571
+ j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
572
+ h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
573
+ w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
574
+ out = torch.zeros((3, t, target_height, target_width))
575
+ for ind in range(t):
576
+ out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
577
+ images[
578
+ :,
579
+ ind : ind + 1,
580
+ i_s[ind] : i_s[ind] + h_s[ind],
581
+ j_s[ind] : j_s[ind] + w_s[ind],
582
+ ],
583
+ size=(target_height, target_width),
584
+ mode="bilinear",
585
+ align_corners=False,
586
+ )
587
+ return out
588
+
589
+
590
+ def create_random_augment(
591
+ input_size,
592
+ auto_augment=None,
593
+ interpolation="bilinear",
594
+ ):
595
+ """
596
+ Get video randaug transform.
597
+
598
+ Args:
599
+ input_size: The size of the input video in tuple.
600
+ auto_augment: Parameters for randaug. An example:
601
+ "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
602
+ of operations to apply).
603
+ interpolation: Interpolation method.
604
+ """
605
+ if isinstance(input_size, tuple):
606
+ img_size = input_size[-2:]
607
+ else:
608
+ img_size = input_size
609
+
610
+ if auto_augment:
611
+ assert isinstance(auto_augment, str)
612
+ if isinstance(img_size, tuple):
613
+ img_size_min = min(img_size)
614
+ else:
615
+ img_size_min = img_size
616
+ aa_params = {"translate_const": int(img_size_min * 0.45)}
617
+ if interpolation and interpolation != "random":
618
+ aa_params["interpolation"] = _pil_interp(interpolation)
619
+ if auto_augment.startswith("rand"):
620
+ return transforms.Compose([rand_augment_transform(auto_augment, aa_params)])
621
+ raise NotImplementedError
622
+
623
+
624
+ def random_sized_crop_img(
625
+ im,
626
+ size,
627
+ jitter_scale=(0.08, 1.0),
628
+ jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
629
+ max_iter=10,
630
+ ):
631
+ """
632
+ Performs Inception-style cropping (used for training).
633
+ """
634
+ assert len(im.shape) == 3, "Currently only support image for random_sized_crop"
635
+ h, w = im.shape[1:3]
636
+ i, j, h, w = _get_param_spatial_crop(
637
+ scale=jitter_scale,
638
+ ratio=jitter_aspect,
639
+ height=h,
640
+ width=w,
641
+ num_repeat=max_iter,
642
+ log_scale=False,
643
+ switch_hw=True,
644
+ )
645
+ cropped = im[:, i : i + h, j : j + w]
646
+ return torch.nn.functional.interpolate(
647
+ cropped.unsqueeze(0),
648
+ size=(size, size),
649
+ mode="bilinear",
650
+ align_corners=False,
651
+ ).squeeze(0)
652
+
653
+
654
+ def circulant_frame_padding(video: Tensor, total_frames: int) -> Tensor:
655
+ """
656
+ Applies circulant frame padding (repeating the video) to a specified size.
657
+
658
+ Args:
659
+ video: The input video to be padded. Expected (C, T, H, W)
660
+ total_frames: The number of frames after padding.
661
+
662
+ Returns
663
+ The video padded to total_frames.
664
+ """
665
+ start_frames = video.shape[1]
666
+ if start_frames == total_frames:
667
+ return video
668
+
669
+ num_repeats = total_frames // start_frames + (total_frames % start_frames > 0)
670
+
671
+ return video.repeat((1, num_repeats) + (1,) * (video.ndim - 2))[:, :total_frames]
672
+
673
+
674
+ def frame_pad(video: Tensor, total_frames: int, pad_frame_method: str) -> Tensor:
675
+ if pad_frame_method not in PAD_FRAME_METHODS:
676
+ raise ValueError(f"Unrecognized pad_frame_method {pad_frame_method}")
677
+
678
+ if pad_frame_method == "circulant":
679
+ return circulant_frame_padding(video, total_frames)
680
+
681
+ return None
682
+
683
+
684
+ # The following code are modified based on timm lib, we will replace the following
685
+ # contents with dependency from PyTorchVideo.
686
+ # https://github.com/facebookresearch/pytorchvideo
687
+ class RandomResizedCropAndInterpolation:
688
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
689
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
690
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
691
+ is finally resized to given size.
692
+ This is popularly used to train the Inception networks.
693
+ Args:
694
+ size: expected output size of each edge
695
+ scale: range of size of the origin size cropped
696
+ ratio: range of aspect ratio of the origin aspect ratio cropped
697
+ interpolation: Default: PIL.Image.BILINEAR
698
+ """
699
+
700
+ def __init__(
701
+ self,
702
+ size,
703
+ scale=(0.08, 1.0),
704
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
705
+ interpolation="bilinear",
706
+ ):
707
+ if isinstance(size, tuple):
708
+ self.size = size
709
+ else:
710
+ self.size = (size, size)
711
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
712
+ print("range should be of kind (min, max)")
713
+
714
+ if interpolation == "random":
715
+ self.interpolation = _RANDOM_INTERPOLATION
716
+ else:
717
+ self.interpolation = _pil_interp(interpolation)
718
+ self.scale = scale
719
+ self.ratio = ratio
720
+
721
+ @staticmethod
722
+ def get_params(img, scale, ratio):
723
+ """Get parameters for ``crop`` for a random sized crop.
724
+ Args:
725
+ img (PIL Image): Image to be cropped.
726
+ scale (tuple): range of size of the origin size cropped
727
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
728
+ Returns:
729
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
730
+ sized crop.
731
+ """
732
+ area = img.size[0] * img.size[1]
733
+
734
+ for _ in range(10):
735
+ target_area = random.uniform(*scale) * area
736
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
737
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
738
+
739
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
740
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
741
+
742
+ if w <= img.size[0] and h <= img.size[1]:
743
+ i = random.randint(0, img.size[1] - h)
744
+ j = random.randint(0, img.size[0] - w)
745
+ return i, j, h, w
746
+
747
+ # Fallback to central crop
748
+ in_ratio = img.size[0] / img.size[1]
749
+ if in_ratio < min(ratio):
750
+ w = img.size[0]
751
+ h = int(round(w / min(ratio)))
752
+ elif in_ratio > max(ratio):
753
+ h = img.size[1]
754
+ w = int(round(h * max(ratio)))
755
+ else: # whole image
756
+ w = img.size[0]
757
+ h = img.size[1]
758
+ i = (img.size[1] - h) // 2
759
+ j = (img.size[0] - w) // 2
760
+ return i, j, h, w
761
+
762
+ def __call__(self, img):
763
+ """
764
+ Args:
765
+ img (PIL Image): Image to be cropped and resized.
766
+ Returns:
767
+ PIL Image: Randomly cropped and resized image.
768
+ """
769
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
770
+ if isinstance(self.interpolation, (tuple, list)):
771
+ interpolation = random.choice(self.interpolation)
772
+ else:
773
+ interpolation = self.interpolation
774
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
775
+
776
+ def __repr__(self):
777
+ if isinstance(self.interpolation, (tuple, list)):
778
+ interpolate_str = " ".join([_pil_interpolation_to_str[x] for x in self.interpolation])
779
+ else:
780
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
781
+ format_string = self.__class__.__name__ + "(size={0}".format(self.size)
782
+ format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
783
+ format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
784
+ format_string += ", interpolation={0})".format(interpolate_str)
785
+ return format_string
786
+
787
+
788
+ class Compose(object):
789
+ """Composes several transforms
790
+ Args:
791
+ transforms (list of ``Transform`` objects): list of transforms
792
+ to compose
793
+ """
794
+
795
+ def __init__(self, transforms):
796
+ self.transforms = transforms
797
+
798
+ def __call__(self, clip):
799
+ for t in self.transforms:
800
+ clip = t(clip)
801
+ return clip
802
+
803
+
804
+ class RandomHorizontalFlip(object):
805
+ """Horizontally flip the list of given images randomly
806
+ with a probability 0.5
807
+ """
808
+
809
+ def __call__(self, clip):
810
+ """
811
+ Args:
812
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
813
+ in format (h, w, c) in numpy.ndarray
814
+ Returns:
815
+ PIL.Image or numpy.ndarray: Randomly flipped clip
816
+ """
817
+ if random.random() < 0.5:
818
+ if isinstance(clip[0], np.ndarray):
819
+ return [np.fliplr(img) for img in clip]
820
+ elif isinstance(clip[0], PIL.Image.Image):
821
+ return [img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip]
822
+ else:
823
+ raise TypeError("Expected numpy.ndarray or PIL.Image" + " but got list of {0}".format(type(clip[0])))
824
+ return clip
825
+
826
+
827
+ class RandomResize(object):
828
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
829
+ The larger the original image is, the more times it takes to
830
+ interpolate
831
+ Args:
832
+ interpolation (str): Can be one of 'nearest', 'bilinear'
833
+ defaults to nearest
834
+ size (tuple): (widht, height)
835
+ """
836
+
837
+ def __init__(self, ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="nearest"):
838
+ self.ratio = ratio
839
+ self.interpolation = interpolation
840
+
841
+ def __call__(self, clip):
842
+ scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
843
+
844
+ if isinstance(clip[0], np.ndarray):
845
+ im_h, im_w, im_c = clip[0].shape
846
+ elif isinstance(clip[0], PIL.Image.Image):
847
+ im_w, im_h = clip[0].size
848
+
849
+ new_w = int(im_w * scaling_factor)
850
+ new_h = int(im_h * scaling_factor)
851
+ new_size = (new_w, new_h)
852
+ resized = FF.resize_clip(clip, new_size, interpolation=self.interpolation)
853
+ return resized
854
+
855
+
856
+ class Resize(object):
857
+ """Resizes a list of (H x W x C) numpy.ndarray to the final size
858
+ The larger the original image is, the more times it takes to
859
+ interpolate
860
+ Args:
861
+ interpolation (str): Can be one of 'nearest', 'bilinear'
862
+ defaults to nearest
863
+ size (tuple): (widht, height)
864
+ """
865
+
866
+ def __init__(self, size, interpolation="nearest"):
867
+ self.size = size
868
+ self.interpolation = interpolation
869
+
870
+ def __call__(self, clip):
871
+ resized = FF.resize_clip(clip, self.size, interpolation=self.interpolation)
872
+ return resized
873
+
874
+
875
+ class RandomCrop(object):
876
+ """Extract random crop at the same location for a list of images
877
+ Args:
878
+ size (sequence or int): Desired output size for the
879
+ crop in format (h, w)
880
+ """
881
+
882
+ def __init__(self, size):
883
+ if isinstance(size, numbers.Number):
884
+ size = (size, size)
885
+
886
+ self.size = size
887
+
888
+ def __call__(self, clip):
889
+ """
890
+ Args:
891
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
892
+ in format (h, w, c) in numpy.ndarray
893
+ Returns:
894
+ PIL.Image or numpy.ndarray: Cropped list of images
895
+ """
896
+ h, w = self.size
897
+ if isinstance(clip[0], np.ndarray):
898
+ im_h, im_w, im_c = clip[0].shape
899
+ elif isinstance(clip[0], PIL.Image.Image):
900
+ im_w, im_h = clip[0].size
901
+ else:
902
+ raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
903
+ if w > im_w or h > im_h:
904
+ error_msg = (
905
+ "Initial image size should be larger then "
906
+ "cropped size but got cropped sizes : ({w}, {h}) while "
907
+ "initial image is ({im_w}, {im_h})".format(im_w=im_w, im_h=im_h, w=w, h=h)
908
+ )
909
+ raise ValueError(error_msg)
910
+
911
+ x1 = random.randint(0, im_w - w)
912
+ y1 = random.randint(0, im_h - h)
913
+ cropped = FF.crop_clip(clip, y1, x1, h, w)
914
+
915
+ return cropped
916
+
917
+
918
+ class ThreeCrop(object):
919
+ """Extract random crop at the same location for a list of images
920
+ Args:
921
+ size (sequence or int): Desired output size for the
922
+ crop in format (h, w)
923
+ """
924
+
925
+ def __init__(self, size):
926
+ if isinstance(size, numbers.Number):
927
+ size = (size, size)
928
+
929
+ self.size = size
930
+
931
+ def __call__(self, clip):
932
+ """
933
+ Args:
934
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
935
+ in format (h, w, c) in numpy.ndarray
936
+ Returns:
937
+ PIL.Image or numpy.ndarray: Cropped list of images
938
+ """
939
+ h, w = self.size
940
+ if isinstance(clip[0], np.ndarray):
941
+ im_h, im_w, im_c = clip[0].shape
942
+ elif isinstance(clip[0], PIL.Image.Image):
943
+ im_w, im_h = clip[0].size
944
+ else:
945
+ raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
946
+ if w != im_w and h != im_h:
947
+ clip = FF.resize_clip(clip, self.size, interpolation="bilinear")
948
+ im_h, im_w, im_c = clip[0].shape
949
+
950
+ step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0)
951
+ cropped = []
952
+ for i in range(3):
953
+ if im_h > self.size[0]:
954
+ x1 = 0
955
+ y1 = i * step
956
+ cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
957
+ else:
958
+ x1 = i * step
959
+ y1 = 0
960
+ cropped.extend(FF.crop_clip(clip, y1, x1, h, w))
961
+ return cropped
962
+
963
+
964
+ class RandomRotation(object):
965
+ """Rotate entire clip randomly by a random angle within
966
+ given bounds
967
+ Args:
968
+ degrees (sequence or int): Range of degrees to select from
969
+ If degrees is a number instead of sequence like (min, max),
970
+ the range of degrees, will be (-degrees, +degrees).
971
+ """
972
+
973
+ def __init__(self, degrees):
974
+ if isinstance(degrees, numbers.Number):
975
+ if degrees < 0:
976
+ raise ValueError("If degrees is a single number," "must be positive")
977
+ degrees = (-degrees, degrees)
978
+ else:
979
+ if len(degrees) != 2:
980
+ raise ValueError("If degrees is a sequence," "it must be of len 2.")
981
+
982
+ self.degrees = degrees
983
+
984
+ def __call__(self, clip):
985
+ """
986
+ Args:
987
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
988
+ in format (h, w, c) in numpy.ndarray
989
+ Returns:
990
+ PIL.Image or numpy.ndarray: Cropped list of images
991
+ """
992
+ import skimage
993
+
994
+ angle = random.uniform(self.degrees[0], self.degrees[1])
995
+ if isinstance(clip[0], np.ndarray):
996
+ rotated = [skimage.transform.rotate(img, angle) for img in clip]
997
+ elif isinstance(clip[0], PIL.Image.Image):
998
+ rotated = [img.rotate(angle) for img in clip]
999
+ else:
1000
+ raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
1001
+
1002
+ return rotated
1003
+
1004
+
1005
+ class CenterCrop(object):
1006
+ """Extract center crop at the same location for a list of images
1007
+ Args:
1008
+ size (sequence or int): Desired output size for the
1009
+ crop in format (h, w)
1010
+ """
1011
+
1012
+ def __init__(self, size):
1013
+ if isinstance(size, numbers.Number):
1014
+ size = (size, size)
1015
+
1016
+ self.size = size
1017
+
1018
+ def __call__(self, clip):
1019
+ """
1020
+ Args:
1021
+ img (PIL.Image or numpy.ndarray): List of images to be cropped
1022
+ in format (h, w, c) in numpy.ndarray
1023
+ Returns:
1024
+ PIL.Image or numpy.ndarray: Cropped list of images
1025
+ """
1026
+ h, w = self.size
1027
+ if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
1028
+ if clip[0].shape[-1] == 3:
1029
+ im_h, im_w, im_c = clip[0].shape
1030
+ else:
1031
+ assert clip[0].shape[0] == 3
1032
+ im_c, im_h, im_w = clip[0].shape
1033
+ elif isinstance(clip[0], PIL.Image.Image):
1034
+ im_w, im_h = clip[0].size
1035
+ else:
1036
+ raise TypeError(
1037
+ "Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0]))
1038
+ )
1039
+ if w > im_w or h > im_h:
1040
+ error_msg = (
1041
+ "Initial image size should be larger then "
1042
+ "cropped size but got cropped sizes : ({w}, {h}) while "
1043
+ "initial image is ({im_w}, {im_h})".format(im_w=im_w, im_h=im_h, w=w, h=h)
1044
+ )
1045
+ raise ValueError(error_msg)
1046
+
1047
+ x1 = int(round((im_w - w) / 2.0))
1048
+ y1 = int(round((im_h - h) / 2.0))
1049
+ cropped = FF.crop_clip(clip, y1, x1, h, w)
1050
+
1051
+ return cropped
1052
+
1053
+
1054
+ class ColorJitter(object):
1055
+ """
1056
+ Randomly change the brightness, contrast and saturation and hue of the clip
1057
+
1058
+ Args:
1059
+ brightness (float): How much to jitter brightness. brightness_factor
1060
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
1061
+ contrast (float): How much to jitter contrast. contrast_factor
1062
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
1063
+ saturation (float): How much to jitter saturation. saturation_factor
1064
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
1065
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
1066
+ [-hue, hue]. Should be >=0 and <= 0.5.
1067
+ """
1068
+
1069
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1070
+ self.brightness = brightness
1071
+ self.contrast = contrast
1072
+ self.saturation = saturation
1073
+ self.hue = hue
1074
+
1075
+ def get_params(self, brightness, contrast, saturation, hue):
1076
+ if brightness > 0:
1077
+ brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness)
1078
+ else:
1079
+ brightness_factor = None
1080
+
1081
+ if contrast > 0:
1082
+ contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
1083
+ else:
1084
+ contrast_factor = None
1085
+
1086
+ if saturation > 0:
1087
+ saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation)
1088
+ else:
1089
+ saturation_factor = None
1090
+
1091
+ if hue > 0:
1092
+ hue_factor = random.uniform(-hue, hue)
1093
+ else:
1094
+ hue_factor = None
1095
+ return brightness_factor, contrast_factor, saturation_factor, hue_factor
1096
+
1097
+ def __call__(self, clip):
1098
+ """
1099
+ Args:
1100
+ clip (list): list of PIL.Image
1101
+ Returns:
1102
+ list PIL.Image : list of transformed PIL.Image
1103
+ """
1104
+ if isinstance(clip[0], np.ndarray):
1105
+ raise TypeError("Color jitter not yet implemented for numpy arrays")
1106
+ elif isinstance(clip[0], PIL.Image.Image):
1107
+ brightness, contrast, saturation, hue = self.get_params(
1108
+ self.brightness, self.contrast, self.saturation, self.hue
1109
+ )
1110
+
1111
+ # Create img transform function sequence
1112
+ img_transforms = []
1113
+ if brightness is not None:
1114
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
1115
+ if saturation is not None:
1116
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
1117
+ if hue is not None:
1118
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
1119
+ if contrast is not None:
1120
+ img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
1121
+ random.shuffle(img_transforms)
1122
+
1123
+ # Apply to all images
1124
+ jittered_clip = []
1125
+ for img in clip:
1126
+ for func in img_transforms:
1127
+ jittered_img = func(img)
1128
+ jittered_clip.append(jittered_img)
1129
+
1130
+ else:
1131
+ raise TypeError("Expected numpy.ndarray or PIL.Image" + "but got list of {0}".format(type(clip[0])))
1132
+ return jittered_clip
1133
+
1134
+
1135
+ class Normalize(object):
1136
+ """Normalize a clip with mean and standard deviation.
1137
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
1138
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
1139
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
1140
+ .. note::
1141
+ This transform acts out of place, i.e., it does not mutates the input tensor.
1142
+ Args:
1143
+ mean (sequence): Sequence of means for each channel.
1144
+ std (sequence): Sequence of standard deviations for each channel.
1145
+ """
1146
+
1147
+ def __init__(self, mean, std):
1148
+ self.mean = mean
1149
+ self.std = std
1150
+
1151
+ def __call__(self, clip):
1152
+ """
1153
+ Args:
1154
+ clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
1155
+ Returns:
1156
+ Tensor: Normalized Tensor clip.
1157
+ """
1158
+ return FF.normalize(clip, self.mean, self.std)
1159
+
1160
+ def __repr__(self):
1161
+ return self.__class__.__name__ + "(mean={0}, std={1})".format(self.mean, self.std)
src/datasets/utils/video/transforms_builder.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+
11
+ import src.datasets.utils.video.transforms as video_transforms
12
+ from src.datasets.utils.video.randerase import RandomErasing
13
+
14
+
15
+ def make_transforms(
16
+ random_horizontal_flip=True,
17
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
18
+ random_resize_scale=(0.3, 1.0),
19
+ reprob=0.0,
20
+ auto_augment=False,
21
+ motion_shift=False,
22
+ crop_size=224,
23
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
24
+ pad_frame_count: Optional[int] = None,
25
+ pad_frame_method: str = "circulant",
26
+ ):
27
+ _frames_augmentation = VideoTransform(
28
+ random_horizontal_flip=random_horizontal_flip,
29
+ random_resize_aspect_ratio=random_resize_aspect_ratio,
30
+ random_resize_scale=random_resize_scale,
31
+ reprob=reprob,
32
+ auto_augment=auto_augment,
33
+ motion_shift=motion_shift,
34
+ crop_size=crop_size,
35
+ normalize=normalize,
36
+ pad_frame_count=pad_frame_count,
37
+ pad_frame_method=pad_frame_method,
38
+ )
39
+ return _frames_augmentation
40
+
41
+
42
+ class VideoTransform(object):
43
+
44
+ def __init__(
45
+ self,
46
+ random_horizontal_flip=True,
47
+ random_resize_aspect_ratio=(3 / 4, 4 / 3),
48
+ random_resize_scale=(0.3, 1.0),
49
+ reprob=0.0,
50
+ auto_augment=False,
51
+ motion_shift=False,
52
+ crop_size=224,
53
+ normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
54
+ pad_frame_count: Optional[int] = None,
55
+ pad_frame_method: str = "circulant",
56
+ ):
57
+ self.random_horizontal_flip = random_horizontal_flip
58
+ self.random_resize_aspect_ratio = random_resize_aspect_ratio
59
+ self.random_resize_scale = random_resize_scale
60
+ self.auto_augment = auto_augment
61
+ self.motion_shift = motion_shift
62
+ self.crop_size = crop_size
63
+ self.mean = torch.tensor(normalize[0], dtype=torch.float32)
64
+ self.std = torch.tensor(normalize[1], dtype=torch.float32)
65
+ self.pad_frame_count = pad_frame_count
66
+ self.pad_frame_method = pad_frame_method
67
+
68
+ if not self.auto_augment:
69
+ # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
70
+ self.mean *= 255.0
71
+ self.std *= 255.0
72
+
73
+ self.autoaug_transform = video_transforms.create_random_augment(
74
+ input_size=(crop_size, crop_size),
75
+ auto_augment="rand-m7-n4-mstd0.5-inc1",
76
+ interpolation="bicubic",
77
+ )
78
+
79
+ self.spatial_transform = (
80
+ video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
81
+ )
82
+
83
+ self.reprob = reprob
84
+ self.erase_transform = RandomErasing(
85
+ reprob,
86
+ mode="pixel",
87
+ max_count=1,
88
+ num_splits=1,
89
+ device="cpu",
90
+ )
91
+
92
+ def __call__(self, buffer):
93
+
94
+ if self.auto_augment:
95
+ buffer = [transforms.ToPILImage()(frame) for frame in buffer]
96
+ buffer = self.autoaug_transform(buffer)
97
+ buffer = [transforms.ToTensor()(img) for img in buffer]
98
+ buffer = torch.stack(buffer) # T C H W
99
+ buffer = buffer.permute(0, 2, 3, 1) # T H W C
100
+ elif torch.is_tensor(buffer):
101
+ # TODO: ensure input is always a tensor?
102
+ buffer = buffer.to(torch.float32)
103
+ else:
104
+ buffer = torch.tensor(buffer, dtype=torch.float32)
105
+
106
+ buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W
107
+
108
+ buffer = self.spatial_transform(
109
+ images=buffer,
110
+ target_height=self.crop_size,
111
+ target_width=self.crop_size,
112
+ scale=self.random_resize_scale,
113
+ ratio=self.random_resize_aspect_ratio,
114
+ )
115
+ if self.random_horizontal_flip:
116
+ buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
117
+
118
+ buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
119
+ if self.reprob > 0:
120
+ buffer = buffer.permute(1, 0, 2, 3)
121
+ buffer = self.erase_transform(buffer)
122
+ buffer = buffer.permute(1, 0, 2, 3)
123
+
124
+ if self.pad_frame_count is not None:
125
+ buffer = video_transforms.frame_pad(buffer, self.pad_frame_count, self.pad_frame_method)
126
+
127
+ return buffer
128
+
129
+
130
+ def tensor_normalize(tensor, mean, std):
131
+ """
132
+ Normalize a given tensor by subtracting the mean and dividing the std.
133
+ Args:
134
+ tensor (tensor): tensor to normalize.
135
+ mean (tensor or list): mean value to subtract.
136
+ std (tensor or list): std to divide.
137
+ """
138
+ if tensor.dtype == torch.uint8:
139
+ tensor = tensor.float()
140
+ tensor = tensor / 255.0
141
+ if isinstance(mean, list):
142
+ mean = torch.tensor(mean)
143
+ if isinstance(std, list):
144
+ std = torch.tensor(std)
145
+ tensor = tensor - mean
146
+ tensor = tensor / std
147
+ return tensor
148
+
149
+
150
+ def _tensor_normalize_inplace(tensor, mean, std):
151
+ """
152
+ Normalize a given tensor by subtracting the mean and dividing the std.
153
+ Args:
154
+ tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
155
+ mean (tensor): mean value to subtract (in 0 to 255 floats).
156
+ std (tensor): std to divide (in 0 to 255 floats).
157
+ """
158
+ if tensor.dtype == torch.uint8:
159
+ tensor = tensor.float()
160
+
161
+ C, T, H, W = tensor.shape
162
+ tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
163
+ tensor.sub_(mean).div_(std)
164
+ tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
165
+ return tensor
src/datasets/utils/video/volume_transforms.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+ def convert_img(img):
12
+ """Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
13
+ if len(img.shape) == 3:
14
+ img = img.transpose(2, 0, 1)
15
+ if len(img.shape) == 2:
16
+ img = np.expand_dims(img, 0)
17
+ return img
18
+
19
+
20
+ class ClipToTensor(object):
21
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
22
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
23
+ """
24
+
25
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
26
+ self.channel_nb = channel_nb
27
+ self.div_255 = div_255
28
+ self.numpy = numpy
29
+
30
+ def __call__(self, clip):
31
+ """
32
+ Args: clip (list of numpy.ndarray): clip (list of images)
33
+ to be converted to tensor.
34
+ """
35
+ # Retrieve shape
36
+ if isinstance(clip[0], np.ndarray):
37
+ h, w, ch = clip[0].shape
38
+ assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
39
+ elif isinstance(clip[0], Image.Image):
40
+ w, h = clip[0].size
41
+ elif isinstance(clip[0], torch.Tensor):
42
+ tensor_clip = torch.stack(clip)
43
+ # Converting (T, C, H, W) -> (C, T, H, W) to match what `convert_img` followed by
44
+ # `np_clip[:, img_idx, :, :] = img` does for other data types.
45
+ tensor_clip = tensor_clip.permute(1, 0, 2, 3)
46
+ if not isinstance(tensor_clip, torch.FloatTensor):
47
+ tensor_clip = tensor_clip.float()
48
+ if self.div_255:
49
+ tensor_clip = torch.div(tensor_clip, 255)
50
+ return tensor_clip
51
+ else:
52
+ raise TypeError(
53
+ "Expected numpy.ndarray or PIL.Image or torch.Tensor\
54
+ but got list of {0}".format(
55
+ type(clip[0])
56
+ )
57
+ )
58
+
59
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
60
+
61
+ # Convert
62
+ for img_idx, img in enumerate(clip):
63
+ if isinstance(img, np.ndarray):
64
+ pass
65
+ elif isinstance(img, Image.Image):
66
+ img = np.array(img, copy=False)
67
+ else:
68
+ raise TypeError(
69
+ "Expected numpy.ndarray or PIL.Image\
70
+ but got list of {0}".format(
71
+ type(clip[0])
72
+ )
73
+ )
74
+ img = convert_img(img)
75
+ np_clip[:, img_idx, :, :] = img
76
+
77
+ if self.numpy:
78
+ if self.div_255:
79
+ np_clip = np_clip / 255.0
80
+ return np_clip
81
+
82
+ else:
83
+ tensor_clip = torch.from_numpy(np_clip)
84
+
85
+ if not isinstance(tensor_clip, torch.FloatTensor):
86
+ tensor_clip = tensor_clip.float()
87
+ if self.div_255:
88
+ tensor_clip = torch.div(tensor_clip, 255)
89
+ return tensor_clip
90
+
91
+
92
+ # Note this norms data to -1/1
93
+ class ClipToTensor_K(object):
94
+ """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
95
+ to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
96
+ """
97
+
98
+ def __init__(self, channel_nb=3, div_255=True, numpy=False):
99
+ self.channel_nb = channel_nb
100
+ self.div_255 = div_255
101
+ self.numpy = numpy
102
+
103
+ def __call__(self, clip):
104
+ """
105
+ Args: clip (list of numpy.ndarray): clip (list of images)
106
+ to be converted to tensor.
107
+ """
108
+ # Retrieve shape
109
+ if isinstance(clip[0], np.ndarray):
110
+ h, w, ch = clip[0].shape
111
+ assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
112
+ elif isinstance(clip[0], Image.Image):
113
+ w, h = clip[0].size
114
+ else:
115
+ raise TypeError(
116
+ "Expected numpy.ndarray or PIL.Image\
117
+ but got list of {0}".format(
118
+ type(clip[0])
119
+ )
120
+ )
121
+
122
+ np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
123
+
124
+ # Convert
125
+ for img_idx, img in enumerate(clip):
126
+ if isinstance(img, np.ndarray):
127
+ pass
128
+ elif isinstance(img, Image.Image):
129
+ img = np.array(img, copy=False)
130
+ else:
131
+ raise TypeError(
132
+ "Expected numpy.ndarray or PIL.Image\
133
+ but got list of {0}".format(
134
+ type(clip[0])
135
+ )
136
+ )
137
+ img = convert_img(img)
138
+ np_clip[:, img_idx, :, :] = img
139
+ if self.numpy:
140
+ if self.div_255:
141
+ np_clip = (np_clip - 127.5) / 127.5
142
+ return np_clip
143
+
144
+ else:
145
+ tensor_clip = torch.from_numpy(np_clip)
146
+
147
+ if not isinstance(tensor_clip, torch.FloatTensor):
148
+ tensor_clip = tensor_clip.float()
149
+ if self.div_255:
150
+ tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
151
+ return tensor_clip
152
+
153
+
154
+ class ToTensor(object):
155
+ """Converts numpy array to tensor"""
156
+
157
+ def __call__(self, array):
158
+ tensor = torch.from_numpy(array)
159
+ return tensor
src/datasets/utils/weighted_sampler.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Iterator, Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import DistributedSampler, RandomSampler
12
+
13
+ from src.utils.logging import get_logger
14
+
15
+ logger = get_logger("WeightedSampler")
16
+
17
+
18
+ class DistributedWeightedSampler(DistributedSampler):
19
+ """
20
+ This class implements a weighted sampler for distributed training.
21
+ See https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler for more details.
22
+
23
+ It shares the same interface as `torch.utils.data.DistributedSampler`.
24
+ The effective change is replacing `DistributedSampler`'s `torch.randperm` for generating the sequence
25
+ of indices with `numpy.random.Generator.choice`, with replacement. This allows weighted sampling and
26
+ avoiding issue with `torch.randperm` when the number of samples is larger than 2^24 samples.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset,
32
+ num_replicas: Optional[int] = None,
33
+ rank: Optional[int] = None,
34
+ shuffle: bool = True,
35
+ seed: int = 0,
36
+ drop_last: bool = False,
37
+ ):
38
+ logger.info(f"Using DistributedWeightedSampler with rank {rank} / {num_replicas}")
39
+ assert hasattr(
40
+ dataset, "sample_weights"
41
+ ), "Dataset must have sample_weights property for using DistributedWeightedSampler"
42
+ super().__init__(
43
+ dataset,
44
+ num_replicas=num_replicas,
45
+ rank=rank,
46
+ shuffle=shuffle,
47
+ seed=seed,
48
+ drop_last=drop_last,
49
+ )
50
+
51
+ @property
52
+ def sample_probabilities(self) -> np.ndarray:
53
+ sample_weights = self.dataset.sample_weights
54
+ if isinstance(sample_weights, torch.Tensor):
55
+ sample_weights = sample_weights.cpu().numpy()
56
+ elif isinstance(sample_weights, list):
57
+ sample_weights = np.array(sample_weights)
58
+ assert isinstance(
59
+ sample_weights, np.ndarray
60
+ ), f"sample_weights must be a numpy array, torch.Tensor, or python list; got {type(sample_weights)}"
61
+ return sample_weights / np.sum(sample_weights)
62
+
63
+ def __iter__(self) -> Iterator:
64
+ n = len(self.dataset)
65
+
66
+ # deterministically shuffle based on epoch and seed
67
+ rng = np.random.default_rng(self.seed + self.epoch)
68
+ indices = rng.choice(
69
+ range(0, n),
70
+ size=self.total_size,
71
+ p=self.sample_probabilities,
72
+ replace=True,
73
+ ).tolist()
74
+
75
+ if not self.drop_last:
76
+ # add extra samples to make it evenly divisible
77
+ padding_size = self.total_size - len(indices)
78
+ if padding_size <= len(indices):
79
+ indices += indices[:padding_size]
80
+ else:
81
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
82
+ else:
83
+ # remove tail of data to make it evenly divisible
84
+ indices = indices[: self.total_size]
85
+ assert len(indices) == self.total_size
86
+
87
+ # subsample
88
+ indices = indices[self.rank : self.total_size : self.num_replicas]
89
+ assert len(indices) == self.num_samples
90
+
91
+ return iter(indices)
92
+
93
+
94
+ class MemoryEfficientDistributedWeightedSampler(DistributedSampler):
95
+ """
96
+ This class implements a memory efficient version of `DistributedWeightedSampler`.
97
+ It shares the same interface as `DistributedWeightedSampler`.
98
+ The effective change is just-in-time sampling of the indices, instead of pre-computing them.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ dataset,
104
+ num_replicas: Optional[int] = None,
105
+ rank: Optional[int] = None,
106
+ shuffle: bool = True,
107
+ seed: int = 0,
108
+ ):
109
+ logger.info(f"Using MemoryEfficientDistributedWeightedSampler with rank {rank} / {num_replicas}")
110
+ assert hasattr(
111
+ dataset, "dataset_weights"
112
+ ), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSampler"
113
+ super().__init__(
114
+ dataset,
115
+ num_replicas=num_replicas,
116
+ rank=rank,
117
+ shuffle=shuffle,
118
+ seed=seed,
119
+ )
120
+
121
+ self.dataset_weights = dataset.dataset_weights
122
+ self.dataset_sizes = [len(d) for d in dataset.datasets]
123
+ if len(self.dataset_sizes) != len(self.dataset_weights):
124
+ raise ValueError(
125
+ f"Number of datasets ({len(self.dataset_sizes)}) "
126
+ f"does not match number of dataset weights ({len(self.dataset_weights)})"
127
+ )
128
+
129
+ if self.shuffle:
130
+ self.rng = np.random.default_rng(self.seed + self.rank + self.epoch)
131
+ total_weights = sum(self.dataset_weights)
132
+ self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights])
133
+ else:
134
+ if any([not isinstance(w, int) for w in self.dataset_weights]):
135
+ raise ValueError("Dataset weights must be integers when shuffle is False")
136
+
137
+ self.dataset_orders = []
138
+ for i, w in enumerate(self.dataset_weights):
139
+ self.dataset_orders.extend([i] * w)
140
+
141
+ self.drawn_samples = 0
142
+
143
+ def __iter__(self) -> Iterator:
144
+ return self
145
+
146
+ def __next__(self) -> int:
147
+ if self.shuffle:
148
+ selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities)
149
+
150
+ # In order to avoid sampling the same example multiple times between the ranks,
151
+ # we limit each rank to a subset of the total number of samples in the dataset.
152
+ # For example if our dataet is [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], and we have 2 ranks,
153
+ # then rank 0 will ONLY sample from [0, 2, 4, 6, 8], and rank 1 from [1, 3, 5, 7, 9].
154
+ # In each iteration we first produce `in_rank_sample` which is the sample index in the rank,
155
+ # based on the size of the subset which that rank can sample from.
156
+ # Then we computer `sample_idx_in_dataset` for the indx of the sample in the whole dataset.
157
+ # For the above example if we are sampling for rank 1, we have `self.rng.integers(5)`.
158
+ # Let's assume the result is 2, then `in_rank_sample` is 2 (number "5" in the subset),
159
+ # so the sample index in the whole dataset is
160
+ # `in_rank_sample * self.num_replicas + self.rank`: 2 * 2 + 1 = 5.
161
+
162
+ selected_dataset_size = self.dataset_sizes[selected_dataset_idx]
163
+ # 1) Getting sample index in the rank.
164
+ # NOTE: this may effectively drops the last batch,
165
+ # but given the sample sizes that we use this sampler with, it should not be an issue.
166
+ num_samples_in_rank = selected_dataset_size // self.num_replicas
167
+ in_rank_sample = self.rng.integers(num_samples_in_rank)
168
+
169
+ # 2) Getting sample index in the dataset.
170
+ sample_idx_in_dataset = in_rank_sample * self.num_replicas + self.rank
171
+
172
+ else:
173
+ # Iterate through the dataset orders in a round-robin fashion, offset by the rank
174
+ dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders)
175
+ selected_dataset_idx = self.dataset_orders[dataset_orders_idx]
176
+ # Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples
177
+ sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[
178
+ selected_dataset_idx
179
+ ]
180
+ self.drawn_samples += 1
181
+
182
+ # Getting the index of the sample in the whole dataset
183
+ # For example if the total dataset has 4 datasets with sizes [10, 20, 30, 5].
184
+ # and our selected_dataset_idx=3 and sample_idx_in_dataset=5
185
+ # then the index of the sample in the whole dataset is
186
+ # 10 (for dataset 1) + 20 (for dataset 1) + 5 (for sample_idx_in_dataset) = 35
187
+ # This is because the first 10 samples are from the first dataset, the next 20 are from the second dataset,
188
+ # then we reach at the 3rd dataset which is the selected dataset, and the 5th sample in the 3rd dataset.
189
+ sample_idx = 0
190
+ for i, d in enumerate(self.dataset_sizes):
191
+ if selected_dataset_idx == i:
192
+ break
193
+ sample_idx += d
194
+ sample_idx += sample_idx_in_dataset
195
+
196
+ return sample_idx
197
+
198
+
199
+ def safe_next(iterator):
200
+ try:
201
+ return next(iterator)
202
+ except StopIteration:
203
+ return None
204
+
205
+
206
+ class MemoryEfficientDistributedWeightedSamplerLessRepeat(DistributedSampler):
207
+ """
208
+ This class implements a memory efficient version of `DistributedWeightedSampler`.
209
+ It shares the same interface as `DistributedWeightedSampler`.
210
+ The effective change is pre-computing the permutations of indices over a subset of total indices.
211
+ This subset is the selected with picking the indices in a dataset with steps sizes equal to the world size.
212
+ For example, if world size is 12 and rank is 2, for a dataset of size N,
213
+ this sampler only permutes the indices in range(2, n, 12)
214
+
215
+ Compared with MemoryEfficientDistributedWeightedSampler, this will reduce the effective number of repeat.
216
+ See discussions here: https://github.com/fairinternal/jepa-internal/pull/254
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ dataset,
222
+ num_replicas: Optional[int] = None,
223
+ rank: Optional[int] = None,
224
+ shuffle: bool = True,
225
+ seed: int = 0,
226
+ ):
227
+ logger.info(f"Using MemoryEfficientDistributedWeightedSamplerLessRepeat with rank {rank} / {num_replicas}")
228
+ assert hasattr(
229
+ dataset, "dataset_weights"
230
+ ), "Dataset must have dataset_weights property for using MemoryEfficientDistributedWeightedSamplerLessRepeat"
231
+ super().__init__(
232
+ dataset,
233
+ num_replicas=num_replicas,
234
+ rank=rank,
235
+ shuffle=shuffle,
236
+ seed=seed,
237
+ )
238
+
239
+ self._generator = torch.Generator()
240
+ self._generator.manual_seed(seed)
241
+
242
+ self.dataset_weights = dataset.dataset_weights
243
+ self.dataset_sizes = [len(d) for d in dataset.datasets]
244
+ if len(self.dataset_sizes) != len(self.dataset_weights):
245
+ raise ValueError(
246
+ f"Number of datasets ({len(self.dataset_sizes)}) "
247
+ f"does not match number of dataset weights ({len(self.dataset_weights)})"
248
+ )
249
+
250
+ if self.shuffle:
251
+ self.rng = np.random.default_rng(self.seed + self.rank + self.epoch)
252
+ total_weights = sum(self.dataset_weights)
253
+ self.dataset_probablities = np.array([w / total_weights for w in self.dataset_weights])
254
+
255
+ # For each dataset we generate a permutation of the indices that will be processed by that rank.
256
+ # This is going to be the subset of indices, selected by the steps sizes of the world size.
257
+ logger.info(f"Generating dataset indices for rank {self.rank} / {self.num_replicas}")
258
+
259
+ # Getting a RandomSampler for indices assigned to each dataset.
260
+ self.individual_dataset_sampler = []
261
+ for ids, ds in enumerate(self.dataset_sizes):
262
+
263
+ # NOTE: this may effectively drops the last batch,
264
+ # but given the sample sizes that we use this sampler with, it should not be an issue.
265
+ num_samples_in_rank = ds // self.num_replicas
266
+ self.individual_dataset_sampler.append(self._new_sampler(num_samples_in_rank))
267
+
268
+ else:
269
+ if any([not isinstance(w, int) for w in self.dataset_weights]):
270
+ raise ValueError("Dataset weights must be integers when shuffle is False")
271
+
272
+ self.dataset_orders = []
273
+ for i, w in enumerate(self.dataset_weights):
274
+ self.dataset_orders.extend([i] * w)
275
+
276
+ self.drawn_samples = 0
277
+
278
+ def __iter__(self) -> Iterator:
279
+ return self
280
+
281
+ def _new_sampler(self, sample_size: int) -> RandomSampler:
282
+ assert self.shuffle
283
+
284
+ return iter(
285
+ RandomSampler(
286
+ range(sample_size),
287
+ generator=self._generator,
288
+ )
289
+ )
290
+
291
+ def _in_rank_next_index_for_dataset(self, dataset_idx: int) -> int:
292
+ assert self.shuffle
293
+
294
+ next_sampler_idx = safe_next(self.individual_dataset_sampler[dataset_idx])
295
+ if next_sampler_idx is None:
296
+ # We have reached the end of the dataset, we need to reset the sampler.
297
+ num_samples_in_rank = self.dataset_sizes[dataset_idx] // self.num_replicas
298
+ self.individual_dataset_sampler[dataset_idx] = self._new_sampler(num_samples_in_rank)
299
+ next_sampler_idx = safe_next(self.individual_dataset_sampler[dataset_idx])
300
+ assert next_sampler_idx is not None
301
+
302
+ return next_sampler_idx
303
+
304
+ def __next__(self) -> int:
305
+ if self.shuffle:
306
+ selected_dataset_idx = self.rng.choice(range(len(self.dataset_weights)), p=self.dataset_probablities)
307
+ in_rank_sample = self._in_rank_next_index_for_dataset(selected_dataset_idx)
308
+
309
+ # 2) Getting sample index in the dataset.
310
+ sample_idx_in_dataset = in_rank_sample * self.num_replicas + self.rank
311
+
312
+ else:
313
+ # Iterate through the dataset orders in a round-robin fashion, offset by the rank
314
+ dataset_orders_idx = (self.rank + self.drawn_samples) % len(self.dataset_orders)
315
+ selected_dataset_idx = self.dataset_orders[dataset_orders_idx]
316
+ # Get the sample index in the selected dataset by skipping with the num_replicas * drawn_samples
317
+ sample_idx_in_dataset = (self.drawn_samples * self.num_replicas + self.rank) % self.dataset_sizes[
318
+ selected_dataset_idx
319
+ ]
320
+ self.drawn_samples += 1
321
+
322
+ # Getting the index of the sample in the whole dataset
323
+ # For example if the total dataset has 4 datasets with sizes [10, 20, 30, 5].
324
+ # and our selected_dataset_idx=3 and sample_idx_in_dataset=5
325
+ # then the index of the sample in the whole dataset is
326
+ # 10 (for dataset 1) + 20 (for dataset 1) + 5 (for sample_idx_in_dataset) = 35
327
+ # This is because the first 10 samples are from the first dataset, the next 20 are from the second dataset,
328
+ # then we reach at the 3rd dataset which is the selected dataset, and the 5th sample in the 3rd dataset.
329
+ sample_idx = 0
330
+ for i, d in enumerate(self.dataset_sizes):
331
+ if selected_dataset_idx == i:
332
+ break
333
+ sample_idx += d
334
+ sample_idx += sample_idx_in_dataset
335
+
336
+ return sample_idx
src/datasets/utils/worker_init_fn.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Copyright The Lightning AI team.
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # This code originally comes from PyTorch Lighting with some light modificaitons:
18
+ # https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/utilities/seed.py
19
+
20
+
21
+ import os
22
+ import random
23
+ from typing import Optional
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+ from src.utils.logging import get_logger
29
+
30
+ logger = get_logger("worker_init_fn")
31
+
32
+
33
+ def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]:
34
+ """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
35
+ algorithm."""
36
+ # Combine base seed, worker id and rank into a unique 64-bit number
37
+ combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
38
+ seeds = []
39
+ for _ in range(count):
40
+ # x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
41
+ combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
42
+ seeds.append(combined_seed)
43
+ return seeds
44
+
45
+
46
+ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
47
+ r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
48
+ ``seed_everything(seed, workers=True)``.
49
+
50
+ See also the PyTorch documentation on
51
+ `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
52
+
53
+ """
54
+ # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
55
+ if rank is None:
56
+ procid = os.environ.get("SLURM_PROCID")
57
+ if procid is None:
58
+ logger.warning("SLURM_PROCID is not set, setting rank to 0")
59
+ rank = 0
60
+ else:
61
+ rank = int(procid)
62
+
63
+ process_seed = torch.initial_seed()
64
+ # back out the base seed so we can use all the bits
65
+ base_seed = process_seed - worker_id
66
+ logger.debug(
67
+ f"Initializing random number generators of process {rank} worker {worker_id} with base seed {base_seed}"
68
+ )
69
+ seed_sequence = _generate_seed_sequence(base_seed, worker_id, rank, count=4)
70
+ torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed
71
+ random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds
72
+
73
+ ss = np.random.SeedSequence([base_seed, worker_id, rank])
74
+ np_rng_seed = ss.generate_state(4)
75
+
76
+ np.random.seed(np_rng_seed)
src/datasets/video_dataset.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import os
8
+ import pathlib
9
+ import warnings
10
+ from logging import getLogger
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import torchvision
16
+ from decord import VideoReader, cpu
17
+
18
+ from src.datasets.utils.dataloader import ConcatIndices, MonitoredDataset, NondeterministicDataLoader
19
+ from src.datasets.utils.weighted_sampler import DistributedWeightedSampler
20
+
21
+ _GLOBAL_SEED = 0
22
+ logger = getLogger()
23
+
24
+
25
+ def make_videodataset(
26
+ data_paths,
27
+ batch_size,
28
+ frames_per_clip=8,
29
+ dataset_fpcs=None,
30
+ frame_step=4,
31
+ duration=None,
32
+ fps=None,
33
+ num_clips=1,
34
+ random_clip_sampling=True,
35
+ allow_clip_overlap=False,
36
+ filter_short_videos=False,
37
+ filter_long_videos=int(10**9),
38
+ transform=None,
39
+ shared_transform=None,
40
+ rank=0,
41
+ world_size=1,
42
+ datasets_weights=None,
43
+ collator=None,
44
+ drop_last=True,
45
+ num_workers=10,
46
+ pin_mem=True,
47
+ persistent_workers=True,
48
+ deterministic=True,
49
+ log_dir=None,
50
+ ):
51
+ dataset = VideoDataset(
52
+ data_paths=data_paths,
53
+ datasets_weights=datasets_weights,
54
+ frames_per_clip=frames_per_clip,
55
+ dataset_fpcs=dataset_fpcs,
56
+ duration=duration,
57
+ fps=fps,
58
+ frame_step=frame_step,
59
+ num_clips=num_clips,
60
+ random_clip_sampling=random_clip_sampling,
61
+ allow_clip_overlap=allow_clip_overlap,
62
+ filter_short_videos=filter_short_videos,
63
+ filter_long_videos=filter_long_videos,
64
+ shared_transform=shared_transform,
65
+ transform=transform,
66
+ )
67
+
68
+ log_dir = pathlib.Path(log_dir) if log_dir else None
69
+ if log_dir:
70
+ log_dir.mkdir(parents=True, exist_ok=True)
71
+ # Worker ID will replace '%w'
72
+ resource_log_filename = log_dir / f"resource_file_{rank}_%w.csv"
73
+ dataset = MonitoredDataset(
74
+ dataset=dataset,
75
+ log_filename=str(resource_log_filename),
76
+ log_interval=10.0,
77
+ monitor_interval=5.0,
78
+ )
79
+
80
+ logger.info("VideoDataset dataset created")
81
+ if datasets_weights is not None:
82
+ dist_sampler = DistributedWeightedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
83
+ else:
84
+ dist_sampler = torch.utils.data.distributed.DistributedSampler(
85
+ dataset, num_replicas=world_size, rank=rank, shuffle=True
86
+ )
87
+
88
+ if deterministic:
89
+ data_loader = torch.utils.data.DataLoader(
90
+ dataset,
91
+ collate_fn=collator,
92
+ sampler=dist_sampler,
93
+ batch_size=batch_size,
94
+ drop_last=drop_last,
95
+ pin_memory=pin_mem,
96
+ num_workers=num_workers,
97
+ persistent_workers=(num_workers > 0) and persistent_workers,
98
+ )
99
+ else:
100
+ data_loader = NondeterministicDataLoader(
101
+ dataset,
102
+ collate_fn=collator,
103
+ sampler=dist_sampler,
104
+ batch_size=batch_size,
105
+ drop_last=drop_last,
106
+ pin_memory=pin_mem,
107
+ num_workers=num_workers,
108
+ persistent_workers=(num_workers > 0) and persistent_workers,
109
+ )
110
+ logger.info("VideoDataset unsupervised data loader created")
111
+
112
+ return dataset, data_loader, dist_sampler
113
+
114
+
115
+ class VideoDataset(torch.utils.data.Dataset):
116
+ """Video classification dataset."""
117
+
118
+ def __init__(
119
+ self,
120
+ data_paths,
121
+ datasets_weights=None,
122
+ frames_per_clip=16,
123
+ fps=None,
124
+ dataset_fpcs=None,
125
+ frame_step=4,
126
+ num_clips=1,
127
+ transform=None,
128
+ shared_transform=None,
129
+ random_clip_sampling=True,
130
+ allow_clip_overlap=False,
131
+ filter_short_videos=False,
132
+ filter_long_videos=int(10**9),
133
+ duration=None, # duration in seconds
134
+ ):
135
+ self.data_paths = data_paths
136
+ self.datasets_weights = datasets_weights
137
+ self.frame_step = frame_step
138
+ self.num_clips = num_clips
139
+ self.transform = transform
140
+ self.shared_transform = shared_transform
141
+ self.random_clip_sampling = random_clip_sampling
142
+ self.allow_clip_overlap = allow_clip_overlap
143
+ self.filter_short_videos = filter_short_videos
144
+ self.filter_long_videos = filter_long_videos
145
+ self.duration = duration
146
+ self.fps = fps
147
+
148
+ if sum([v is not None for v in (fps, duration, frame_step)]) != 1:
149
+ raise ValueError(f"Must specify exactly one of either {fps=}, {duration=}, or {frame_step=}.")
150
+
151
+ if isinstance(data_paths, str):
152
+ data_paths = [data_paths]
153
+
154
+ if dataset_fpcs is None:
155
+ self.dataset_fpcs = [frames_per_clip for _ in data_paths]
156
+ else:
157
+ if len(dataset_fpcs) != len(data_paths):
158
+ raise ValueError("Frames per clip not properly specified for NFS data paths")
159
+ self.dataset_fpcs = dataset_fpcs
160
+
161
+ if VideoReader is None:
162
+ raise ImportError('Unable to import "decord" which is required to read videos.')
163
+
164
+ # Load video paths and labels
165
+ samples, labels = [], []
166
+ self.num_samples_per_dataset = []
167
+ for data_path in self.data_paths:
168
+
169
+ if data_path[-4:] == ".csv":
170
+ try:
171
+ data = pd.read_csv(data_path, header=None, delimiter=" ")
172
+ except pd.errors.ParserError:
173
+ # In image captioning datasets where we have space, we use :: as delimiter.
174
+ data = pd.read_csv(data_path, header=None, delimiter="::")
175
+ samples += list(data.values[:, 0])
176
+ labels += list(data.values[:, 1])
177
+ num_samples = len(data)
178
+ self.num_samples_per_dataset.append(num_samples)
179
+
180
+ elif data_path[-4:] == ".npy":
181
+ data = np.load(data_path, allow_pickle=True)
182
+ data = list(map(lambda x: repr(x)[1:-1], data))
183
+ samples += data
184
+ labels += [0] * len(data)
185
+ num_samples = len(data)
186
+ self.num_samples_per_dataset.append(len(data))
187
+
188
+ self.per_dataset_indices = ConcatIndices(self.num_samples_per_dataset)
189
+
190
+ # [Optional] Weights for each sample to be used by downstream
191
+ # weighted video sampler
192
+ self.sample_weights = None
193
+ if self.datasets_weights is not None:
194
+ self.sample_weights = []
195
+ for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset):
196
+ self.sample_weights += [dw / ns] * ns
197
+
198
+ self.samples = samples
199
+ self.labels = labels
200
+
201
+ def __getitem__(self, index):
202
+ sample = self.samples[index]
203
+ loaded_sample = False
204
+ # Keep trying to load videos until you find a valid sample
205
+ while not loaded_sample:
206
+ if not isinstance(sample, str):
207
+ logger.warning("Invalid sample.")
208
+ else:
209
+ if sample.split(".")[-1].lower() in ("jpg", "png", "jpeg"):
210
+ loaded_sample = self.get_item_image(index)
211
+ else:
212
+ loaded_sample = self.get_item_video(index)
213
+
214
+ if not loaded_sample:
215
+ index = np.random.randint(self.__len__())
216
+ sample = self.samples[index]
217
+
218
+ return loaded_sample
219
+
220
+ def get_item_video(self, index):
221
+ sample = self.samples[index]
222
+ dataset_idx, _ = self.per_dataset_indices[index]
223
+ frames_per_clip = self.dataset_fpcs[dataset_idx]
224
+
225
+ buffer, clip_indices = self.loadvideo_decord(sample, frames_per_clip) # [T H W 3]
226
+ loaded_video = len(buffer) > 0
227
+ if not loaded_video:
228
+ return
229
+
230
+ # Label/annotations for video
231
+ label = self.labels[index]
232
+
233
+ def split_into_clips(video):
234
+ """Split video into a list of clips"""
235
+ fpc = frames_per_clip
236
+ nc = self.num_clips
237
+ return [video[i * fpc : (i + 1) * fpc] for i in range(nc)]
238
+
239
+ # Parse video into frames & apply data augmentations
240
+ if self.shared_transform is not None:
241
+ buffer = self.shared_transform(buffer)
242
+ buffer = split_into_clips(buffer)
243
+ if self.transform is not None:
244
+ buffer = [self.transform(clip) for clip in buffer]
245
+
246
+ return buffer, label, clip_indices
247
+
248
+ def get_item_image(self, index):
249
+ sample = self.samples[index]
250
+ dataset_idx, _ = self.per_dataset_indices[index]
251
+ fpc = self.dataset_fpcs[dataset_idx]
252
+
253
+ try:
254
+ image_tensor = torchvision.io.read_image(path=sample, mode=torchvision.io.ImageReadMode.RGB)
255
+ except Exception:
256
+ return
257
+ label = self.labels[index]
258
+ clip_indices = [np.arange(start=0, stop=fpc, dtype=np.int32)]
259
+
260
+ # Expanding the input image [3, H, W] ==> [T, 3, H, W]
261
+ buffer = image_tensor.unsqueeze(dim=0).repeat((fpc, 1, 1, 1))
262
+ buffer = buffer.permute((0, 2, 3, 1)) # [T, 3, H, W] ==> [T H W 3]
263
+
264
+ if self.shared_transform is not None:
265
+ # Technically we can have only transform, doing this just for the sake of consistency with videos.
266
+ buffer = self.shared_transform(buffer)
267
+
268
+ if self.transform is not None:
269
+ buffer = [self.transform(buffer)]
270
+
271
+ return buffer, label, clip_indices
272
+
273
+ def loadvideo_decord(self, sample, fpc):
274
+ """Load video content using Decord"""
275
+
276
+ fname = sample
277
+ if not os.path.exists(fname):
278
+ warnings.warn(f"video path not found {fname=}")
279
+ return [], None
280
+
281
+ _fsize = os.path.getsize(fname)
282
+ if _fsize > self.filter_long_videos:
283
+ warnings.warn(f"skipping long video of size {_fsize=} (bytes)")
284
+ return [], None
285
+
286
+ try:
287
+ vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
288
+ except Exception:
289
+ return [], None
290
+
291
+ fstp = self.frame_step
292
+ if self.duration is not None or self.fps is not None:
293
+ try:
294
+ video_fps = math.ceil(vr.get_avg_fps())
295
+ except Exception as e:
296
+ logger.warning(e)
297
+
298
+ if self.duration is not None:
299
+ assert self.fps is None
300
+ fstp = int(self.duration * video_fps / fpc)
301
+ else:
302
+ assert self.duration is None
303
+ fstp = video_fps // self.fps
304
+
305
+ assert fstp is not None and fstp > 0
306
+ clip_len = int(fpc * fstp)
307
+
308
+ if self.filter_short_videos and len(vr) < clip_len:
309
+ warnings.warn(f"skipping video of length {len(vr)}")
310
+ return [], None
311
+
312
+ vr.seek(0) # Go to start of video before sampling frames
313
+
314
+ # Partition video into equal sized segments and sample each clip
315
+ # from a different segment
316
+ partition_len = len(vr) // self.num_clips
317
+
318
+ all_indices, clip_indices = [], []
319
+ for i in range(self.num_clips):
320
+
321
+ if partition_len > clip_len:
322
+ # If partition_len > clip len, then sample a random window of
323
+ # clip_len frames within the segment
324
+ end_indx = clip_len
325
+ if self.random_clip_sampling:
326
+ end_indx = np.random.randint(clip_len, partition_len)
327
+ start_indx = end_indx - clip_len
328
+ indices = np.linspace(start_indx, end_indx, num=fpc)
329
+ indices = np.clip(indices, start_indx, end_indx - 1).astype(np.int64)
330
+ # --
331
+ indices = indices + i * partition_len
332
+ else:
333
+ # If partition overlap not allowed and partition_len < clip_len
334
+ # then repeatedly append the last frame in the segment until
335
+ # we reach the desired clip length
336
+ if not self.allow_clip_overlap:
337
+ indices = np.linspace(0, partition_len, num=partition_len // fstp)
338
+ indices = np.concatenate(
339
+ (
340
+ indices,
341
+ np.ones(fpc - partition_len // fstp) * partition_len,
342
+ )
343
+ )
344
+ indices = np.clip(indices, 0, partition_len - 1).astype(np.int64)
345
+ # --
346
+ indices = indices + i * partition_len
347
+
348
+ # If partition overlap is allowed and partition_len < clip_len
349
+ # then start_indx of segment i+1 will lie within segment i
350
+ else:
351
+ sample_len = min(clip_len, len(vr)) - 1
352
+ indices = np.linspace(0, sample_len, num=sample_len // fstp)
353
+ indices = np.concatenate(
354
+ (
355
+ indices,
356
+ np.ones(fpc - sample_len // fstp) * sample_len,
357
+ )
358
+ )
359
+ indices = np.clip(indices, 0, sample_len - 1).astype(np.int64)
360
+ # --
361
+ clip_step = 0
362
+ if len(vr) > clip_len:
363
+ clip_step = (len(vr) - clip_len) // (self.num_clips - 1)
364
+ indices = indices + i * clip_step
365
+
366
+ clip_indices.append(indices)
367
+ all_indices.extend(list(indices))
368
+
369
+ buffer = vr.get_batch(all_indices).asnumpy()
370
+ return buffer, clip_indices
371
+
372
+ def __len__(self):
373
+ return len(self.samples)
src/hub/__init__.py ADDED
File without changes
src/hub/backbones.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2"
9
+
10
+ # for testing
11
+ # VJEPA_BASE_URL = "http://localhost:8300"
12
+
13
+ ARCH_NAME_MAP = {
14
+ "vit_large": ("vit_large", "vitl"),
15
+ "vit_huge": ("vit_huge", "vith"),
16
+ "vit_giant": ("vit_giant_xformers", "vitg"),
17
+ "vit_ac_giant": ("vit_giant_xformers", "vjepa2-ac-vitg"),
18
+ "vit_giant_384": ("vit_giant_xformers", "vitg-384"),
19
+ }
20
+
21
+
22
+ def _clean_backbone_key(state_dict):
23
+ for key, val in state_dict.copy().items():
24
+ _ = state_dict.pop(key)
25
+ key = key.replace("module.", "")
26
+ key = key.replace("backbone.", "")
27
+ state_dict[key] = val
28
+ return state_dict
29
+
30
+
31
+ def _make_vjepa2_ac_model(
32
+ *,
33
+ model_name: str = "vit_ac_giant",
34
+ img_size=256,
35
+ patch_size=16,
36
+ tubelet_size=2,
37
+ num_frames=64,
38
+ pretrained: bool = True,
39
+ **kwargs,
40
+ ):
41
+ from ..models import ac_predictor as vit_ac_predictor
42
+ from ..models import vision_transformer as vit_encoder
43
+
44
+ vit_encoder_kwargs = dict(
45
+ patch_size=patch_size,
46
+ img_size=(img_size, img_size),
47
+ num_frames=num_frames,
48
+ tubelet_size=tubelet_size,
49
+ use_sdpa=True,
50
+ use_SiLU=False,
51
+ wide_SiLU=True,
52
+ uniform_power=False,
53
+ use_rope=True,
54
+ )
55
+ vit_encoder_kwargs.update(**kwargs)
56
+
57
+ arch_name = ARCH_NAME_MAP[model_name][0]
58
+ encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
59
+
60
+ vit_predictor_kwargs = dict(
61
+ img_size=(img_size, img_size),
62
+ patch_size=patch_size,
63
+ num_frames=num_frames,
64
+ tubelet_size=tubelet_size,
65
+ embed_dim=encoder.embed_dim,
66
+ )
67
+ vit_predictor_kwargs.update(**kwargs)
68
+
69
+ predictor = vit_ac_predictor.__dict__["vit_ac_predictor"](**vit_predictor_kwargs)
70
+
71
+ if pretrained:
72
+ model_file = ARCH_NAME_MAP[model_name][-1]
73
+ url = VJEPA_BASE_URL + f"/{model_file}.pt"
74
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
75
+ encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
76
+ encoder.load_state_dict(encoder_state_dict, strict=False)
77
+ predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
78
+ predictor.load_state_dict(predictor_state_dict, strict=True)
79
+
80
+ return encoder, predictor
81
+
82
+
83
+ def _make_vjepa2_model(
84
+ *,
85
+ model_name: str = "vit_large",
86
+ img_size=256,
87
+ patch_size=16,
88
+ tubelet_size=2,
89
+ num_frames=64,
90
+ pretrained: bool = True,
91
+ **kwargs,
92
+ ):
93
+ from ..models import predictor as vit_predictor
94
+ from ..models import vision_transformer as vit_encoder
95
+
96
+ vit_encoder_kwargs = dict(
97
+ patch_size=patch_size,
98
+ img_size=(img_size, img_size),
99
+ num_frames=num_frames,
100
+ tubelet_size=tubelet_size,
101
+ use_sdpa=True,
102
+ use_SiLU=False,
103
+ wide_SiLU=True,
104
+ uniform_power=False,
105
+ use_rope=True,
106
+ )
107
+ vit_encoder_kwargs.update(**kwargs)
108
+
109
+ arch_name = ARCH_NAME_MAP[model_name][0]
110
+ encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
111
+
112
+ vit_predictor_kwargs = dict(
113
+ img_size=(img_size, img_size),
114
+ patch_size=patch_size,
115
+ use_mask_tokens=True,
116
+ embed_dim=encoder.embed_dim,
117
+ predictor_embed_dim=384,
118
+ num_frames=num_frames,
119
+ tubelet_size=tubelet_size,
120
+ depth=12,
121
+ num_heads=12,
122
+ num_mask_tokens=10,
123
+ use_rope=True,
124
+ uniform_power=False,
125
+ use_sdpa=True,
126
+ use_silu=False,
127
+ wide_silu=True,
128
+ )
129
+ vit_predictor_kwargs.update(**kwargs)
130
+
131
+ predictor = vit_predictor.__dict__["vit_predictor"](**vit_predictor_kwargs)
132
+
133
+ if pretrained:
134
+ model_file = ARCH_NAME_MAP[model_name][-1]
135
+ url = VJEPA_BASE_URL + f"/{model_file}.pt"
136
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
137
+ encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
138
+ encoder.load_state_dict(encoder_state_dict, strict=False) # state_dict has pos_embed but we use RoPE
139
+ predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
140
+ predictor.load_state_dict(predictor_state_dict, strict=False) # state_dict has pos_embed but we use RoPE
141
+
142
+ return encoder, predictor
143
+
144
+
145
+ def vjepa2_vit_large(*, pretrained: bool = True, **kwargs):
146
+ """
147
+ VJEPA 2 ViT-Large model
148
+ """
149
+ return _make_vjepa2_model(model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs)
150
+
151
+
152
+ def vjepa2_vit_huge(*, pretrained: bool = True, **kwargs):
153
+ """
154
+ VJEPA 2 ViT-Huge model
155
+ """
156
+ return _make_vjepa2_model(model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs)
157
+
158
+
159
+ def vjepa2_vit_giant(*, pretrained: bool = True, **kwargs):
160
+ """
161
+ VJEPA 2 ViT-giant model
162
+ """
163
+ return _make_vjepa2_model(model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs)
164
+
165
+
166
+ def vjepa2_vit_giant_384(*, pretrained: bool = True, **kwargs):
167
+ """
168
+ VJEPA 2 ViT-giant-384 model
169
+ """
170
+ return _make_vjepa2_model(model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs)
171
+
172
+
173
+ def vjepa2_ac_vit_giant(*, pretrained: bool = True, **kwargs):
174
+ """
175
+ VJEPA 2-AC ViT-giant model
176
+ """
177
+ return _make_vjepa2_ac_model(model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs)
src/masks/__pycache__/utils.cpython-312.pyc ADDED
Binary file (923 Bytes). View file
 
src/masks/default.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from logging import getLogger
7
+
8
+ import torch
9
+
10
+ _GLOBAL_SEED = 0
11
+ logger = getLogger()
12
+
13
+
14
+ class DefaultCollator(object):
15
+
16
+ def __call__(self, batch):
17
+ collated_batch = torch.utils.data.default_collate(batch)
18
+ return collated_batch, None, None
src/masks/multiseq_multiblock3d.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from logging import getLogger
8
+ from multiprocessing import Value
9
+
10
+ import torch
11
+
12
+ _GLOBAL_SEED = 0
13
+ logger = getLogger()
14
+
15
+
16
+ class MaskCollator(object):
17
+
18
+ def __init__(
19
+ self,
20
+ cfgs_mask,
21
+ dataset_fpcs,
22
+ crop_size=(224, 224),
23
+ patch_size=(16, 16),
24
+ tubelet_size=2,
25
+ ):
26
+ super(MaskCollator, self).__init__()
27
+
28
+ self.mask_generators = dict()
29
+ for fpc in dataset_fpcs:
30
+ self.mask_generators[fpc] = []
31
+ for m in cfgs_mask:
32
+ mask_generator = _MaskGenerator(
33
+ crop_size=crop_size,
34
+ num_frames=fpc,
35
+ spatial_patch_size=patch_size,
36
+ temporal_patch_size=tubelet_size,
37
+ spatial_pred_mask_scale=m.get("spatial_scale"),
38
+ temporal_pred_mask_scale=m.get("temporal_scale"),
39
+ aspect_ratio=m.get("aspect_ratio"),
40
+ npred=m.get("num_blocks"),
41
+ max_context_frames_ratio=m.get("max_temporal_keep", 1.0),
42
+ max_keep=m.get("max_keep", None),
43
+ full_complement=m.get("full_complement", False),
44
+ pred_full_complement=m.get("pred_full_complement", False),
45
+ inv_block=m.get("inv_block", False),
46
+ )
47
+ self.mask_generators[fpc].append(mask_generator)
48
+
49
+ def step(self):
50
+ for fpc in self.mask_generators:
51
+ for mask_generator in self.mask_generators[fpc]:
52
+ mask_generator.step()
53
+
54
+ def __call__(self, batch):
55
+
56
+ # Batch: [buffer, label, clip_indices]
57
+ filtered_batches = {fpc: [] for fpc in self.mask_generators}
58
+ for sample in batch:
59
+ fpc = len(sample[-1][-1])
60
+ filtered_batches[fpc] += [sample]
61
+
62
+ fpc_collations = []
63
+ for fpc in filtered_batches:
64
+ fpc_batch = filtered_batches[fpc]
65
+ batch_size = len(fpc_batch)
66
+ if batch_size == 0:
67
+ continue
68
+ collated_batch = torch.utils.data.default_collate(fpc_batch)
69
+ collated_masks_pred, collated_masks_enc = [], []
70
+ for i, mask_generator in enumerate(self.mask_generators[fpc]):
71
+ masks_enc, masks_pred = mask_generator(batch_size)
72
+ collated_masks_enc.append(masks_enc)
73
+ collated_masks_pred.append(masks_pred)
74
+ fpc_collations += [(collated_batch, collated_masks_enc, collated_masks_pred)]
75
+
76
+ return fpc_collations
77
+
78
+
79
+ class _MaskGenerator(object):
80
+
81
+ def __init__(
82
+ self,
83
+ crop_size=(224, 224),
84
+ num_frames=16,
85
+ spatial_patch_size=(16, 16),
86
+ temporal_patch_size=2,
87
+ spatial_pred_mask_scale=(0.2, 0.8),
88
+ temporal_pred_mask_scale=(1.0, 1.0),
89
+ aspect_ratio=(0.3, 3.0),
90
+ npred=1,
91
+ max_context_frames_ratio=1.0,
92
+ max_keep=None,
93
+ inv_block=False,
94
+ full_complement=False,
95
+ pred_full_complement=False,
96
+ ):
97
+ super(_MaskGenerator, self).__init__()
98
+ if not isinstance(crop_size, tuple):
99
+ crop_size = (crop_size,) * 2
100
+ if not isinstance(spatial_patch_size, tuple):
101
+ spatial_patch_size = (spatial_patch_size,) * 2
102
+ self.crop_size = crop_size
103
+ self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)]
104
+ self.duration = num_frames // temporal_patch_size
105
+ self.full_complement = full_complement
106
+ self.pred_full_complement = pred_full_complement
107
+
108
+ self.spatial_patch_size = spatial_patch_size
109
+ self.temporal_patch_size = temporal_patch_size
110
+
111
+ self.aspect_ratio = aspect_ratio
112
+ self.spatial_pred_mask_scale = spatial_pred_mask_scale
113
+ self.temporal_pred_mask_scale = temporal_pred_mask_scale
114
+ self.npred = npred
115
+ self.max_context_duration = max(
116
+ 1, int(self.duration * max_context_frames_ratio)
117
+ ) # maximum number of time-steps (frames) spanned by context mask
118
+ self.max_keep = max_keep # maximum number of patches to keep in context
119
+ self._itr_counter = Value("i", -1) # collator is shared across worker processes
120
+ self.inv_block = inv_block
121
+
122
+ def step(self):
123
+ i = self._itr_counter
124
+ with i.get_lock():
125
+ i.value += 1
126
+ v = i.value
127
+ return v
128
+
129
+ def _sample_block_size(self, generator, temporal_scale, spatial_scale, aspect_ratio_scale):
130
+ # -- Sample temporal block mask scale
131
+ _rand = torch.rand(1, generator=generator).item()
132
+ min_t, max_t = temporal_scale
133
+ temporal_mask_scale = min_t + _rand * (max_t - min_t)
134
+ t = max(1, int(self.duration * temporal_mask_scale))
135
+
136
+ # -- Sample spatial block mask scale
137
+ _rand = torch.rand(1, generator=generator).item()
138
+ min_s, max_s = spatial_scale
139
+ spatial_mask_scale = min_s + _rand * (max_s - min_s)
140
+ spatial_num_keep = int(self.height * self.width * spatial_mask_scale)
141
+
142
+ # -- Sample block aspect-ratio
143
+ _rand = torch.rand(1, generator=generator).item()
144
+ min_ar, max_ar = aspect_ratio_scale
145
+ aspect_ratio = min_ar + _rand * (max_ar - min_ar)
146
+
147
+ # -- Compute block height and width (given scale and aspect-ratio)
148
+ h = int(round(math.sqrt(spatial_num_keep * aspect_ratio)))
149
+ w = int(round(math.sqrt(spatial_num_keep / aspect_ratio)))
150
+ h = min(h, self.height)
151
+ w = min(w, self.width)
152
+
153
+ return (t, h, w)
154
+
155
+ def _sample_block_mask(self, b_size):
156
+ t, h, w = b_size
157
+ top = torch.randint(0, self.height - h + 1, (1,))
158
+ left = torch.randint(0, self.width - w + 1, (1,))
159
+ start = torch.randint(0, self.duration - t + 1, (1,))
160
+
161
+ mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
162
+ mask[start : start + t, top : top + h, left : left + w] = 0
163
+
164
+ # Context mask will only span the first X frames
165
+ # (X=self.max_context_frames)
166
+ if self.max_context_duration < self.duration:
167
+ mask[self.max_context_duration :, :, :] = 0
168
+
169
+ # --
170
+ return mask
171
+
172
+ def __call__(self, batch_size):
173
+ """
174
+ Create encoder and predictor masks when collating imgs into a batch
175
+ # 1. sample pred block size using seed
176
+ # 2. sample several pred block locations for each image (w/o seed)
177
+ # 3. return pred masks and complement (enc mask)
178
+ """
179
+ seed = self.step()
180
+ g = torch.Generator()
181
+ g.manual_seed(seed)
182
+ p_size = self._sample_block_size(
183
+ generator=g,
184
+ temporal_scale=self.temporal_pred_mask_scale,
185
+ spatial_scale=self.spatial_pred_mask_scale,
186
+ aspect_ratio_scale=self.aspect_ratio,
187
+ )
188
+
189
+ collated_masks_pred, collated_masks_enc = [], []
190
+ min_keep_enc = min_keep_pred = self.duration * self.height * self.width
191
+ for _ in range(batch_size):
192
+
193
+ empty_context = True
194
+ while empty_context:
195
+
196
+ mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32)
197
+ for _ in range(self.npred):
198
+ mask_e *= self._sample_block_mask(p_size)
199
+ mask_e = mask_e.flatten()
200
+
201
+ mask_p = torch.argwhere(mask_e == 0).squeeze()
202
+ mask_e = torch.nonzero(mask_e).squeeze()
203
+
204
+ empty_context = len(mask_e) == 0
205
+ if not empty_context:
206
+ min_keep_pred = min(min_keep_pred, len(mask_p))
207
+ min_keep_enc = min(min_keep_enc, len(mask_e))
208
+ collated_masks_pred.append(mask_p)
209
+ collated_masks_enc.append(mask_e)
210
+
211
+ if self.max_keep is not None:
212
+ min_keep_enc = min(min_keep_enc, self.max_keep)
213
+
214
+ collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc]
215
+ collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred]
216
+ if self.full_complement: # predictor mask is just complement of encoder mask
217
+ collated_masks_pred = [
218
+ torch.tensor(
219
+ sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
220
+ dtype=cm.dtype,
221
+ )
222
+ for cm in collated_masks_enc
223
+ ]
224
+ elif self.pred_full_complement:
225
+ collated_masks_enc = [
226
+ torch.tensor(
227
+ sorted(list(set(range(int(self.duration * self.height * self.width))) - set(cm.tolist()))),
228
+ dtype=cm.dtype,
229
+ )
230
+ for cm in collated_masks_pred
231
+ ]
232
+
233
+ collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)
234
+ collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
235
+
236
+ if self.inv_block:
237
+ return collated_masks_pred, collated_masks_enc # predict context from block
238
+ else:
239
+ return collated_masks_enc, collated_masks_pred
src/masks/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ def apply_masks(x, masks, concat=True):
10
+ """
11
+ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
12
+ :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
13
+ """
14
+ all_x = []
15
+ for m in masks:
16
+ mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
17
+ all_x += [torch.gather(x, dim=1, index=mask_keep)]
18
+ if not concat:
19
+ return all_x
20
+
21
+ return torch.cat(all_x, dim=0)
src/models/__pycache__/attentive_pooler.cpython-312.pyc ADDED
Binary file (6.34 kB). View file
 
src/models/__pycache__/vision_transformer.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
src/models/ac_predictor.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from src.models.utils.modules import ACBlock as Block
13
+ from src.models.utils.modules import build_action_block_causal_attention_mask
14
+ from src.utils.tensors import trunc_normal_
15
+
16
+
17
+ class VisionTransformerPredictorAC(nn.Module):
18
+ """Action Conditioned Vision Transformer Predictor"""
19
+
20
+ def __init__(
21
+ self,
22
+ img_size=(224, 224),
23
+ patch_size=16,
24
+ num_frames=1,
25
+ tubelet_size=2,
26
+ embed_dim=768,
27
+ predictor_embed_dim=1024,
28
+ depth=24,
29
+ num_heads=16,
30
+ mlp_ratio=4.0,
31
+ qkv_bias=True,
32
+ qk_scale=None,
33
+ drop_rate=0.0,
34
+ attn_drop_rate=0.0,
35
+ drop_path_rate=0.0,
36
+ norm_layer=nn.LayerNorm,
37
+ init_std=0.02,
38
+ uniform_power=True,
39
+ use_silu=False,
40
+ wide_silu=True,
41
+ is_frame_causal=True,
42
+ use_activation_checkpointing=False,
43
+ use_rope=True,
44
+ action_embed_dim=7,
45
+ use_extrinsics=False,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.is_frame_causal = is_frame_causal
50
+ self.use_extrinsics = use_extrinsics
51
+
52
+ # Map input to predictor dimension
53
+ self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
54
+ self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
55
+ self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
56
+ self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
57
+
58
+ # Determine positional embedding
59
+ if type(img_size) is int:
60
+ img_size = (img_size, img_size)
61
+ self.img_height, self.img_width = img_size
62
+ self.patch_size = patch_size
63
+ # --
64
+ self.num_frames = num_frames
65
+ self.tubelet_size = tubelet_size
66
+ self.is_video = num_frames > 1
67
+
68
+ self.grid_height = img_size[0] // self.patch_size
69
+ self.grid_width = img_size[1] // self.patch_size
70
+ self.use_activation_checkpointing = use_activation_checkpointing
71
+
72
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
73
+
74
+ # Position embedding
75
+ self.uniform_power = uniform_power
76
+
77
+ # Attention Blocks
78
+ self.use_rope = use_rope
79
+ self.predictor_blocks = nn.ModuleList(
80
+ [
81
+ Block(
82
+ use_rope=use_rope,
83
+ grid_size=self.grid_height,
84
+ dim=predictor_embed_dim,
85
+ num_heads=num_heads,
86
+ mlp_ratio=mlp_ratio,
87
+ qkv_bias=qkv_bias,
88
+ qk_scale=qk_scale,
89
+ drop=drop_rate,
90
+ act_layer=nn.SiLU if use_silu else nn.GELU,
91
+ wide_silu=wide_silu,
92
+ attn_drop=attn_drop_rate,
93
+ drop_path=dpr[i],
94
+ norm_layer=norm_layer,
95
+ )
96
+ for i in range(depth)
97
+ ]
98
+ )
99
+
100
+ # Normalize & project back to input dimension
101
+ self.predictor_norm = norm_layer(predictor_embed_dim)
102
+ self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
103
+
104
+ # ------ initialize weights
105
+ self.init_std = init_std
106
+ self.apply(self._init_weights)
107
+ self._rescale_blocks()
108
+
109
+ attn_mask = None
110
+ if self.is_frame_causal:
111
+ grid_depth = self.num_frames // self.tubelet_size
112
+ grid_height = self.img_height // self.patch_size
113
+ grid_width = self.img_width // self.patch_size
114
+ attn_mask = build_action_block_causal_attention_mask(
115
+ grid_depth, grid_height, grid_width, add_tokens=3 if use_extrinsics else 2
116
+ )
117
+ self.attn_mask = attn_mask
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, nn.Linear):
121
+ trunc_normal_(m.weight, std=self.init_std)
122
+ if isinstance(m, nn.Linear) and m.bias is not None:
123
+ nn.init.constant_(m.bias, 0)
124
+ elif isinstance(m, nn.LayerNorm):
125
+ nn.init.constant_(m.bias, 0)
126
+ nn.init.constant_(m.weight, 1.0)
127
+
128
+ def _rescale_blocks(self):
129
+ def rescale(param, layer_id):
130
+ param.div_(math.sqrt(2.0 * layer_id))
131
+
132
+ for layer_id, layer in enumerate(self.predictor_blocks):
133
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
134
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
135
+
136
+ def forward(self, x, actions, states, extrinsics=None):
137
+ """
138
+ :param x: context tokens
139
+ """
140
+ # Map tokens to pedictor dimensions
141
+ x = self.predictor_embed(x)
142
+ B, N_ctxt, D = x.size()
143
+ T = N_ctxt // (self.grid_height * self.grid_width)
144
+
145
+ # Interleave action tokens
146
+ s = self.state_encoder(states).unsqueeze(2)
147
+ a = self.action_encoder(actions).unsqueeze(2)
148
+ x = x.view(B, T, self.grid_height * self.grid_width, D) # [B, T, H*W, D]
149
+ if self.use_extrinsics:
150
+ e = self.extrinsics_encoder(extrinsics).unsqueeze(2)
151
+ x = torch.cat([a, s, e, x], dim=2).flatten(1, 2) # [B, T*(H*W+3), D]
152
+ else:
153
+ x = torch.cat([a, s, x], dim=2).flatten(1, 2) # [B, T*(H*W+2), D]
154
+
155
+ cond_tokens = 3 if self.use_extrinsics else 2
156
+ attn_mask = self.attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
157
+
158
+ # Fwd prop
159
+ for i, blk in enumerate(self.predictor_blocks):
160
+ if self.use_activation_checkpointing:
161
+ x = torch.utils.checkpoint.checkpoint(
162
+ blk,
163
+ x,
164
+ mask=None,
165
+ attn_mask=attn_mask,
166
+ T=T,
167
+ H=self.grid_height,
168
+ W=self.grid_width,
169
+ action_tokens=cond_tokens,
170
+ use_reentrant=False,
171
+ )
172
+ else:
173
+ x = blk(
174
+ x,
175
+ mask=None,
176
+ attn_mask=attn_mask,
177
+ T=T,
178
+ H=self.grid_height,
179
+ W=self.grid_width,
180
+ action_tokens=cond_tokens,
181
+ )
182
+
183
+ # Split out action and frame tokens
184
+ x = x.view(B, T, cond_tokens + self.grid_height * self.grid_width, D) # [B, T, K+H*W, D]
185
+ x = x[:, :, cond_tokens:, :].flatten(1, 2)
186
+
187
+ x = self.predictor_norm(x)
188
+ x = self.predictor_proj(x)
189
+
190
+ return x
191
+
192
+
193
+ def vit_ac_predictor(**kwargs):
194
+ model = VisionTransformerPredictorAC(
195
+ mlp_ratio=4,
196
+ qkv_bias=True,
197
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
198
+ **kwargs
199
+ )
200
+ return model
src/models/attentive_pooler.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock
13
+ from src.utils.tensors import trunc_normal_
14
+
15
+
16
+ class AttentivePooler(nn.Module):
17
+ """Attentive Pooler"""
18
+
19
+ def __init__(
20
+ self,
21
+ num_queries=1,
22
+ embed_dim=768,
23
+ num_heads=12,
24
+ mlp_ratio=4.0,
25
+ depth=1,
26
+ norm_layer=nn.LayerNorm,
27
+ init_std=0.02,
28
+ qkv_bias=True,
29
+ complete_block=True,
30
+ use_activation_checkpointing=False,
31
+ ):
32
+ super().__init__()
33
+ self.use_activation_checkpointing = use_activation_checkpointing
34
+ self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
35
+
36
+ self.complete_block = complete_block
37
+ if complete_block:
38
+ self.cross_attention_block = CrossAttentionBlock(
39
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer
40
+ )
41
+ else:
42
+ self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias)
43
+
44
+ self.blocks = None
45
+ if depth > 1:
46
+ self.blocks = nn.ModuleList(
47
+ [
48
+ Block(
49
+ dim=embed_dim,
50
+ num_heads=num_heads,
51
+ mlp_ratio=mlp_ratio,
52
+ qkv_bias=qkv_bias,
53
+ qk_scale=False,
54
+ norm_layer=norm_layer,
55
+ )
56
+ for i in range(depth - 1)
57
+ ]
58
+ )
59
+
60
+ self.init_std = init_std
61
+ trunc_normal_(self.query_tokens, std=self.init_std)
62
+ self.apply(self._init_weights)
63
+ self._rescale_blocks()
64
+
65
+ def _rescale_blocks(self):
66
+ def rescale(param, layer_id):
67
+ param.div_(math.sqrt(2.0 * layer_id))
68
+
69
+ layer_id = 0
70
+ if self.blocks is not None:
71
+ for layer_id, layer in enumerate(self.blocks):
72
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
73
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
74
+
75
+ if self.complete_block:
76
+ rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, nn.Linear):
80
+ trunc_normal_(m.weight, std=self.init_std)
81
+ if isinstance(m, nn.Linear) and m.bias is not None:
82
+ nn.init.constant_(m.bias, 0)
83
+ elif isinstance(m, nn.LayerNorm):
84
+ nn.init.constant_(m.bias, 0)
85
+ nn.init.constant_(m.weight, 1.0)
86
+ elif isinstance(m, nn.Conv2d):
87
+ trunc_normal_(m.weight, std=self.init_std)
88
+ if m.bias is not None:
89
+ nn.init.constant_(m.bias, 0)
90
+
91
+ def forward(self, x):
92
+ if self.blocks is not None:
93
+ for blk in self.blocks:
94
+ if self.use_activation_checkpointing:
95
+ x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False)
96
+ else:
97
+ x = blk(x)
98
+ q = self.query_tokens.repeat(len(x), 1, 1)
99
+ q = self.cross_attention_block(q, x)
100
+ return q
101
+
102
+
103
+ class AttentiveClassifier(nn.Module):
104
+ """Attentive Classifier"""
105
+
106
+ def __init__(
107
+ self,
108
+ embed_dim=768,
109
+ num_heads=12,
110
+ mlp_ratio=4.0,
111
+ depth=1,
112
+ norm_layer=nn.LayerNorm,
113
+ init_std=0.02,
114
+ qkv_bias=True,
115
+ num_classes=1000,
116
+ complete_block=True,
117
+ use_activation_checkpointing=False,
118
+ ):
119
+ super().__init__()
120
+ self.pooler = AttentivePooler(
121
+ num_queries=1,
122
+ embed_dim=embed_dim,
123
+ num_heads=num_heads,
124
+ mlp_ratio=mlp_ratio,
125
+ depth=depth,
126
+ norm_layer=norm_layer,
127
+ init_std=init_std,
128
+ qkv_bias=qkv_bias,
129
+ complete_block=complete_block,
130
+ use_activation_checkpointing=use_activation_checkpointing,
131
+ )
132
+ self.linear = nn.Linear(embed_dim, num_classes, bias=True)
133
+
134
+ def forward(self, x):
135
+ x = self.pooler(x).squeeze(1)
136
+ x = self.linear(x)
137
+ return x
src/models/predictor.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from src.masks.utils import apply_masks
13
+ from src.models.utils.modules import Block
14
+ from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
15
+ from src.utils.tensors import repeat_interleave_batch, trunc_normal_
16
+
17
+
18
+ class VisionTransformerPredictor(nn.Module):
19
+ """Vision Transformer"""
20
+
21
+ def __init__(
22
+ self,
23
+ img_size=(224, 224),
24
+ patch_size=16,
25
+ num_frames=1,
26
+ tubelet_size=2,
27
+ embed_dim=768,
28
+ predictor_embed_dim=384,
29
+ depth=6,
30
+ num_heads=12,
31
+ mlp_ratio=4.0,
32
+ qkv_bias=True,
33
+ qk_scale=None,
34
+ drop_rate=0.0,
35
+ attn_drop_rate=0.0,
36
+ drop_path_rate=0.0,
37
+ norm_layer=nn.LayerNorm,
38
+ init_std=0.02,
39
+ uniform_power=False,
40
+ use_mask_tokens=False,
41
+ num_mask_tokens=2,
42
+ zero_init_mask_tokens=True,
43
+ use_silu=False,
44
+ wide_silu=True,
45
+ use_activation_checkpointing=False,
46
+ return_all_tokens=False,
47
+ chop_last_n_tokens=0,
48
+ use_rope=False,
49
+ **kwargs
50
+ ):
51
+ super().__init__()
52
+ self.return_all_tokens = return_all_tokens
53
+ self.chop_last_n_tokens = chop_last_n_tokens
54
+
55
+ # Map input to predictor dimension
56
+ self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
57
+
58
+ # Mask tokens
59
+ self.mask_tokens = None
60
+ self.num_mask_tokens = 0
61
+ if use_mask_tokens:
62
+ self.num_mask_tokens = num_mask_tokens
63
+ self.mask_tokens = nn.ParameterList(
64
+ [nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) for i in range(num_mask_tokens)]
65
+ )
66
+
67
+ # Determine positional embedding
68
+ if type(img_size) is int:
69
+ img_size = (img_size, img_size)
70
+ self.img_height, self.img_width = img_size
71
+ self.patch_size = patch_size
72
+ # --
73
+ self.num_frames = num_frames
74
+ self.tubelet_size = tubelet_size
75
+ self.is_video = num_frames > 1
76
+
77
+ self.grid_height = img_size[0] // self.patch_size
78
+ self.grid_width = img_size[1] // self.patch_size
79
+ self.grid_depth = num_frames // self.tubelet_size
80
+ self.use_activation_checkpointing = use_activation_checkpointing
81
+
82
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
83
+
84
+ if self.is_video:
85
+ self.num_patches = num_patches = (
86
+ (num_frames // tubelet_size) * (img_size[0] // patch_size) * (img_size[1] // patch_size)
87
+ )
88
+ else:
89
+ self.num_patches = num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
90
+ # Position embedding
91
+ self.uniform_power = uniform_power
92
+
93
+ self.predictor_pos_embed = None
94
+ if not use_rope:
95
+ self.predictor_pos_embed = nn.Parameter(
96
+ torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False
97
+ )
98
+
99
+ # Attention Blocks
100
+ self.use_rope = use_rope
101
+ self.predictor_blocks = nn.ModuleList(
102
+ [
103
+ Block(
104
+ use_rope=use_rope,
105
+ grid_size=self.grid_height,
106
+ grid_depth=self.grid_depth,
107
+ dim=predictor_embed_dim,
108
+ num_heads=num_heads,
109
+ mlp_ratio=mlp_ratio,
110
+ qkv_bias=qkv_bias,
111
+ qk_scale=qk_scale,
112
+ drop=drop_rate,
113
+ act_layer=nn.SiLU if use_silu else nn.GELU,
114
+ wide_silu=wide_silu,
115
+ attn_drop=attn_drop_rate,
116
+ drop_path=dpr[i],
117
+ norm_layer=norm_layer,
118
+ )
119
+ for i in range(depth)
120
+ ]
121
+ )
122
+
123
+ # Normalize & project back to input dimension
124
+ self.predictor_norm = norm_layer(predictor_embed_dim)
125
+ self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
126
+
127
+ # ------ initialize weights
128
+ if self.predictor_pos_embed is not None:
129
+ self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed
130
+ self.init_std = init_std
131
+ if not zero_init_mask_tokens:
132
+ for mt in self.mask_tokens:
133
+ trunc_normal_(mt, std=init_std)
134
+ self.apply(self._init_weights)
135
+ self._rescale_blocks()
136
+
137
+ def _init_pos_embed(self, pos_embed):
138
+ embed_dim = pos_embed.size(-1)
139
+ grid_size = self.img_height // self.patch_size # TODO: update; currently assumes square input
140
+ if self.is_video:
141
+ grid_depth = self.num_frames // self.tubelet_size
142
+ sincos = get_3d_sincos_pos_embed(
143
+ embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=self.uniform_power
144
+ )
145
+ else:
146
+ sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)
147
+ pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))
148
+
149
+ def _init_weights(self, m):
150
+ if isinstance(m, nn.Linear):
151
+ trunc_normal_(m.weight, std=self.init_std)
152
+ if isinstance(m, nn.Linear) and m.bias is not None:
153
+ nn.init.constant_(m.bias, 0)
154
+ elif isinstance(m, nn.LayerNorm):
155
+ nn.init.constant_(m.bias, 0)
156
+ nn.init.constant_(m.weight, 1.0)
157
+
158
+ def _rescale_blocks(self):
159
+ def rescale(param, layer_id):
160
+ param.div_(math.sqrt(2.0 * layer_id))
161
+
162
+ for layer_id, layer in enumerate(self.predictor_blocks):
163
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
164
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
165
+
166
+ def forward(self, x, masks_x, masks_y, mask_index=1, has_cls=False):
167
+ """
168
+ :param x: context tokens
169
+ :param masks_x: indices of context tokens in input
170
+ :params masks_y: indices of target tokens in input
171
+ """
172
+ assert (masks_x is not None) and (masks_y is not None), "Cannot run predictor without mask indices"
173
+ if not isinstance(masks_x, list):
174
+ masks_x = [masks_x]
175
+ if not isinstance(masks_y, list):
176
+ masks_y = [masks_y]
177
+
178
+ # Batch Size
179
+ B = len(x) // len(masks_x)
180
+
181
+ # Map context tokens to pedictor dimensions
182
+ x = self.predictor_embed(x)
183
+ if has_cls:
184
+ x_cls = x[:, :1, :]
185
+ x = x[:, 1:, :]
186
+ _, N_ctxt, D = x.shape
187
+
188
+ # Add positional embedding to ctxt tokens
189
+ if not self.use_rope:
190
+ x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
191
+ x += apply_masks(x_pos_embed, masks_x)
192
+
193
+ # Make target tokens
194
+ mask_index = mask_index % self.num_mask_tokens
195
+ pred_tokens = self.mask_tokens[mask_index]
196
+ pred_tokens = pred_tokens.repeat(B, self.num_patches, 1)
197
+ pred_tokens = apply_masks(pred_tokens, masks_y)
198
+ # -- add pos embed
199
+ if not self.use_rope:
200
+ pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
201
+ pos_embs = apply_masks(pos_embs, masks_y)
202
+ pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
203
+ pred_tokens += pos_embs
204
+
205
+ # Concatenate context & target tokens
206
+ x = x.repeat(len(masks_x), 1, 1)
207
+ x = torch.cat([x, pred_tokens], dim=1)
208
+
209
+ # Positions of context & target tokens
210
+ masks_x = torch.cat(masks_x, dim=0)
211
+ masks_y = torch.cat(masks_y, dim=0)
212
+ masks = torch.cat([masks_x, masks_y], dim=1)
213
+
214
+ # Put tokens in sorted order
215
+ argsort = torch.argsort(masks, dim=1) # [B, N]
216
+ masks = torch.stack([masks[i, row] for i, row in enumerate(argsort)], dim=0)
217
+ x = torch.stack([x[i, row, :] for i, row in enumerate(argsort)], dim=0)
218
+
219
+ # Remove the last n tokens of sorted sequence before processing
220
+ if self.chop_last_n_tokens > 0:
221
+ x = x[:, : -self.chop_last_n_tokens]
222
+ masks = masks[:, : -self.chop_last_n_tokens]
223
+
224
+ if has_cls:
225
+ x = torch.cat([x_cls, x], dim=1)
226
+
227
+ # Fwd prop
228
+ for i, blk in enumerate(self.predictor_blocks):
229
+ if self.use_activation_checkpointing:
230
+ x = torch.utils.checkpoint.checkpoint(blk, x, masks, None, use_reentrant=False)
231
+ else:
232
+ x = blk(x, mask=masks, attn_mask=None)
233
+ x = self.predictor_norm(x)
234
+
235
+ if has_cls:
236
+ x = x[:, 1:, :]
237
+
238
+ # Return output corresponding to target tokens
239
+ if not self.return_all_tokens:
240
+ reverse_argsort = torch.argsort(argsort, dim=1) # [B, N]
241
+ x = torch.stack([x[i, row, :] for i, row in enumerate(reverse_argsort)], dim=0)
242
+ x = x[:, N_ctxt:]
243
+
244
+ x = self.predictor_proj(x)
245
+
246
+ return x
247
+
248
+
249
+ def vit_predictor(**kwargs):
250
+ model = VisionTransformerPredictor(
251
+ mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
252
+ )
253
+ return model
src/models/utils/__pycache__/modules.cpython-312.pyc ADDED
Binary file (30.1 kB). View file
 
src/models/utils/__pycache__/patch_embed.cpython-312.pyc ADDED
Binary file (2.31 kB). View file
 
src/models/utils/__pycache__/pos_embs.cpython-312.pyc ADDED
Binary file (4.2 kB). View file
 
src/models/utils/modules.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from timm.models.layers import drop_path
10
+
11
+
12
+ def build_action_block_causal_attention_mask(T, H, W, add_tokens=1):
13
+ N_T = add_tokens + (H * W)
14
+ N = T * N_T
15
+ mask = torch.zeros(N, N).bool()
16
+ mask_block = torch.ones(N_T, N_T).bool()
17
+ local_window_time = T
18
+
19
+ for t1 in range(T):
20
+ for t2 in range(max(0, t1 - local_window_time + 1), t1 + 1):
21
+ mask[t1 * N_T : (t1 + 1) * N_T, t2 * N_T : (t2 + 1) * N_T] = mask_block
22
+
23
+ return mask
24
+
25
+
26
+ def rotate_queries_or_keys(x, pos):
27
+ B, num_heads, N, D = x.size()
28
+ assert D % 2 == 0, "Embedding dimension must be a multiple of 2 for block matrix rotation"
29
+
30
+ # -- compute angle for each position
31
+ omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
32
+ omega /= D / 2.0
33
+ omega = 1.0 / 10000**omega # (D/2,)
34
+ freq = torch.einsum("..., f -> ... f", pos, omega) # (..., N, D/2), outer product
35
+
36
+ # -- build rotation matrix and apply
37
+ emb_sin = freq.sin() # (..., N, D/2)
38
+ emb_cos = freq.cos() # (..., N, D/2)
39
+
40
+ emb_sin = emb_sin.squeeze(-1).repeat(1, 1, 1, 2)
41
+ emb_cos = emb_cos.squeeze(-1).repeat(1, 1, 1, 2)
42
+
43
+ # --
44
+ y = x.unflatten(-1, (-1, 2))
45
+ y1, y2 = y.unbind(
46
+ dim=-1,
47
+ )
48
+ y = torch.stack((-y2, y1), dim=-1)
49
+ y = y.flatten(-2)
50
+ return (x * emb_cos) + (y * emb_sin)
51
+
52
+
53
+ class DropPath(nn.Module):
54
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
55
+
56
+ def __init__(self, drop_prob=None):
57
+ super(DropPath, self).__init__()
58
+ self.drop_prob = drop_prob
59
+
60
+ def forward(self, x):
61
+ return drop_path(x, self.drop_prob, self.training)
62
+
63
+ def extra_repr(self) -> str:
64
+ return "p={}".format(self.drop_prob)
65
+
66
+
67
+ class MLP(nn.Module):
68
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
69
+ super().__init__()
70
+ out_features = out_features or in_features
71
+ hidden_features = hidden_features or in_features
72
+ self.fc1 = nn.Linear(in_features, hidden_features)
73
+ self.act = act_layer()
74
+ self.fc2 = nn.Linear(hidden_features, out_features)
75
+ self.drop = nn.Dropout(drop)
76
+
77
+ def forward(self, x):
78
+ x = self.fc1(x)
79
+ x = self.act(x)
80
+ x = self.drop(x)
81
+ x = self.fc2(x)
82
+ x = self.drop(x)
83
+ return x
84
+
85
+
86
+ class SwiGLUFFN(nn.Module):
87
+ def __init__(
88
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, wide_silu=True
89
+ ):
90
+ super().__init__()
91
+ out_features = out_features or in_features
92
+ swiglu_hidden_features = hidden_features = hidden_features or in_features
93
+ if wide_silu:
94
+ swiglu_hidden_features = int(2 * hidden_features / 3)
95
+ align_as = 8
96
+ swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
97
+ self.fc1 = nn.Linear(in_features, swiglu_hidden_features)
98
+ self.fc2 = nn.Linear(in_features, swiglu_hidden_features)
99
+ self.act = act_layer()
100
+ self.fc3 = nn.Linear(swiglu_hidden_features, out_features)
101
+
102
+ def forward(self, x):
103
+ x1 = self.fc1(x)
104
+ x2 = self.fc2(x)
105
+ hidden = F.silu(x1) * x2
106
+ return self.fc3(hidden)
107
+
108
+
109
+ class ACRoPEAttention(nn.Module):
110
+ def __init__(
111
+ self,
112
+ dim,
113
+ num_heads=8,
114
+ qkv_bias=False,
115
+ qk_scale=None,
116
+ attn_drop=0.0,
117
+ proj_drop=0.0,
118
+ use_sdpa=True,
119
+ is_causal=False,
120
+ grid_size=16,
121
+ ):
122
+ super().__init__()
123
+ self.num_heads = num_heads
124
+ self.head_dim = head_dim = dim // num_heads
125
+ self.scale = qk_scale or head_dim**-0.5
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ self.proj = nn.Linear(dim, dim)
129
+ self.proj_drop_prob = proj_drop
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+ self.use_sdpa = use_sdpa
132
+ # --
133
+ self.d_dim = int(2 * ((head_dim // 3) // 2))
134
+ self.h_dim = int(2 * ((head_dim // 3) // 2))
135
+ self.w_dim = int(2 * ((head_dim // 3) // 2))
136
+ self.grid_size = grid_size
137
+ self.is_causal = is_causal
138
+
139
+ def _get_frame_pos(self, ids, H_patches, W_patches):
140
+ tokens_per_frame = int(H_patches * W_patches)
141
+ return ids // tokens_per_frame
142
+
143
+ def _get_height_pos(self, ids, H_patches, W_patches):
144
+ # Remove frame component from ids
145
+ tokens_per_frame = int(H_patches * W_patches)
146
+ tokens_per_row = W_patches
147
+ frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
148
+ ids = ids - tokens_per_frame * frame_ids
149
+ # --
150
+ return ids // tokens_per_row
151
+
152
+ def separate_positions(self, ids, H_patches, W_patches):
153
+ tokens_per_frame = int(H_patches * W_patches)
154
+ tokens_per_row = W_patches
155
+ frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
156
+ # --
157
+ height_ids = self._get_height_pos(ids, H_patches, W_patches)
158
+ # --
159
+ # Remove frame component from ids (1st term) and height component (2nd term)
160
+ width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
161
+ return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
162
+
163
+ def forward(self, x, mask=None, attn_mask=None, T=None, H=None, W=None, action_tokens=0):
164
+ B, N, C = x.size()
165
+
166
+ # -- compute position of each frame token
167
+ if mask is not None:
168
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
169
+ d_mask, h_mask, w_mask = self.separate_positions(mask, H, W)
170
+ else:
171
+ mask = torch.arange(int(T * H * W), device=x.device)
172
+ d_mask, h_mask, w_mask = self.separate_positions(mask, H, W)
173
+
174
+ # -- snap spatial positions to grid size
175
+ h_mask *= self.grid_size / H
176
+ w_mask *= self.grid_size / W
177
+
178
+ # -- split out action tokens from sequence
179
+ if action_tokens > 0:
180
+ x = x.view(B, -1, action_tokens + H * W, C) # [B, T, 1+H*W, D]
181
+
182
+ action_q, action_k, action_v = [], [], []
183
+ for i in range(action_tokens):
184
+ a = x[:, :, i : i + 1, :].flatten(1, 2)
185
+ # Note action tokens do not work with masking
186
+ # -- compute qkv for action tokens and rotate
187
+ qkv = self.qkv(a).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
188
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
189
+ # --
190
+ qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(T, device=x.device))
191
+ kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(T, device=x.device))
192
+ qr = q[..., self.d_dim :]
193
+ kr = k[..., self.d_dim :]
194
+ action_q += [torch.cat([qd, qr], dim=-1).view(B, self.num_heads, T, 1, -1)]
195
+ action_k += [torch.cat([kd, kr], dim=-1).view(B, self.num_heads, T, 1, -1)]
196
+ action_v += [v.view(B, self.num_heads, T, 1, -1)]
197
+
198
+ action_q = torch.cat(action_q, dim=3).flatten(2, 3)
199
+ action_k = torch.cat(action_k, dim=3).flatten(2, 3)
200
+ action_v = torch.cat(action_v, dim=3).flatten(2, 3)
201
+ x = x[:, :, action_tokens:, :].flatten(1, 2)
202
+
203
+ # -- compute qkv for frame tokens and rotate
204
+ qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
205
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
206
+
207
+ s = 0
208
+ # Rotate depth
209
+ qd = rotate_queries_or_keys(q[..., s : s + self.d_dim], pos=d_mask)
210
+ kd = rotate_queries_or_keys(k[..., s : s + self.d_dim], pos=d_mask)
211
+ s += self.d_dim
212
+ # Rotate height dim
213
+ qh = rotate_queries_or_keys(q[..., s : s + self.h_dim], pos=h_mask)
214
+ kh = rotate_queries_or_keys(k[..., s : s + self.h_dim], pos=h_mask)
215
+ s += self.h_dim
216
+ # Rotate width dim
217
+ qw = rotate_queries_or_keys(q[..., s : s + self.w_dim], pos=w_mask)
218
+ kw = rotate_queries_or_keys(k[..., s : s + self.w_dim], pos=w_mask)
219
+ s += self.w_dim
220
+
221
+ # Combine rotated dimension
222
+ if s < self.head_dim:
223
+ qr = q[..., s:]
224
+ kr = k[..., s:]
225
+ q = torch.cat([qd, qh, qw, qr], dim=-1)
226
+ k = torch.cat([kd, kh, kw, kr], dim=-1)
227
+ else:
228
+ q = torch.cat([qd, qh, qw], dim=-1)
229
+ k = torch.cat([kd, kh, kw], dim=-1)
230
+
231
+ if action_tokens > 0:
232
+
233
+ def merge_(tx, ta):
234
+ """tx, tx in [B, num_heads, N, D]"""
235
+ tx = tx.view(B, self.num_heads, T, H * W, -1) # [B, T, H*W, D]
236
+ ta = ta.view(B, self.num_heads, T, action_tokens, -1) # [B, T, A, D]
237
+ return torch.cat([ta, tx], dim=3).flatten(2, 3)
238
+
239
+ q = merge_(q, action_q)
240
+ k = merge_(k, action_k)
241
+ v = merge_(v, action_v)
242
+
243
+ if attn_mask is not None or self.use_sdpa:
244
+ with torch.backends.cuda.sdp_kernel():
245
+ x = F.scaled_dot_product_attention(
246
+ q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
247
+ )
248
+ attn = None
249
+ else:
250
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
251
+ attn = attn.softmax(dim=-1)
252
+ attn = self.attn_drop(attn)
253
+ x = attn @ v
254
+
255
+ x = x.transpose(1, 2).reshape(B, N, C)
256
+ x = self.proj(x)
257
+ x = self.proj_drop(x)
258
+ return x
259
+
260
+
261
+ class RoPEAttention(nn.Module):
262
+ def __init__(
263
+ self,
264
+ dim,
265
+ num_heads=8,
266
+ qkv_bias=False,
267
+ qk_scale=None,
268
+ attn_drop=0.0,
269
+ proj_drop=0.0,
270
+ use_sdpa=True,
271
+ grid_size=14,
272
+ is_causal=False,
273
+ ):
274
+ super().__init__()
275
+ self.num_heads = num_heads
276
+ self.head_dim = head_dim = dim // num_heads
277
+ self.scale = qk_scale or head_dim**-0.5
278
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
279
+ self.attn_drop = nn.Dropout(attn_drop)
280
+ self.proj = nn.Linear(dim, dim)
281
+ self.proj_drop_prob = proj_drop
282
+ self.proj_drop = nn.Dropout(proj_drop)
283
+ self.use_sdpa = use_sdpa
284
+ # --
285
+ self.d_dim = int(2 * ((head_dim // 3) // 2))
286
+ self.h_dim = int(2 * ((head_dim // 3) // 2))
287
+ self.w_dim = int(2 * ((head_dim // 3) // 2))
288
+ self.grid_size = grid_size
289
+ self.is_causal = is_causal
290
+
291
+ def _get_frame_pos(self, ids, H_patches=None, W_patches=None):
292
+ if H_patches is None or W_patches is None:
293
+ tokens_per_frame = int(self.grid_size * self.grid_size)
294
+ else:
295
+ tokens_per_frame = int(H_patches * W_patches)
296
+ return ids // tokens_per_frame
297
+
298
+ def _get_height_pos(self, ids, H_patches=None, W_patches=None):
299
+ # Remove frame component from ids
300
+ if H_patches is None or W_patches is None:
301
+ tokens_per_frame = int(self.grid_size * self.grid_size)
302
+ tokens_per_row = self.grid_size
303
+ else:
304
+ tokens_per_frame = int(H_patches * W_patches)
305
+ tokens_per_row = W_patches
306
+ frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
307
+ ids = ids - tokens_per_frame * frame_ids
308
+ # --
309
+ return ids // tokens_per_row
310
+
311
+ def separate_positions(self, ids, H_patches=None, W_patches=None):
312
+ if H_patches is None or W_patches is None:
313
+ tokens_per_frame = int(self.grid_size * self.grid_size)
314
+ tokens_per_row = self.grid_size
315
+ else:
316
+ tokens_per_frame = int(H_patches * W_patches)
317
+ tokens_per_row = W_patches
318
+ frame_ids = self._get_frame_pos(ids, H_patches, W_patches)
319
+ # --
320
+ height_ids = self._get_height_pos(ids, H_patches, W_patches)
321
+ # --
322
+ # Remove frame component from ids (1st term) and height component (2nd term)
323
+ width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
324
+ return frame_ids, height_ids, width_ids
325
+
326
+ def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patches=None):
327
+ B, N, C = x.size()
328
+ grid_depth = int(N // (self.grid_size * self.grid_size))
329
+
330
+ qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
331
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
332
+
333
+ if mask is not None:
334
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
335
+ d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches)
336
+ else:
337
+ if T is None or H_patches is None or W_patches is None:
338
+ mask = torch.arange(int(grid_depth * self.grid_size * self.grid_size), device=x.device)
339
+ else:
340
+ mask = torch.arange(int(T * H_patches * W_patches), device=x.device)
341
+ d_mask, h_mask, w_mask = self.separate_positions(mask, H_patches, W_patches)
342
+
343
+ s = 0
344
+ # Rotate depth
345
+ qd = rotate_queries_or_keys(q[..., s : s + self.d_dim], pos=d_mask)
346
+ kd = rotate_queries_or_keys(k[..., s : s + self.d_dim], pos=d_mask)
347
+ s += self.d_dim
348
+ # Rotate height dim
349
+ qh = rotate_queries_or_keys(q[..., s : s + self.h_dim], pos=h_mask)
350
+ kh = rotate_queries_or_keys(k[..., s : s + self.h_dim], pos=h_mask)
351
+ s += self.h_dim
352
+ # Rotate width dim
353
+ qw = rotate_queries_or_keys(q[..., s : s + self.w_dim], pos=w_mask)
354
+ kw = rotate_queries_or_keys(k[..., s : s + self.w_dim], pos=w_mask)
355
+ s += self.w_dim
356
+
357
+ # Combine rotated dimension
358
+ if s < self.head_dim:
359
+ qr = q[..., s:]
360
+ kr = k[..., s:]
361
+ q = torch.cat([qd, qh, qw, qr], dim=-1)
362
+ k = torch.cat([kd, kh, kw, kr], dim=-1)
363
+ else:
364
+ q = torch.cat([qd, qh, qw], dim=-1)
365
+ k = torch.cat([kd, kh, kw], dim=-1)
366
+
367
+ if attn_mask is not None or self.use_sdpa:
368
+ with torch.backends.cuda.sdp_kernel():
369
+ x = F.scaled_dot_product_attention(
370
+ q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
371
+ )
372
+ attn = None
373
+ else:
374
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
375
+ attn = attn.softmax(dim=-1)
376
+ attn = self.attn_drop(attn)
377
+ x = attn @ v
378
+
379
+ x = x.transpose(1, 2).reshape(B, N, C)
380
+ x = self.proj(x)
381
+ x = self.proj_drop(x)
382
+ return x
383
+
384
+
385
+ class Attention(nn.Module):
386
+ def __init__(
387
+ self,
388
+ dim,
389
+ num_heads=8,
390
+ qkv_bias=False,
391
+ qk_scale=None,
392
+ attn_drop=0.0,
393
+ proj_drop=0.0,
394
+ use_sdpa=True,
395
+ is_causal=False,
396
+ ):
397
+ super().__init__()
398
+ self.num_heads = num_heads
399
+ head_dim = dim // num_heads
400
+ self.scale = qk_scale or head_dim**-0.5
401
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
402
+ self.attn_drop = nn.Dropout(attn_drop)
403
+ self.proj = nn.Linear(dim, dim)
404
+ self.proj_drop_prob = proj_drop
405
+ self.proj_drop = nn.Dropout(proj_drop)
406
+ self.use_sdpa = use_sdpa
407
+ self.is_causal = is_causal
408
+
409
+ def forward(self, x, mask=None, attn_mask=None):
410
+ B, N, C = x.shape
411
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
412
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D]
413
+
414
+ if attn_mask is not None or self.use_sdpa:
415
+ with torch.backends.cuda.sdp_kernel():
416
+ x = F.scaled_dot_product_attention(
417
+ q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
418
+ )
419
+ attn = None
420
+ else:
421
+ attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
422
+ attn = attn.softmax(dim=-1)
423
+ attn = self.attn_drop(attn)
424
+ x = attn @ v
425
+
426
+ x = x.transpose(1, 2).reshape(B, N, C)
427
+ x = self.proj(x)
428
+ x = self.proj_drop(x)
429
+ return x
430
+
431
+
432
+ class ACBlock(nn.Module):
433
+ def __init__(
434
+ self,
435
+ dim,
436
+ num_heads,
437
+ mlp_ratio=4.0,
438
+ qkv_bias=False,
439
+ qk_scale=None,
440
+ drop=0.0,
441
+ attn_drop=0.0,
442
+ drop_path=0.0,
443
+ act_layer=nn.GELU,
444
+ wide_silu=True,
445
+ norm_layer=nn.LayerNorm,
446
+ use_sdpa=True,
447
+ is_causal=False,
448
+ grid_size=16,
449
+ use_rope=False,
450
+ **kwargs,
451
+ ):
452
+ super().__init__()
453
+ self.norm1 = norm_layer(dim)
454
+ if use_rope:
455
+ self.attn = ACRoPEAttention(
456
+ dim,
457
+ num_heads=num_heads,
458
+ qkv_bias=qkv_bias,
459
+ qk_scale=qk_scale,
460
+ attn_drop=attn_drop,
461
+ use_sdpa=use_sdpa,
462
+ is_causal=is_causal,
463
+ grid_size=grid_size,
464
+ proj_drop=drop,
465
+ )
466
+ else:
467
+ self.attn = Attention(
468
+ dim,
469
+ num_heads=num_heads,
470
+ qkv_bias=qkv_bias,
471
+ qk_scale=qk_scale,
472
+ attn_drop=attn_drop,
473
+ use_sdpa=use_sdpa,
474
+ is_causal=is_causal,
475
+ proj_drop=drop,
476
+ )
477
+
478
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
479
+ self.norm2 = norm_layer(dim)
480
+ mlp_hidden_dim = int(dim * mlp_ratio)
481
+ if act_layer is nn.SiLU:
482
+ self.mlp = SwiGLUFFN(
483
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, wide_silu=wide_silu, drop=drop
484
+ )
485
+ else:
486
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
487
+
488
+ def forward(self, x, mask=None, attn_mask=None, T=None, H=None, W=None, action_tokens=0):
489
+ y = self.norm1(x)
490
+ if isinstance(self.attn, ACRoPEAttention):
491
+ y = self.attn(y, mask=mask, attn_mask=attn_mask, T=T, H=H, W=W, action_tokens=action_tokens)
492
+ else:
493
+ y = self.attn(y, mask=mask, attn_mask=attn_mask)
494
+ x = x + self.drop_path(y)
495
+ y = self.norm2(x)
496
+ x = x + self.drop_path(self.mlp(y))
497
+ return x
498
+
499
+
500
+ class Block(nn.Module):
501
+ def __init__(
502
+ self,
503
+ dim,
504
+ num_heads,
505
+ mlp_ratio=4.0,
506
+ qkv_bias=False,
507
+ qk_scale=None,
508
+ drop=0.0,
509
+ attn_drop=0.0,
510
+ drop_path=0.0,
511
+ act_layer=nn.GELU,
512
+ wide_silu=True,
513
+ norm_layer=nn.LayerNorm,
514
+ use_sdpa=True,
515
+ is_causal=False,
516
+ grid_size=16,
517
+ use_rope=False,
518
+ **kwargs,
519
+ ):
520
+ super().__init__()
521
+ self.norm1 = norm_layer(dim)
522
+ if use_rope:
523
+ self.attn = RoPEAttention(
524
+ dim,
525
+ num_heads=num_heads,
526
+ qkv_bias=qkv_bias,
527
+ qk_scale=qk_scale,
528
+ attn_drop=attn_drop,
529
+ use_sdpa=use_sdpa,
530
+ is_causal=is_causal,
531
+ grid_size=grid_size,
532
+ proj_drop=drop,
533
+ )
534
+ else:
535
+ self.attn = Attention(
536
+ dim,
537
+ num_heads=num_heads,
538
+ qkv_bias=qkv_bias,
539
+ qk_scale=qk_scale,
540
+ attn_drop=attn_drop,
541
+ use_sdpa=use_sdpa,
542
+ is_causal=is_causal,
543
+ proj_drop=drop,
544
+ )
545
+
546
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
547
+ self.norm2 = norm_layer(dim)
548
+ mlp_hidden_dim = int(dim * mlp_ratio)
549
+ if act_layer is nn.SiLU:
550
+ self.mlp = SwiGLUFFN(
551
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, wide_silu=wide_silu, drop=drop
552
+ )
553
+ else:
554
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
555
+
556
+ def forward(self, x, mask=None, attn_mask=None, T=None, H_patches=None, W_patches=None):
557
+ if isinstance(self.attn, RoPEAttention):
558
+ y = self.attn(self.norm1(x), mask=mask, attn_mask=attn_mask, T=T, H_patches=H_patches, W_patches=W_patches)
559
+ else:
560
+ y = self.attn(self.norm1(x), mask=mask, attn_mask=attn_mask)
561
+ x = x + self.drop_path(y)
562
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
563
+ return x
564
+
565
+
566
+ class CrossAttention(nn.Module):
567
+ def __init__(self, dim, num_heads=12, qkv_bias=False, use_sdpa=True):
568
+ super().__init__()
569
+ self.num_heads = num_heads
570
+ head_dim = dim // num_heads
571
+ self.scale = head_dim**-0.5
572
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
573
+ self.kv = nn.Linear(dim, int(dim * 2), bias=qkv_bias)
574
+ # self.proj = nn.Linear(dim, dim)
575
+ self.use_sdpa = use_sdpa
576
+
577
+ def forward(self, q, x):
578
+ B, n, C = q.shape
579
+ q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
580
+
581
+ B, N, C = x.shape
582
+ kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
583
+ k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head)
584
+
585
+ if self.use_sdpa:
586
+ with torch.backends.cuda.sdp_kernel():
587
+ q = F.scaled_dot_product_attention(q, k, v)
588
+ else:
589
+ xattn = (q @ k.transpose(-2, -1)) * self.scale
590
+ xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len)
591
+ q = xattn @ v
592
+
593
+ q = q.transpose(1, 2).reshape(B, n, C)
594
+ return q
595
+
596
+
597
+ class CrossAttentionBlock(nn.Module):
598
+ def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
599
+ super().__init__()
600
+ self.norm1 = norm_layer(dim)
601
+ self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
602
+ self.norm2 = norm_layer(dim)
603
+ mlp_hidden_dim = int(dim * mlp_ratio)
604
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
605
+
606
+ def forward(self, q, x):
607
+ y = self.xattn(q, self.norm1(x))
608
+ q = q + y
609
+ q = q + self.mlp(self.norm2(q))
610
+ return q
src/models/utils/patch_embed.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+
9
+
10
+ class PatchEmbed(nn.Module):
11
+ """
12
+ Image to Patch Embedding
13
+ """
14
+
15
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
16
+ super().__init__()
17
+ self.patch_size = patch_size
18
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
19
+
20
+ def forward(self, x):
21
+ B, C, H, W = x.shape
22
+ x = self.proj(x).flatten(2).transpose(1, 2)
23
+ return x
24
+
25
+
26
+ class PatchEmbed3D(nn.Module):
27
+ """
28
+ Image to Patch Embedding
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ patch_size=16,
34
+ tubelet_size=2,
35
+ in_chans=3,
36
+ embed_dim=768,
37
+ ):
38
+ super().__init__()
39
+ self.patch_size = patch_size
40
+ self.tubelet_size = tubelet_size
41
+
42
+ self.proj = nn.Conv3d(
43
+ in_channels=in_chans,
44
+ out_channels=embed_dim,
45
+ kernel_size=(tubelet_size, patch_size, patch_size),
46
+ stride=(tubelet_size, patch_size, patch_size),
47
+ )
48
+
49
+ def forward(self, x, **kwargs):
50
+ B, C, T, H, W = x.shape
51
+ x = self.proj(x).flatten(2).transpose(1, 2)
52
+ return x
src/models/utils/pos_embs.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+
8
+
9
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False):
10
+ """
11
+ grid_size: int of the grid height and width
12
+ grid_depth: int of the grid depth
13
+ returns:
14
+ pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token)
15
+ or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token)
16
+ """
17
+ grid_d = np.arange(grid_depth, dtype=float)
18
+ grid_h = np.arange(grid_size, dtype=float)
19
+ grid_w = np.arange(grid_size, dtype=float)
20
+ grid_h, grid_d, grid_w = np.meshgrid(
21
+ grid_h, grid_d, grid_w
22
+ ) # order of meshgrid is very important for indexing as [d,h,w]
23
+
24
+ if not uniform_power:
25
+ h_embed_dim = embed_dim // 4
26
+ w_embed_dim = embed_dim // 4
27
+ d_embed_dim = embed_dim // 2
28
+ else:
29
+ h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2)
30
+
31
+ emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1)
32
+ emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2)
33
+ emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3)
34
+ pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1)
35
+ pos_embed = pos_embed[:, :embed_dim]
36
+ if cls_token:
37
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
38
+ return pos_embed
39
+
40
+
41
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
42
+ """
43
+ grid_size: int of the grid height and width
44
+ returns:
45
+ pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token)
46
+ or [1+grid_size*grid_size, embed_dim] (w/ cls_token)
47
+ """
48
+ grid_h = np.arange(grid_size, dtype=float)
49
+ grid_w = np.arange(grid_size, dtype=float)
50
+ grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w]
51
+
52
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2)
53
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2)
54
+ pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
55
+ if cls_token:
56
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
57
+ return pos_embed
58
+
59
+
60
+ def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
61
+ """
62
+ embed_dim: output dimension for each position
63
+ grid_size: int of the grid length
64
+ returns:
65
+ pos_embed: [grid_size, embed_dim] (w/o cls_token)
66
+ or [1+grid_size, embed_dim] (w/ cls_token)
67
+ """
68
+ grid = np.arange(grid_size, dtype=float)
69
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
70
+ if cls_token:
71
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
72
+ return pos_embed
73
+
74
+
75
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76
+ """
77
+ embed_dim: output dimension for each position
78
+ pos: a list of positions to be encoded: size (M,)
79
+ returns: (M, D)
80
+ """
81
+ assert embed_dim % 2 == 0
82
+ omega = np.arange(embed_dim // 2, dtype=float)
83
+ omega /= embed_dim / 2.0
84
+ omega = 1.0 / 10000**omega # (D/2,)
85
+
86
+ pos = pos.reshape(-1) # (M,)
87
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
88
+
89
+ emb_sin = np.sin(out) # (M, D/2)
90
+ emb_cos = np.cos(out) # (M, D/2)
91
+
92
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
93
+ return emb
src/models/vision_transformer.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from src.masks.utils import apply_masks
13
+ from src.models.utils.modules import Block
14
+ from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D
15
+ from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
16
+ from src.utils.tensors import trunc_normal_
17
+
18
+
19
+ class VisionTransformer(nn.Module):
20
+ """Vision Transformer"""
21
+
22
+ def __init__(
23
+ self,
24
+ img_size=(224, 224),
25
+ patch_size=16,
26
+ num_frames=1,
27
+ tubelet_size=2,
28
+ in_chans=3,
29
+ embed_dim=768,
30
+ depth=12,
31
+ num_heads=12,
32
+ mlp_ratio=4.0,
33
+ qkv_bias=True,
34
+ qk_scale=None,
35
+ drop_rate=0.0,
36
+ attn_drop_rate=0.0,
37
+ drop_path_rate=0.0,
38
+ norm_layer=nn.LayerNorm,
39
+ init_std=0.02,
40
+ out_layers=None,
41
+ uniform_power=False,
42
+ use_silu=False,
43
+ wide_silu=True,
44
+ use_sdpa=True,
45
+ use_activation_checkpointing=False,
46
+ use_rope=False,
47
+ handle_nonsquare_inputs=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.num_features = self.embed_dim = embed_dim
52
+ self.num_heads = num_heads
53
+ self.out_layers = out_layers
54
+ self.handle_nonsquare_inputs = handle_nonsquare_inputs
55
+
56
+ if type(img_size) is int:
57
+ img_size = (img_size, img_size)
58
+ self.img_height, self.img_width = img_size
59
+ self.patch_size = patch_size
60
+ self.num_frames = num_frames
61
+ self.tubelet_size = tubelet_size
62
+ self.is_video = num_frames > 1
63
+
64
+ self.use_activation_checkpointing = use_activation_checkpointing
65
+
66
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
67
+
68
+ # Tokenize pixels with convolution
69
+ if self.is_video:
70
+ self.patch_embed = PatchEmbed3D(
71
+ patch_size=patch_size, tubelet_size=tubelet_size, in_chans=in_chans, embed_dim=embed_dim
72
+ )
73
+ self.num_patches = (num_frames // tubelet_size) * (img_size[0] // patch_size) * (img_size[1] // patch_size)
74
+ else:
75
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
76
+ self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
77
+
78
+ # Position embedding
79
+ self.uniform_power = uniform_power
80
+ self.use_rope = use_rope
81
+ if self.use_rope:
82
+ self.pos_embed = None
83
+ else:
84
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=False)
85
+
86
+ # Attention Blocks
87
+ self.blocks = nn.ModuleList(
88
+ [
89
+ Block(
90
+ use_rope=use_rope,
91
+ grid_size=img_size[0] // patch_size,
92
+ grid_depth=num_frames // tubelet_size,
93
+ dim=embed_dim,
94
+ num_heads=num_heads,
95
+ mlp_ratio=mlp_ratio,
96
+ use_sdpa=use_sdpa,
97
+ qkv_bias=qkv_bias,
98
+ qk_scale=qk_scale,
99
+ drop=drop_rate,
100
+ act_layer=nn.SiLU if use_silu else nn.GELU,
101
+ wide_silu=wide_silu,
102
+ attn_drop=attn_drop_rate,
103
+ drop_path=dpr[i],
104
+ norm_layer=norm_layer,
105
+ )
106
+ for i in range(depth)
107
+ ]
108
+ )
109
+ self.norm = norm_layer(embed_dim)
110
+
111
+ # ------ initialize weights
112
+ if self.pos_embed is not None:
113
+ self._init_pos_embed(self.pos_embed.data) # sincos pos-embed
114
+ self.init_std = init_std
115
+ self.apply(self._init_weights)
116
+ self._rescale_blocks()
117
+
118
+ def _init_pos_embed(self, pos_embed):
119
+ embed_dim = pos_embed.size(-1)
120
+ grid_size = self.img_height // self.patch_size # TODO: update; currently assumes square input
121
+ if self.is_video:
122
+ grid_depth = self.num_frames // self.tubelet_size
123
+ sincos = get_3d_sincos_pos_embed(
124
+ embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=self.uniform_power
125
+ )
126
+ else:
127
+ sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)
128
+ pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))
129
+
130
+ def _init_weights(self, m):
131
+ if isinstance(m, nn.Linear):
132
+ trunc_normal_(m.weight, std=self.init_std)
133
+ if isinstance(m, nn.Linear) and m.bias is not None:
134
+ nn.init.constant_(m.bias, 0)
135
+ elif isinstance(m, nn.LayerNorm):
136
+ nn.init.constant_(m.bias, 0)
137
+ nn.init.constant_(m.weight, 1.0)
138
+ elif isinstance(m, nn.Conv2d):
139
+ trunc_normal_(m.weight, std=self.init_std)
140
+ if m.bias is not None:
141
+ nn.init.constant_(m.bias, 0)
142
+ elif isinstance(m, nn.Conv3d):
143
+ trunc_normal_(m.weight, std=self.init_std)
144
+ if m.bias is not None:
145
+ nn.init.constant_(m.bias, 0)
146
+
147
+ def _rescale_blocks(self):
148
+ def rescale(param, layer_id):
149
+ param.div_(math.sqrt(2.0 * layer_id))
150
+
151
+ for layer_id, layer in enumerate(self.blocks):
152
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
153
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
154
+
155
+ def get_num_layers(self):
156
+ return len(self.blocks)
157
+
158
+ def no_weight_decay(self):
159
+ return {}
160
+
161
+ def forward(self, x, masks=None):
162
+ """
163
+ :param x: input image/video
164
+ :param masks: indices of patch tokens to mask (remove)
165
+ """
166
+ if masks is not None and not isinstance(masks, list):
167
+ masks = [masks]
168
+
169
+ # Tokenize input
170
+ # Image
171
+ if x.ndim == 4:
172
+ _, _, H, W = x.shape
173
+ T = 1
174
+ # Video
175
+ elif x.ndim == 5:
176
+ _, _, T, H, W = x.shape
177
+ T = T // self.tubelet_size
178
+ H_patches = H // self.patch_size
179
+ W_patches = W // self.patch_size
180
+ if not self.handle_nonsquare_inputs:
181
+ T = H_patches = W_patches = None
182
+
183
+ if not self.use_rope:
184
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
185
+ x = self.patch_embed(x)
186
+ x += pos_embed
187
+ else:
188
+ x = self.patch_embed(x)
189
+
190
+ # Mask away unwanted tokens (if masks provided)
191
+ if masks is not None:
192
+ x = apply_masks(x, masks)
193
+ masks = torch.cat(masks, dim=0)
194
+
195
+ # Fwd prop
196
+ outs = []
197
+ for i, blk in enumerate(self.blocks):
198
+ if self.use_activation_checkpointing:
199
+ x = torch.utils.checkpoint.checkpoint(
200
+ blk, x, masks, None, T=T, H_patches=H_patches, W_patches=W_patches, use_reentrant=False
201
+ )
202
+ else:
203
+ x = blk(x, mask=masks, attn_mask=None, T=T, H_patches=H_patches, W_patches=W_patches)
204
+ if self.out_layers is not None and i in self.out_layers:
205
+ outs.append(self.norm(x))
206
+
207
+ if self.out_layers is not None:
208
+ return outs
209
+
210
+ if self.norm is not None:
211
+ x = self.norm(x)
212
+
213
+ return x
214
+
215
+ def interpolate_pos_encoding(self, x, pos_embed):
216
+
217
+ _, N, dim = pos_embed.shape
218
+
219
+ if self.is_video:
220
+
221
+ # If pos_embed already corret size, just return
222
+ _, _, T, H, W = x.shape
223
+ if H == self.img_height and W == self.img_width and T == self.num_frames:
224
+ return pos_embed
225
+
226
+ # Just chop off last N tokens of positional embedding
227
+ elif H == self.img_height and W == self.img_width and T < self.num_frames:
228
+ new_N = int((T // self.tubelet_size) * (H // self.patch_size) * (W // self.patch_size))
229
+ return pos_embed[:, :new_N, :]
230
+
231
+ # Convert depth, height, width of input to be measured in patches
232
+ # instead of pixels/frames
233
+ T = T // self.tubelet_size
234
+ H = H // self.patch_size
235
+ W = W // self.patch_size
236
+
237
+ # Compute the initialized shape of the positional embedding measured
238
+ # in patches
239
+ N_t = self.num_frames // self.tubelet_size
240
+ N_h = self.img_height // self.patch_size
241
+ N_w = self.img_width // self.patch_size
242
+ assert N_h * N_w * N_t == N, "Positional embedding initialized incorrectly"
243
+
244
+ # Compute scale factor for spatio-temporal interpolation
245
+ scale_factor = (T / N_t, H / N_h, W / N_w)
246
+
247
+ pos_embed = nn.functional.interpolate(
248
+ pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3),
249
+ scale_factor=scale_factor,
250
+ mode="trilinear",
251
+ )
252
+ pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim)
253
+ return pos_embed
254
+
255
+ else:
256
+
257
+ # If pos_embed already corret size, just return
258
+ _, _, H, W = x.shape
259
+ if H == self.img_height and W == self.img_width:
260
+ return pos_embed
261
+
262
+ # Compute scale factor for spatial interpolation
263
+ npatch = (H // self.patch_size) * (W // self.patch_size)
264
+ scale_factor = math.sqrt(npatch / N)
265
+
266
+ pos_embed = nn.functional.interpolate(
267
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
268
+ scale_factor=scale_factor,
269
+ mode="bicubic",
270
+ )
271
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
272
+ return pos_embed
273
+
274
+
275
+ def vit_large(patch_size=16, **kwargs):
276
+ model = VisionTransformer(
277
+ patch_size=patch_size,
278
+ embed_dim=1024,
279
+ depth=24,
280
+ num_heads=16,
281
+ mlp_ratio=4,
282
+ qkv_bias=True,
283
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
284
+ **kwargs
285
+ )
286
+ return model
287
+
288
+
289
+ def vit_huge(patch_size=16, **kwargs):
290
+ model = VisionTransformer(
291
+ patch_size=patch_size,
292
+ embed_dim=1280,
293
+ depth=32,
294
+ num_heads=16,
295
+ mlp_ratio=4,
296
+ qkv_bias=True,
297
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
298
+ **kwargs
299
+ )
300
+ return model
301
+
302
+
303
+ def vit_giant_xformers(patch_size=16, **kwargs):
304
+ model = VisionTransformer(
305
+ patch_size=patch_size,
306
+ embed_dim=1408,
307
+ depth=40,
308
+ num_heads=22,
309
+ mlp_ratio=48 / 11,
310
+ qkv_bias=True,
311
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
312
+ **kwargs
313
+ )
314
+ return model
315
+
316
+
317
+ # We do not use any of the following ViT definitions in V-JEPA 2, but retain them for
318
+ # compatibility reasons.
319
+ def vit_synthetic(patch_size=16, **kwargs):
320
+ # For performance testing only
321
+ model = VisionTransformer(
322
+ patch_size=patch_size,
323
+ embed_dim=1,
324
+ depth=1,
325
+ num_heads=1,
326
+ mlp_ratio=4,
327
+ qkv_bias=True,
328
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
329
+ **kwargs
330
+ )
331
+ return model
332
+
333
+
334
+ def vit_tiny(patch_size=16, **kwargs):
335
+ model = VisionTransformer(
336
+ patch_size=patch_size,
337
+ embed_dim=192,
338
+ depth=12,
339
+ num_heads=3,
340
+ mlp_ratio=4,
341
+ qkv_bias=True,
342
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
343
+ **kwargs
344
+ )
345
+ return model
346
+
347
+
348
+ def vit_small(patch_size=16, **kwargs):
349
+ model = VisionTransformer(
350
+ patch_size=patch_size,
351
+ embed_dim=384,
352
+ depth=12,
353
+ num_heads=6,
354
+ mlp_ratio=4,
355
+ qkv_bias=True,
356
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
357
+ **kwargs
358
+ )
359
+ return model
360
+
361
+
362
+ def vit_base(patch_size=16, **kwargs):
363
+ model = VisionTransformer(
364
+ patch_size=patch_size,
365
+ embed_dim=768,
366
+ depth=12,
367
+ num_heads=12,
368
+ mlp_ratio=4,
369
+ qkv_bias=True,
370
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
371
+ **kwargs
372
+ )
373
+ return model
374
+
375
+
376
+ def vit_large_rope(patch_size=16, **kwargs):
377
+ model = VisionTransformer(
378
+ patch_size=patch_size,
379
+ embed_dim=1024,
380
+ depth=24,
381
+ num_heads=16,
382
+ mlp_ratio=4,
383
+ qkv_bias=True,
384
+ use_rope=True,
385
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
386
+ **kwargs
387
+ )
388
+ return model
389
+
390
+
391
+ def vit_huge_rope(patch_size=16, **kwargs):
392
+ model = VisionTransformer(
393
+ patch_size=patch_size,
394
+ embed_dim=1280,
395
+ depth=32,
396
+ num_heads=16,
397
+ mlp_ratio=4,
398
+ qkv_bias=True,
399
+ use_rope=True,
400
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
401
+ **kwargs
402
+ )
403
+ return model
404
+
405
+
406
+ def vit_giant(patch_size=16, **kwargs):
407
+ model = VisionTransformer(
408
+ patch_size=patch_size,
409
+ embed_dim=1408,
410
+ depth=40,
411
+ num_heads=16,
412
+ mlp_ratio=48 / 11,
413
+ qkv_bias=True,
414
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
415
+ **kwargs
416
+ )
417
+ return model
418
+
419
+
420
+ def vit_giant_rope(patch_size=16, **kwargs):
421
+ model = VisionTransformer(
422
+ patch_size=patch_size,
423
+ embed_dim=1408,
424
+ depth=40,
425
+ num_heads=16,
426
+ mlp_ratio=48 / 11,
427
+ qkv_bias=True,
428
+ use_rope=True,
429
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
430
+ **kwargs
431
+ )
432
+ return model
433
+
434
+
435
+ def vit_giant_xformers_rope(patch_size=16, **kwargs):
436
+ model = VisionTransformer(
437
+ patch_size=patch_size,
438
+ embed_dim=1408,
439
+ depth=40,
440
+ num_heads=22,
441
+ mlp_ratio=48 / 11,
442
+ qkv_bias=True,
443
+ use_rope=True,
444
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
445
+ **kwargs
446
+ )
447
+ return model
448
+
449
+
450
+ def vit_gigantic(patch_size=16, **kwargs):
451
+ model = VisionTransformer(
452
+ patch_size=patch_size,
453
+ embed_dim=1664,
454
+ depth=48,
455
+ num_heads=16,
456
+ mpl_ratio=64 / 13,
457
+ qkv_bias=True,
458
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
459
+ **kwargs
460
+ )
461
+ return model
462
+
463
+
464
+ def vit_gigantic_xformers(patch_size=16, **kwargs):
465
+ model = VisionTransformer(
466
+ patch_size=patch_size,
467
+ embed_dim=1664,
468
+ depth=48,
469
+ num_heads=26,
470
+ mpl_ratio=64 / 13,
471
+ qkv_bias=True,
472
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
473
+ **kwargs
474
+ )
475
+ return model
476
+
477
+
478
+ VIT_EMBED_DIMS = {
479
+ "vit_synthetic": 1,
480
+ "vit_tiny": 192,
481
+ "vit_small": 384,
482
+ "vit_base": 768,
483
+ "vit_large": 1024,
484
+ "vit_huge": 1280,
485
+ "vit_giant": 1408,
486
+ "vit_gigantic": 1664,
487
+ }
src/utils/__pycache__/tensors.cpython-312.pyc ADDED
Binary file (2.18 kB). View file
 
src/utils/checkpoint_loader.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import time
9
+ from typing import Any
10
+
11
+ import torch
12
+ from torch.serialization import MAP_LOCATION
13
+
14
+ from src.utils.logging import get_logger
15
+
16
+ logger = get_logger(os.path.basename(__file__))
17
+
18
+
19
+ def robust_checkpoint_loader(r_path: str, map_location: MAP_LOCATION = "cpu", max_retries: int = 3) -> Any:
20
+ """
21
+ Loads a checkpoint from a path, retrying up to max_retries times if the checkpoint is not found.
22
+ """
23
+ retries = 0
24
+
25
+ while retries < max_retries:
26
+ try:
27
+ return torch.load(r_path, map_location=map_location)
28
+ except Exception as e:
29
+ logger.warning(f"Encountered exception when loading checkpoint {e}")
30
+ retries += 1
31
+ if retries < max_retries:
32
+ sleep_time_s = (2**retries) * random.uniform(1.0, 1.1)
33
+ logger.warning(f"Sleeping {sleep_time_s}s and trying again, count {retries}/{max_retries}")
34
+ time.sleep(sleep_time_s)
35
+ continue
36
+ else:
37
+ raise e
src/utils/distributed.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+ from src.utils.logging import get_logger
13
+
14
+ logger = get_logger()
15
+
16
+
17
+ def init_distributed(port=37129, rank_and_world_size=(None, None)):
18
+ # try to set all environment variables to avoid triggering a segfault
19
+ # environment variables can be reallocated during the execution of torch.distributed.init_process_group
20
+ # the idea is a race condition may trigger if init_progress_group is modifying an environment variable at
21
+ # the same time as Python, so we try to set all environs before initializing distributed
22
+ if "SLURM_JOB_ID" in os.environ:
23
+ # Use the slurm_tmpdir (if it exists) instead of /tmp
24
+ tmpdir = Path(f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}")
25
+ if tmpdir.exists():
26
+ os.environ["TMPDIR"] = str(tmpdir)
27
+
28
+ if dist.is_available() and dist.is_initialized():
29
+ return dist.get_world_size(), dist.get_rank()
30
+
31
+ rank, world_size = rank_and_world_size
32
+ os.environ["MASTER_ADDR"] = "localhost"
33
+
34
+ if (rank is None) or (world_size is None):
35
+ try:
36
+ world_size = int(os.environ["SLURM_NTASKS"])
37
+ rank = int(os.environ["SLURM_PROCID"])
38
+ os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"]
39
+ except Exception:
40
+ logger.info("SLURM vars not set (distributed training not available)")
41
+ world_size, rank = 1, 0
42
+ return world_size, rank
43
+
44
+ try:
45
+ os.environ["MASTER_PORT"] = str(port)
46
+ torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
47
+ except Exception as e:
48
+ world_size, rank = 1, 0
49
+ logger.info(f"Rank: {rank}. Distributed training not available {e}")
50
+
51
+ return world_size, rank
52
+
53
+
54
+ class AllGather(torch.autograd.Function):
55
+
56
+ @staticmethod
57
+ def forward(ctx, x):
58
+ if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
59
+ x = x.contiguous()
60
+ outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
61
+ dist.all_gather(outputs, x)
62
+ return torch.cat(outputs, 0)
63
+ return x
64
+
65
+ @staticmethod
66
+ def backward(ctx, grads):
67
+ if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
68
+ s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
69
+ e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
70
+ grads = grads.contiguous()
71
+ dist.all_reduce(grads)
72
+ return grads[s:e]
73
+ return grads
74
+
75
+
76
+ class AllReduceSum(torch.autograd.Function):
77
+
78
+ @staticmethod
79
+ def forward(ctx, x):
80
+ if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
81
+ x = x.contiguous()
82
+ dist.all_reduce(x)
83
+ return x
84
+
85
+ @staticmethod
86
+ def backward(ctx, grads):
87
+ return grads
88
+
89
+
90
+ class AllReduce(torch.autograd.Function):
91
+
92
+ @staticmethod
93
+ def forward(ctx, x):
94
+ if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
95
+ x = x.contiguous() / dist.get_world_size()
96
+ dist.all_reduce(x)
97
+ return x
98
+
99
+ @staticmethod
100
+ def backward(ctx, grads):
101
+ return grads
src/utils/logging.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import subprocess
9
+ import sys
10
+
11
+ import torch
12
+
13
+
14
+ def gpu_timer(closure, log_timings=True):
15
+ """Helper to time gpu-time to execute closure()"""
16
+ log_timings = log_timings and torch.cuda.is_available()
17
+
18
+ elapsed_time = -1.0
19
+ if log_timings:
20
+ start = torch.cuda.Event(enable_timing=True)
21
+ end = torch.cuda.Event(enable_timing=True)
22
+ start.record()
23
+
24
+ result = closure()
25
+
26
+ if log_timings:
27
+ end.record()
28
+ torch.cuda.synchronize()
29
+ elapsed_time = start.elapsed_time(end)
30
+
31
+ return result, elapsed_time
32
+
33
+
34
+ LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(name)-20s][%(funcName)-25s] %(message)s"
35
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
36
+
37
+
38
+ def get_logger(name=None, force=False):
39
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force)
40
+ return logging.getLogger(name=name)
41
+
42
+
43
+ class CSVLogger(object):
44
+
45
+ def __init__(self, fname, *argv, **kwargs):
46
+ self.fname = fname
47
+ self.types = []
48
+ mode = kwargs.get("mode", "+a")
49
+ self.delim = kwargs.get("delim", ",")
50
+ # -- print headers
51
+ with open(self.fname, mode) as f:
52
+ for i, v in enumerate(argv, 1):
53
+ self.types.append(v[0])
54
+ if i < len(argv):
55
+ print(v[1], end=self.delim, file=f)
56
+ else:
57
+ print(v[1], end="\n", file=f)
58
+
59
+ def log(self, *argv):
60
+ with open(self.fname, "+a") as f:
61
+ for i, tv in enumerate(zip(self.types, argv), 1):
62
+ end = self.delim if i < len(argv) else "\n"
63
+ print(tv[0] % tv[1], end=end, file=f)
64
+
65
+
66
+ class AverageMeter(object):
67
+ """computes and stores the average and current value"""
68
+
69
+ def __init__(self):
70
+ self.reset()
71
+
72
+ def reset(self):
73
+ self.val = 0
74
+ self.avg = 0
75
+ self.max = float("-inf")
76
+ self.min = float("inf")
77
+ self.sum = 0
78
+ self.count = 0
79
+
80
+ def update(self, val, n=1):
81
+ self.val = val
82
+ try:
83
+ self.max = max(val, self.max)
84
+ self.min = min(val, self.min)
85
+ except Exception:
86
+ pass
87
+ self.sum += val * n
88
+ self.count += n
89
+ self.avg = self.sum / self.count
90
+
91
+
92
+ def jepa_rootpath():
93
+ this_file = os.path.abspath(__file__)
94
+ return "/".join(this_file.split("/")[:-3])
95
+
96
+
97
+ def git_information():
98
+ jepa_root = jepa_rootpath()
99
+ try:
100
+ resp = (
101
+ subprocess.check_output(["git", "-C", jepa_root, "rev-parse", "HEAD", "--abbrev-ref", "HEAD"])
102
+ .decode("ascii")
103
+ .strip()
104
+ )
105
+ commit, branch = resp.split("\n")
106
+ return f"branch: {branch}\ncommit: {commit}\n"
107
+ except Exception:
108
+ return "unknown"
src/utils/monitoring.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import dataclasses
7
+ import threading
8
+ import time
9
+ from typing import Dict, Tuple
10
+
11
+ import psutil
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class ResourceStatsSample:
16
+ timestamp: float
17
+ cpu_percent: float
18
+ read_count: int
19
+ write_count: int
20
+ read_bytes: int
21
+ write_bytes: int
22
+ read_chars: int
23
+ write_chars: int
24
+ cpu_times_user: float
25
+ cpu_times_system: float
26
+ cpu_times_children_user: float
27
+ cpu_times_children_system: float
28
+ cpu_times_iowait: float
29
+ cpu_affinity: str
30
+ cpu_num: int
31
+ num_threads: int
32
+ num_voluntary_ctx_switches: int
33
+ num_involuntary_ctx_switches: int
34
+
35
+ def as_tuple(self) -> Dict:
36
+ """Return values mirroring fields."""
37
+ return dataclasses.astuple(self)
38
+
39
+ def fields(self) -> Tuple[dataclasses.Field, ...]:
40
+ """Return fields in this dataclass."""
41
+ return dataclasses.fields(self.__class__)
42
+
43
+
44
+ class ResourceMonitoringThread(threading.Thread):
45
+ def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None):
46
+ """Starts a thread to monitor pid every refresh_interval seconds.
47
+
48
+ Passes a ResourceStatsSample object to the callback."""
49
+ super(ResourceMonitoringThread, self).__init__()
50
+ if refresh_interval is None:
51
+ refresh_interval = 5
52
+ self.is_running_event = threading.Event()
53
+ self.p = psutil.Process(pid)
54
+ self.refresh_interval = refresh_interval
55
+ if stats_callback_fn is None:
56
+ # Default callback
57
+ def stats_callback_fn(resource_sample: ResourceStatsSample):
58
+ print(f"PID {self.p.pid} Stats: {resource_sample.resource_stats}")
59
+
60
+ elif not callable(stats_callback_fn):
61
+ raise ValueError("Callback needs to be callable, got {}".format(type(stats_callback_fn)))
62
+ self.stats_callback_fn = stats_callback_fn
63
+
64
+ def stop(self) -> None:
65
+ self.is_running_event.set()
66
+
67
+ def run(self) -> None:
68
+ while not self.is_running_event.is_set():
69
+ self.sample_counters()
70
+ self.is_running_event.wait(self.refresh_interval)
71
+
72
+ def log_sample(self, resource_sample: ResourceStatsSample) -> None:
73
+ self.stats_callback_fn(resource_sample)
74
+
75
+ def sample_counters(self) -> None:
76
+ if not self.p.is_running():
77
+ self.stop()
78
+ return
79
+
80
+ with self.p.oneshot():
81
+ cpu_percent = self.p.cpu_percent()
82
+ cpu_times = self.p.cpu_times()
83
+ io_counters = self.p.io_counters()
84
+ cpu_affinity = self.p.cpu_affinity()
85
+ cpu_num = self.p.cpu_num()
86
+ num_threads = self.p.num_threads()
87
+ num_ctx_switches = self.p.num_ctx_switches()
88
+ timestamp = time.time()
89
+
90
+ read_count = io_counters.read_count
91
+ write_count = io_counters.write_count
92
+ read_bytes = io_counters.read_bytes
93
+ write_bytes = io_counters.write_bytes
94
+ read_chars = io_counters.read_chars
95
+ write_chars = io_counters.write_chars
96
+
97
+ def compress_cpu_affinity(cpu_affinity):
98
+ """Change list representation to interval/range representation."""
99
+ if not cpu_affinity:
100
+ return ""
101
+ cpu_affinity_compressed = []
102
+ min_x = None
103
+ max_x = None
104
+ last_x = None
105
+
106
+ # Find contiguous ranges
107
+ for x in cpu_affinity:
108
+ if last_x is None:
109
+ # Start interval
110
+ min_x = x
111
+ max_x = x
112
+ last_x = x
113
+ continue
114
+ elif x == (last_x + 1):
115
+ # Move interval up
116
+ max_x = x
117
+ elif max_x is not None:
118
+ # Interval ended, start again
119
+ if min_x == max_x:
120
+ cpu_affinity_compressed.append("{}".format(min_x))
121
+ else:
122
+ cpu_affinity_compressed.append("{}-{}".format(min_x, max_x))
123
+ min_x = x
124
+ max_x = x
125
+ last_x = x
126
+ # Terminate last range
127
+ if max_x is not None:
128
+ if min_x == max_x:
129
+ cpu_affinity_compressed.append("{}".format(min_x))
130
+ else:
131
+ cpu_affinity_compressed.append("{}-{}".format(min_x, max_x))
132
+
133
+ # Concat
134
+ cpu_affinity_compressed = ",".join(cpu_affinity_compressed)
135
+
136
+ return cpu_affinity_compressed
137
+
138
+ cpu_affinity = compress_cpu_affinity(cpu_affinity)
139
+
140
+ resource_sample = ResourceStatsSample(
141
+ timestamp=timestamp,
142
+ cpu_percent=cpu_percent,
143
+ read_count=read_count,
144
+ write_count=write_count,
145
+ read_bytes=read_bytes,
146
+ write_bytes=write_bytes,
147
+ read_chars=read_chars,
148
+ write_chars=write_chars,
149
+ cpu_times_user=cpu_times.user,
150
+ cpu_times_system=cpu_times.system,
151
+ cpu_times_children_user=cpu_times.children_user,
152
+ cpu_times_children_system=cpu_times.children_system,
153
+ cpu_times_iowait=cpu_times.iowait,
154
+ cpu_affinity=cpu_affinity,
155
+ cpu_num=cpu_num,
156
+ num_threads=num_threads,
157
+ num_voluntary_ctx_switches=num_ctx_switches.voluntary,
158
+ num_involuntary_ctx_switches=num_ctx_switches.involuntary,
159
+ )
160
+ self.log_sample(resource_sample)
161
+
162
+
163
+ if __name__ == "__main__":
164
+ import multiprocessing
165
+
166
+ pid = multiprocessing.current_process().pid
167
+ monitor_thread = ResourceMonitoringThread(pid, 1)
168
+ monitor_thread.start()
169
+ time.sleep(5)
170
+ print("Shutdown")
171
+ monitor_thread.stop()
src/utils/schedulers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+
9
+ class WSDSchedule(object):
10
+
11
+ def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0):
12
+ self.optimizer = optimizer
13
+ self.start_lr = start_lr
14
+ self.ref_lr = ref_lr
15
+ self.final_lr = final_lr
16
+ self.anneal_steps = anneal_steps
17
+ self.warmup_steps = warmup_steps
18
+ self.T_max = T_max - warmup_steps - anneal_steps
19
+ self._step = 0.0
20
+
21
+ def step(self):
22
+ self._step += 1
23
+ if self._step < self.warmup_steps:
24
+ progress = float(self._step) / float(max(1, self.warmup_steps))
25
+ new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
26
+ elif self._step < self.T_max + self.warmup_steps:
27
+ new_lr = self.ref_lr
28
+ else:
29
+ _step = self._step - (self.T_max + self.warmup_steps)
30
+ progress = float(_step) / float(max(1, self.anneal_steps))
31
+ new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
32
+
33
+ for group in self.optimizer.param_groups:
34
+ group["lr"] = new_lr
35
+ if "lr_scale" in group:
36
+ group["lr"] *= group["lr_scale"]
37
+
38
+ return new_lr
39
+
40
+
41
+ class WarmupCosineSchedule(object):
42
+
43
+ def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
44
+ self.optimizer = optimizer
45
+ self.start_lr = start_lr
46
+ self.ref_lr = ref_lr
47
+ self.final_lr = final_lr
48
+ self.warmup_steps = warmup_steps
49
+ self.T_max = T_max - warmup_steps
50
+ self._step = 0.0
51
+
52
+ def step(self):
53
+ self._step += 1
54
+ if self._step < self.warmup_steps:
55
+ progress = float(self._step) / float(max(1, self.warmup_steps))
56
+ new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
57
+ else:
58
+ # -- progress after warmup
59
+ progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
60
+ new_lr = max(
61
+ self.final_lr,
62
+ self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)),
63
+ )
64
+
65
+ for group in self.optimizer.param_groups:
66
+ group["lr"] = new_lr
67
+
68
+ return new_lr
69
+
70
+
71
+ class CosineWDSchedule(object):
72
+
73
+ def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0):
74
+ self.optimizer = optimizer
75
+ self.ref_wd = ref_wd
76
+ self.final_wd = final_wd
77
+ self.T_max = T_max
78
+ self._step = 0.0
79
+
80
+ def step(self):
81
+ self._step += 1
82
+ progress = self._step / self.T_max
83
+ new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress))
84
+
85
+ if self.final_wd <= self.ref_wd:
86
+ new_wd = max(self.final_wd, new_wd)
87
+ else:
88
+ new_wd = min(self.final_wd, new_wd)
89
+
90
+ for group in self.optimizer.param_groups:
91
+ if ("WD_exclude" not in group) or not group["WD_exclude"]:
92
+ group["weight_decay"] = new_wd
93
+ return new_wd
src/utils/tensors.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from logging import getLogger
8
+
9
+ import torch
10
+
11
+ logger = getLogger()
12
+
13
+
14
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
15
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
16
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
17
+ def norm_cdf(x):
18
+ # Computes standard normal cumulative distribution function
19
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
20
+
21
+ with torch.no_grad():
22
+ # Values are generated by using a truncated uniform distribution and
23
+ # then using the inverse CDF for the normal distribution.
24
+ # Get upper and lower cdf values
25
+ lower = norm_cdf((a - mean) / std)
26
+ upper = norm_cdf((b - mean) / std)
27
+
28
+ # Uniformly fill tensor with values from [lower, upper], then translate to
29
+ # [2*lower-1, 2*upper-1].
30
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
31
+
32
+ # Use inverse cdf transform for normal distribution to get truncated
33
+ # standard normal
34
+ tensor.erfinv_()
35
+
36
+ # Transform to proper mean, std
37
+ tensor.mul_(std * math.sqrt(2.0))
38
+ tensor.add_(mean)
39
+
40
+ # Clamp to ensure it's in the proper range
41
+ tensor.clamp_(min=a, max=b)
42
+ return tensor
43
+
44
+
45
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
46
+ # type: (Tensor, float, float, float, float) -> Tensor
47
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
48
+
49
+
50
+ def repeat_interleave_batch(x, B, repeat):
51
+ N = len(x) // B
52
+ x = torch.cat([torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N)], dim=0)
53
+ return x
src/utils/wrappers.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+
8
+
9
+ class MultiSeqWrapper(nn.Module):
10
+
11
+ def __init__(self, backbone):
12
+ super().__init__()
13
+ self.backbone = backbone
14
+
15
+ def forward(self, x, masks=None):
16
+ """
17
+ :param x: [list] List of Tensors of different seq lengths
18
+ :param masks: [list[list]] List of Tensors (out index: masks for given seq length, inner index: multimasks for that seq len)
19
+ """
20
+ if masks is None:
21
+ return [self.backbone(xi) for xi in x]
22
+
23
+ outs = [[] for _ in x]
24
+ for i, (xi, mi) in enumerate(zip(x, masks)):
25
+ for mij in mi:
26
+ outs[i] += [self.backbone(xi, masks=mij)]
27
+ return outs
28
+
29
+
30
+ class PredictorMultiSeqWrapper(nn.Module):
31
+
32
+ def __init__(self, backbone):
33
+ super().__init__()
34
+ self.backbone = backbone
35
+
36
+ def forward(self, x, masks_x, masks_y, has_cls=False):
37
+ n = 0
38
+ outs = [[] for _ in x]
39
+ for i, (xi, mxi, myi) in enumerate(zip(x, masks_x, masks_y)):
40
+ for xij, mxij, myij in zip(xi, mxi, myi):
41
+ outs[i] += [self.backbone(xij, mxij, myij, mask_index=i, has_cls=has_cls)]
42
+ n += 1
43
+ return outs
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4aa645548d069f0235573e314b319436bc9c7f4a7aa6e2c07f494de56a57b955
3
+ size 17210205