Image Feature Extraction
Transformers
Safetensors
dinov2

Add files helpful for fine-tuning

#6
by fepegar - opened
README.md CHANGED
@@ -58,6 +58,8 @@ Underlying biases of the training datasets may not be well characterized.
58
 
59
  ## Getting started
60
 
 
 
61
  Let us first write an auxiliary function to download a chest X-ray.
62
 
63
  ```python
@@ -73,6 +75,8 @@ Let us first write an auxiliary function to download a chest X-ray.
73
  ...
74
  ```
75
 
 
 
76
  Now let us download the model and encode an image.
77
 
78
  ```python
@@ -82,13 +86,17 @@ Now let us download the model and encode an image.
82
  >>>
83
  >>> # Download the model
84
  >>> repo = "microsoft/rad-dino"
85
- >>> model = AutoModel.from_pretrained(repo)
86
  >>>
87
  >>> # The processor takes a PIL image, performs resizing, center-cropping, and
88
  >>> # intensity normalization using stats from MIMIC-CXR, and returns a
89
  >>> # dictionary with a PyTorch tensor ready for the encoder
90
  >>> processor = AutoImageProcessor.from_pretrained(repo)
91
- >>>
 
 
 
 
92
  >>> # Download and preprocess a chest X-ray
93
  >>> image = download_sample_image()
94
  >>> image.size # (width, height)
@@ -97,7 +105,7 @@ Now let us download the model and encode an image.
97
  >>>
98
  >>> # Encode the image!
99
  >>> with torch.inference_mode():
100
- >>> outputs = model(**inputs)
101
  >>>
102
  >>> # Look at the CLS embeddings
103
  >>> cls_embeddings = outputs.pooler_output
@@ -124,6 +132,62 @@ We will use [`einops`](https://einops.rocks/) (install with `pip install einops`
124
  torch.Size([1, 768, 37, 37])
125
  ```
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  ## Training details
128
 
129
  ### Training data
@@ -225,4 +289,4 @@ We used [SimpleITK](https://simpleitk.org/) and [Pydicom](https://pydicom.github
225
 
226
  ## Model card contact
227
 
228
- Fernando Pérez-García ([`fperezgarcia@microsoft.com`](mailto:fperezgarcia@microsoft.com)).
 
58
 
59
  ## Getting started
60
 
61
+ ### Get some data
62
+
63
  Let us first write an auxiliary function to download a chest X-ray.
64
 
65
  ```python
 
75
  ...
76
  ```
77
 
78
+ ### Load the model
79
+
80
  Now let us download the model and encode an image.
81
 
82
  ```python
 
86
  >>>
87
  >>> # Download the model
88
  >>> repo = "microsoft/rad-dino"
89
+ >>> rad_dino = AutoModel.from_pretrained(repo)
90
  >>>
91
  >>> # The processor takes a PIL image, performs resizing, center-cropping, and
92
  >>> # intensity normalization using stats from MIMIC-CXR, and returns a
93
  >>> # dictionary with a PyTorch tensor ready for the encoder
94
  >>> processor = AutoImageProcessor.from_pretrained(repo)
95
+ ```
96
+
97
+ ### Encode an image
98
+
99
+ ```python
100
  >>> # Download and preprocess a chest X-ray
101
  >>> image = download_sample_image()
102
  >>> image.size # (width, height)
 
105
  >>>
106
  >>> # Encode the image!
107
  >>> with torch.inference_mode():
108
+ >>> outputs = rad_dino(**inputs)
109
  >>>
110
  >>> # Look at the CLS embeddings
111
  >>> cls_embeddings = outputs.pooler_output
 
132
  torch.Size([1, 768, 37, 37])
133
  ```
134
 
135
+ ### Weights for fine-tuning
136
+
137
+ We have released a checkpoint compatible with
138
+ [the original DINOv2 code](https://github.com/facebookresearch/dinov2) to help
139
+ researchers fine-tune our model.
140
+
141
+ First, let us write code to load a
142
+ [`safetensors` checkpoint](https://huggingface.co/docs/safetensors).
143
+
144
+ ```python
145
+ >>> import safetensors
146
+ >>> def safetensors_to_state_dict(checkpoint_path: str) -> dict[str, torch.Tensor]:
147
+ ... state_dict = {}
148
+ ... with safe_open(checkpoint_path, framework="pt") as ckpt_file:
149
+ ... for key in ckpt_file.keys():
150
+ ... state_dict[key] = ckpt_file.get_tensor(key)
151
+ ... return state_dict
152
+ ...
153
+ ```
154
+
155
+ We can now use the hub model and load the RAD-DINO weights.
156
+ Let's clone the DINOv2 repository so we can import the code for the head.
157
+
158
+ ```shell
159
+ git clone https://github.com/facebookresearch/dinov2.git
160
+ cd dinov2
161
+ ```
162
+
163
+ ```python
164
+ >>> import torch
165
+ >>> rad_dino_gh = torch.hub.load(".", "dinov2_vitb14")
166
+ >>> backbone_state_dict = safetensors_to_state_dict("backbone_compatible.safetensors")
167
+ >>> rad_dino_gh.load_state_dict(backbone_state_dict, strict=True)
168
+ <All keys matched successfully>
169
+ ```
170
+
171
+ The weights of the head are also released:
172
+
173
+ ```python
174
+ >>> from dinov2.layers import DINOHead
175
+ >>> rad_dino_head_gh = DINOHead(
176
+ ... in_dim=768,
177
+ ... out_dim=65536,
178
+ ... hidden_dim=2048,
179
+ ... bottleneck_dim=256,
180
+ ... nlayers=3,
181
+ ... )
182
+ >>> head_state_dict = safetensors_to_state_dict("dino_head.safetensors")
183
+ >>> rad_dino_head_gh.load_state_dict(head_state_dict, strict=True)
184
+ <All keys matched successfully>
185
+ ```
186
+
187
+ ### Configs and augmentation
188
+
189
+ The configuration files [`ssl_default_config.yaml`](./ssl_default_config.yaml) and [`vitb14_cxr.yaml`](./vitb14_cxr.yaml), and the [`augmentations`](./augmentations.py) module are also available in the repository to help researchers reproduce the training procedure with our hyperparameters.
190
+
191
  ## Training details
192
 
193
  ### Training data
 
289
 
290
  ## Model card contact
291
 
292
+ Fernando Pérez-García ([`fperezgarcia@microsoft.com`](mailto:fperezgarcia@microsoft.com)).
augmentations.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # See LICENSE in the repo root for license information.
3
+ #
4
+ # Portions:
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ #
7
+ # This source code is licensed under the Apache License, Version 2.0
8
+ # found in the LICENSE file in the root directory of this source tree.
9
+
10
+ import logging
11
+
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+
15
+ from .transforms import (
16
+ GaussianBlur,
17
+ MaybeToTensor,
18
+ make_normalize_transform,
19
+ )
20
+
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+
25
+ class DataAugmentationDINO(object):
26
+ def __init__(
27
+ self,
28
+ global_crops_scale,
29
+ local_crops_scale,
30
+ local_crops_number,
31
+ global_crops_size=224,
32
+ local_crops_size=96,
33
+ ):
34
+ self.global_crops_scale = global_crops_scale
35
+ self.local_crops_scale = local_crops_scale
36
+ self.local_crops_number = local_crops_number
37
+ self.global_crops_size = global_crops_size
38
+ self.local_crops_size = local_crops_size
39
+
40
+ logger.info("###################################")
41
+ logger.info("Using data augmentation parameters:")
42
+ logger.info(f"global_crops_scale: {global_crops_scale}")
43
+ logger.info(f"local_crops_scale: {local_crops_scale}")
44
+ logger.info(f"local_crops_number: {local_crops_number}")
45
+ logger.info(f"global_crops_size: {global_crops_size}")
46
+ logger.info(f"local_crops_size: {local_crops_size}")
47
+ logger.info("###################################")
48
+
49
+ # random resized crop and flip
50
+ self.geometric_augmentation_global = transforms.Compose(
51
+ [
52
+ transforms.RandomResizedCrop(
53
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
54
+ ),
55
+ transforms.RandomHorizontalFlip(p=0.5),
56
+ ]
57
+ )
58
+
59
+ self.geometric_augmentation_local = transforms.Compose(
60
+ [
61
+ transforms.RandomResizedCrop(
62
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
63
+ ),
64
+ transforms.RandomHorizontalFlip(p=0.5),
65
+ ]
66
+ )
67
+
68
+ # color distorsions / blurring
69
+ color_jittering = transforms.Compose(
70
+ [
71
+ transforms.RandomApply(
72
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
73
+ p=0.8,
74
+ ),
75
+ transforms.RandomGrayscale(p=0.2),
76
+ ]
77
+ )
78
+
79
+ global_transfo1_extra = GaussianBlur(p=0.5)
80
+
81
+ global_transfo2_extra = transforms.Compose(
82
+ [
83
+ GaussianBlur(p=0.1),
84
+ ]
85
+ )
86
+
87
+ local_transfo_extra = GaussianBlur(p=0.5)
88
+
89
+ # normalization
90
+ self.normalize = transforms.Compose(
91
+ [
92
+ MaybeToTensor(),
93
+ make_normalize_transform(),
94
+ ]
95
+ )
96
+
97
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
98
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
99
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
100
+
101
+ def __call__(self, image):
102
+ output = {}
103
+
104
+ # global crops:
105
+ im1_base = self.geometric_augmentation_global(image)
106
+ global_crop_1 = self.global_transfo1(im1_base)
107
+
108
+ im2_base = self.geometric_augmentation_global(image)
109
+ global_crop_2 = self.global_transfo2(im2_base)
110
+
111
+ output["global_crops"] = [global_crop_1, global_crop_2]
112
+
113
+ # global crops for teacher:
114
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
115
+
116
+ # local crops:
117
+ local_crops = [
118
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
119
+ ]
120
+ output["local_crops"] = local_crops
121
+ output["offsets"] = ()
122
+
123
+ return output
124
+
125
+
126
+ def get_online_classification_augmentation_from_config(cfg) -> transforms.Compose:
127
+ augmentation_config = cfg.evaluation.online.augmentation
128
+ interpolation = getattr(Image.Resampling, augmentation_config.interpolation)
129
+ resize_size = crop_size = cfg.crops.global_crops_size
130
+ resize = transforms.Resize(resize_size, interpolation=interpolation)
131
+ crop = transforms.CenterCrop(crop_size)
132
+ affine = transforms.RandomAffine(
133
+ degrees=augmentation_config.degrees,
134
+ scale=augmentation_config.scale,
135
+ shear=augmentation_config.shear,
136
+ interpolation=interpolation,
137
+ )
138
+ transforms_list = [
139
+ resize,
140
+ crop,
141
+ affine,
142
+ MaybeToTensor(),
143
+ make_normalize_transform(),
144
+ ]
145
+ if augmentation_config.horizontal_flip:
146
+ transforms_list.append(transforms.RandomHorizontalFlip())
147
+ return transforms.Compose(transforms_list)
backbone_compatible.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eac0464b2a00d368aa3eea1dc029964b10320fbabc59a8a4e768c43a23d26f4
3
+ size 346338024
dino_head.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b3599663464ed1054f7777f547db02f518581acc5becdd3eddffc8c507f3778
3
+ size 92554920
ssl_default_config.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ WEIGHTS: ''
3
+ compute_precision:
4
+ grad_scaler: true
5
+ teacher:
6
+ backbone:
7
+ sharding_strategy: SHARD_GRAD_OP
8
+ mixed_precision:
9
+ param_dtype: fp16
10
+ reduce_dtype: fp16
11
+ buffer_dtype: fp32
12
+ dino_head:
13
+ sharding_strategy: SHARD_GRAD_OP
14
+ mixed_precision:
15
+ param_dtype: fp16
16
+ reduce_dtype: fp16
17
+ buffer_dtype: fp32
18
+ ibot_head:
19
+ sharding_strategy: SHARD_GRAD_OP
20
+ mixed_precision:
21
+ param_dtype: fp16
22
+ reduce_dtype: fp16
23
+ buffer_dtype: fp32
24
+ student:
25
+ backbone:
26
+ sharding_strategy: SHARD_GRAD_OP
27
+ mixed_precision:
28
+ param_dtype: fp16
29
+ reduce_dtype: fp16
30
+ buffer_dtype: fp32
31
+ dino_head:
32
+ sharding_strategy: SHARD_GRAD_OP
33
+ mixed_precision:
34
+ param_dtype: fp16
35
+ reduce_dtype: fp32
36
+ buffer_dtype: fp32
37
+ ibot_head:
38
+ sharding_strategy: SHARD_GRAD_OP
39
+ mixed_precision:
40
+ param_dtype: fp16
41
+ reduce_dtype: fp32
42
+ buffer_dtype: fp32
43
+ dino:
44
+ loss_weight: 1.0
45
+ head_n_prototypes: 65536
46
+ head_bottleneck_dim: 256
47
+ head_nlayers: 3
48
+ head_hidden_dim: 2048
49
+ koleo_loss_weight: 0.1
50
+ ibot:
51
+ loss_weight: 1.0
52
+ mask_sample_probability: 0.5
53
+ mask_ratio_min_max:
54
+ - 0.1
55
+ - 0.5
56
+ separate_head: false
57
+ head_n_prototypes: 65536
58
+ head_bottleneck_dim: 256
59
+ head_nlayers: 3
60
+ head_hidden_dim: 2048
61
+ train:
62
+ batch_size_per_gpu: 64
63
+ dataset_path: ImageNet:split=TRAIN
64
+ output_dir: .
65
+ saveckp_every_n_epoch: 5
66
+ seed: 0
67
+ num_workers: 10
68
+ OFFICIAL_EPOCH_LENGTH: 0 # automatic rescaling based on the dataset len is applied if this is set to 0
69
+ cache_dataset: true
70
+ centering: "centering" # or "sinkhorn_knopp"
71
+ student:
72
+ arch: vit_large
73
+ patch_size: 16
74
+ drop_block_rate: 0.0
75
+ drop_path_rate: 0.3
76
+ layerscale: 1.0e-05
77
+ drop_path_uniform: true
78
+ pretrained_weights: ''
79
+ ffn_layer: "mlp"
80
+ block_chunks: 0
81
+ qkv_bias: true
82
+ proj_bias: true
83
+ ffn_bias: true
84
+ num_register_tokens: 0
85
+ interpolate_antialias: false
86
+ interpolate_offset: 0.1
87
+ load_weights: true
88
+ checkpoints_dir: null
89
+ teacher:
90
+ momentum_teacher: 0.992
91
+ final_momentum_teacher: 1
92
+ warmup_teacher_temp: 0.04
93
+ teacher_temp: 0.07
94
+ warmup_teacher_temp_epochs: 30
95
+ optim:
96
+ epochs: 100
97
+ weight_decay: 0.04
98
+ weight_decay_end: 0.4
99
+ base_lr: 0.004 # learning rate for a batch size of 1024
100
+ lr: 0. # will be set after applying scaling rule
101
+ warmup_epochs: 10
102
+ min_lr: 1.0e-06
103
+ clip_grad: 3.0
104
+ freeze_last_layer_epochs: 1
105
+ scaling_rule: sqrt_wrt_1024
106
+ patch_embed_lr_mult: 0.2
107
+ layerwise_decay: 0.9
108
+ adamw_beta1: 0.9
109
+ adamw_beta2: 0.999
110
+ crops:
111
+ global_crops_scale:
112
+ - 0.32
113
+ - 1.0
114
+ local_crops_number: 8
115
+ local_crops_scale:
116
+ - 0.05
117
+ - 0.32
118
+ global_crops_size: 224
119
+ local_crops_size: 96
120
+ evaluation:
121
+ eval_period_iterations: 12500
122
+ dataset_str: None
123
+ online: # see dinov2.eval.linear_callback for documentation
124
+ learning_rate: 1e-6 # will be multiplied by batch size and number of devices
125
+ num_last_blocks: 1
126
+ add_avg_pool: true
127
+ num_update_epochs_per_eval: 3
128
+ augmentation:
129
+ degrees: 30
130
+ scale:
131
+ - 0.8
132
+ - 1.2
133
+ shear: 15
134
+ interpolation: BICUBIC
135
+ horizontal_flip: true
vitb14_cxr.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this corresponds to the CXR config
2
+ train:
3
+ batch_size_per_gpu: 40 # For nodes with v100s (32 GB), use 20.
4
+ saveckp_every_n_epoch: 25
5
+ student:
6
+ arch: vit_base
7
+ block_chunks: 4
8
+ patch_size: 14
9
+ drop_block_rate: 0.00
10
+ drop_path_rate: 0.30
11
+ teacher:
12
+ warmup_teacher_temp_epochs: 50
13
+ optim:
14
+ epochs: 100
15
+ warmup_epochs: 5
16
+ base_lr: 0.001
17
+ evaluation:
18
+ eval_period_iterations: 300
19
+ tasks: # from the metadata.csv file of the CANDID processed dataset
20
+ - pneumothorax
21
+ crops:
22
+ global_crops_size: 518
23
+ local_crops_size: 196
24
+ global_crops_scale:
25
+ - 0.50
26
+ - 1.00
27
+ local_crops_number: 8
28
+ local_crops_scale:
29
+ - 0.20
30
+ - 0.50
31
+ pretrained: true