alexandraroze commited on
Commit
50bd1fc
·
1 Parent(s): def2cca
README.md CHANGED
@@ -10,3 +10,226 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+
15
+ # Cross Attention Classifier
16
+
17
+ Ниже технические детали того, как устроен репозиторий и как обучить модель.
18
+ В самой последней секции "Описание подходов" подробно описано, как я пришла к этому подходу, с какими проблемами встретилась, а также описаны два других подхода, которые я решила не реализовывать (будет мини-эссе, готовьтесь).
19
+
20
+ ## Описание проекта
21
+
22
+ В проекте используется self-supervised обучение (BYOL) и последующая классификацию изображений с помощью Cross Attention.
23
+ - **train_byol.py** — скрипт для обучения модели-энкодера по методу BYOL.
24
+ - **train_cross_classifier.py** — скрипт для обучения классификатора, который использует предварительно обученный энкодер и Cross Attention.
25
+ - **app.py** — Streamlit-приложение для инференса и визуализации предсказаний (генерация случайных изображений и получение метки от модели).
26
+
27
+ ## Структура репозитория
28
+
29
+ ```
30
+ .
31
+ ├── src
32
+ │ ├── dataset.py # Реализация датасетов (RandomAugmentedDataset и RandomPairDataset)
33
+ │ ├── inference.py # Класс для инференса (CrossAttentionInference) и вспомогательные методы
34
+ │ └── models.py # Определения моделей (BYOL, VGGLikeEncode, CrossAttentionClassifier)
35
+ ├── train_byol.py # Скрипт обучения модели BYOL
36
+ ├── train_cross_classifier.py # Скрипт обучения Cross Attention Classifier (использует готовый энкодер)
37
+ ├── app.py # Streamlit-приложение для инференса
38
+ ├── requirements.txt # Список Python-зависимостей (pip install -r requirements.txt)
39
+ └── pyproject.toml / poetry.lock # Файл для установки зависимостей через Poetry
40
+ ```
41
+
42
+ ## Установка зависимостей
43
+
44
+ Можно установить зависимости двумя способами:
45
+ 1. **Через `pip` и `requirements.txt`:**
46
+ ```bash
47
+ pip install -r requirements.txt
48
+ ```
49
+ 2. **Через Poetry:**
50
+ ```bash
51
+ poetry install
52
+ ```
53
+
54
+ ## Как обучить модель
55
+
56
+ ### 1. Обучение энкодера с помощью BYOL
57
+ Нужно запустить:
58
+ ```bash
59
+ python train_byol.py
60
+ ```
61
+ - Этот скрипт обучает модель энкодера (`VGGLikeEncode`) методом BYOL на данных, сгенерированных `RandomAugmentedDataset`.
62
+ - После обучения лучшая модель (с минимальным `val_loss`) сохраняется в `best_byol.pth`.
63
+
64
+ ### 2. Обучение Cross Attention Classifier
65
+
66
+ `best_byol.pth` (веса энкодера) должны лежать в корневой папке (можно указать другой путь). Затем нужно запустить:
67
+ ```bash
68
+ python train_cross_classifier.py
69
+ ```
70
+ - Этот скрипт использует предобученный энкодер и обучает классификатор для определения, содержат ли картинки одинаковую геометрическую фигуру.
71
+ - По итогам сохранит веса модели-классификатора в `best_attention_classifier.pth`.
72
+
73
+ ## Как запустить инференс
74
+
75
+ ### Запуск через Streamlit-приложение
76
+
77
+ 1. Файл весов `best_attention_classifier.pth` должен лежать в корневой папке
78
+ 2. Нужно запустить Streamlit-приложение:
79
+ ```bash
80
+ streamlit run app.py
81
+ ```
82
+ 3. Дальше, нужно перейти по адресу, который выдаст Streamlit (по умолчанию [http://localhost:8501](http://localhost:8501)).
83
+ 4. Нажмите кнопку **«Сгенерировать изображения»**. Приложение сгенерирует пару случайных изображений и покажет предсказанную моделью метку.
84
+
85
+ ### Использование класса инференса в коде напрямую
86
+
87
+ Можно использовать модель напрямую (без интерфейса Streamlit), импортируйте класс из `src/inference.py`, передайте путь к весам модели и вызовите метод предсказания. Пример:
88
+
89
+ ```python
90
+ import torch
91
+ from src.inference import CrossAttentionInference
92
+
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+
95
+ inference = CrossAttentionInference(
96
+ model_path="best_attention_classifier.pth",
97
+ device=device
98
+ )
99
+
100
+ pred_label, (img1, img2) = inference.predict_random_pair()
101
+ print(f"Предсказанная метка: {pred_label}")
102
+ ```
103
+
104
+ ---
105
+
106
+ ## Описание подходов
107
+
108
+
109
+ ### BYOL + Cross-attention (выбранный подход)
110
+
111
+ Когда я обдумывала финальную архитектуру, я поняла, что у креативности нет предела, поэтому каждое решение должно быть обосновано не только фразой "прикольно, можно попробовать", но и существующими проблемами, которые хочется решить.
112
+ И я решила отталкиваться от реальных задач, а именно от проблемы с отсутствием данных в медицинской сфере.
113
+ В текущей задаче такой проблемы, очевидно, нет, мы можем сгенерировать хоть миллион изображений и для всех будет лейбл.
114
+ Но вот что, если у нас нет возможности сгенерировать миллион изображений? Или если у нас есть 100к изображений, но только 10000 из них размечены?
115
+ Например, у нас есть неплохой банк изображений с разными опухолями, но размечены только 10% из них. Как можно использовать эти данные для обучения модели, чтобы она могла классифицировать новые изображения?
116
+ Саму же задачу можно перенести на задачу вида "мониторинг прогрессирования заболевания" или "сравнение патологий".
117
+
118
+ Поэтому я решила использовать self-supervised обучение для того, чтобы обучить модель на неразмеченных данных и затем дообучить ее на небольшом датасете с разметкой.
119
+
120
+ #### BYOL
121
+ BYOL [Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733) — это метод self-supervised обучения, который позволяет обучить модель на неразмеченных данных.
122
+ Важной особенностью конкретно этого подхода заключается в том, что этот метод не требует негативных пар для обучения, как некоторые другие contrastive методы.
123
+ В BYOL используется две копии одной и той же модели, которые обучаются предсказывать друг друга на основе двух views (аугментаций) одного изображения.
124
+ Не схлапываться в один вектор помогает то, что архитектура не симметрична, так как в одной из веток добавляется MLP предиктор, а также stop gradient операция.
125
+ В итоге модель учится извлекать признаки из изображения, которые можно использовать для дообучения на меньшем датасете.
126
+ В данном случае я использовала энкодер с похожей на `VGG` архитектурой.
127
+ Выбрала VGG я потому, что использовать, например, ResNet со skip-connection нет смысла, так как изображение всего ли 32x32, и через несколько слоев feature мапа была уже 8x8.
128
+
129
+ В целом сама задача заставляет балансировать между сложными подходами и реальной возможностью обучить модель на подобном датасете, так как реализовать можно (почти) что угодно, но переобучиться на таком датасете достаточно легко.
130
+
131
+ Изначально я планировала добавить в датасет для предобучения другие фигуры (треугольники, звездочки и тд), но сами эти фигуры занимают несколько пикселей, и аугментации их сильно искаж��ют. В целом на таких маленьких изображениях почти все аугментации становятся агрессивными.
132
+ Поэтому я остановилась на двух фигурах, но добавила в аугментации реверс цвета, повороты, сдвигы, гауссовский шум и тд.
133
+
134
+ #### Cross Attention
135
+ Честно скажу, именно на этот подход меня вдохновила задача с собеседования, где нужно было сопоставить два снимка одной и той же области.
136
+ Я нашла статью - [An Adaptive Remote Sensing Image-Matching Network Based on Cross Attention and Deformable Convolution](https://www.researchgate.net/publication/388063503_An_Adaptive_Remote_Sensing_Image-Matching_Network_Based_on_Cross_Attention_and_Deformable_Convolution)
137
+ где авторы решают похожую задачу (они тоже кстати используют VGG), но для более сложных изображений (сопоставление фотографий со спутника).
138
+ Я помню в чем заключается проблема cross-attention - у него квадратичная сложность, и если изображение имеет размер 512x512, то это уже становится проблемой.
139
+ Но так как в задаче изображения 32x32, я решила, что будет уместно применить данный подход (в предложенных подходах дальше я опишу, как бы решала задачу, если бы изображения были больше).
140
+ Также, я добавила position эмбеддинги, так как при переходе к cross-attention информация о позиции теряется.
141
+
142
+ Почему cross-attention? Он позволяет каждому пикселю (точнее, каждому патчу) в одном изображении "смотреть" на все патчи в другом изображении.
143
+ Таким образом, если фигуры находятся в противоположных углах, модель это учтет.
144
+ Ну и плюс тенденции последних лет - внимание, внимание, внимание.
145
+
146
+ #### Итоговая архитектура
147
+ Сама архитектура представляет собой два VGGLike энкодера с shared весами, предобученных с помощью BYOL, после которых идет слой MultiheadAttention, а затем классификационная голова.
148
+ Во время предобучения VGGLike энкодера последним слоем был AdaptiveAvgPool2d. Этот слой не использовался во время обучения классфикатора, так как на вход MultiheadAttention требовалась информативная карта признаков (я использовала 8x8).
149
+
150
+ Таким образом, когда на вход поступает два изображения, каждое из них проходит через энкодер, после чего происходит cross-attention между ними, и на выходе получается вероятность того, что изображения содержат одинаковую фигуру.
151
+ Это не самый сложный подход, который можно было придумать, но он позволяет взглянуть на задачу под другим углом - в реальности у нас нет датасета с неограниченным количеством размеченных данных, и нужно уметь работать с тем, что есть.
152
+
153
+ ### Метрики
154
+
155
+ Вот здесь можно посмотреть метрики в wandb:
156
+ - [Обучение BYOL](https://wandb.ai/alexandraroze/contrastive_learning_byol/reports/-BYOL--VmlldzoxMTQzMjA1Mw?accessToken=nh0kzpepsr0faflptx63n91kljc5wl6mt3wi3ay4wxpjmua55bf32nm36qjby0ai)
157
+ - [Обучение cross-attention классификатора](https://api.wandb.ai/links/alexandraroze/hmtnzhv9)
158
+
159
+
160
+ ## Другие подходы
161
+
162
+
163
+ ### Swin transformer
164
+
165
+
166
+ 1. Каждое 32×32 изображение делится на патчи размером 4×4, это даёт 64 патча на изображение.
167
+ Каждый патч выпрямляется и проходит через линейный слой для получения векторного представления.
168
+
169
+ 2. Далее мы применим early fusion (так как если применить late fusion, нам придется применять cross-attention, чтобы действительно учест�� взаимодействие между патчами из разных изображений).
170
+ После извлечения патчей из двух изображений мы просто конкатенируем их по оси последовательности, получая 128 токенов.
171
+
172
+ 3. В window multi-head attention мы делим эту последовательность на окна фиксированного размера. Допустим, каждое окно включает 16 токенов подряд. Это значит, что фигура, находящаяся в определённом блоке патчей, будет анализироваться локально вместе со смежными патчами.
173
+ Применяем self-attention и затем сдвигаем окна (в целом, как и должно быть в swin blockе).
174
+ Дальше идет patch merging, и мы получаем 16 патчей на одно изображение (то есть 32 патча на два изображения).
175
+ Достаточно еще двух таких слоев (16 -> 4 -> 1), чтобы у нас остался один патч на изображение.
176
+
177
+ 4. Далее мы используем global average pooling, и передаем выход в классификационную голову.
178
+
179
+ #### Почему я не стала реализовывать этот подход
180
+ Swin transformer хорошо сработает на крупных изображениях с мелкими деталями, но в данной задаче спустя всего 3 слоя мы уже получаем один токен на изображение.
181
+ В первом слое локальное внимание ограничено работает сразу для двух изображений, а в следующем слое остается уже не так много токенов, чтобы извлекать информацию о фигурах.
182
+
183
+
184
+
185
+ ### Siamese Network с Triplet Loss
186
+
187
+ Вместо простой классификации мы обучаем энкодер, который преобразует изображения в эмбеддинги так, чтобы похожие изображения (круг-круг, квадрат-квадрат) были ближе друг к другу, а разные изображения (круг-квадрат) были дальше.
188
+
189
+ Используем Triplet Loss, где берём три изображения:
190
+ - Anchor – произвольное изображение (например, квадрат).
191
+ - Positive – ещё один квадрат.
192
+ - Negative – круг.
193
+
194
+ Модель минимизирует расстояние между anchor и positive и максимизирует его для negative. Чтобы модели было сложнее, используем hard negatives. Например, генерировать изображение с одинаковыми характеристиками, такими как положение фигуры, цвет, блюр, но с другой фигурой.
195
+
196
+ Используем легкий shared CNN энкодер. Энкодер обрабатывает изображения независимо, но выходные эмбеддинги сравниваются через triplet loss.
197
+ Важно, чтобы размерность эмбеддинга была достаточно низкой, чтобы не переобучиться на простой структуре.
198
+
199
+ Получаем эмбеддинги двух изображений, считаем евклидово расстояние.
200
+ Если меньше порога - фигуры одинаковые, иначе разные.
201
+
202
+
203
+ #### Почему я не стала реализовывать этот подход
204
+ Я уже делала это на своей работе, поэтому хотелось попробовать что-нибудь новое :)
205
+
206
+
207
+ ### Проблемы с которыми я столкнулась
208
+ 1. Вначале я решила написать и обучить всю архитектуру целиком (энкодер + cross-attention классификатор), но сразу же столкнулась с тем, что модель просто не обучалась.
209
+ Чтобы это отдебажить, я решила начать с малого - создала простой датасет и научила простую CNN предсказывать метку для двух изображений сразу.
210
+ Дальше, я добавляла углубление в энкодер, параллельно мониторя количество параметров, чтобы понимать, какое количество сэмплов мне нужно для обучения.
211
+ Та��им образом, я дошла до финальной архитектуры.
212
+ 2. У меня все еще сохранялись проблемы во время обучения (сеть не обучалась). Мониторинг нормы градиентов и весов помог мне понять, что веса из-за attention просто зануляются. Это я решила изменением оптимизатора на AdamW и уменьшением learning rate.
213
+ 3. Изначально планировалось показать улучшения в обучении с помощью self-supervised обучения (в сравнении с обучением с нуля), но по факту при тех же самых условиях обучение проходило одинаково в обоих случаях.
214
+ Это можно объяснить тем, что изображения были слишком маленькими и простыми, и подходу без предобучения также не требовалось много времени и большого количества данных.
215
+ Чтобы self-supervised метод действительно хорошо сработал (особенно без негативных примеров), нужны сложные аугментации, в этом же случае сложные аугментации сильно искажали изображения.
216
+ 4. Судя по кривой обучения, итоговый классификатор очень долгое время находился на плато, так как первые 7-8 эпох из 10 лосс не падал, а точность оставалась на уровне 50%. Это можно объяснить тем, что градиенты очень маленькие или очень шумные. Также, все зависит от исходной инициализации, и при маленьком датасете это может стать проблемой, так как по началу накапливается недостаточно сигналов, чтобы сойти с плато.
217
+
218
+
219
+ ## Что бы я точно не стала делать
220
+ Здесь я опишу подходы, которые сразу пришли мне в голову, но которые я бы точно не стала делать по итогу.
221
+ Опять же, я отталкивалась от переноса задачи на реальные данные.
222
+
223
+ 1. Сверточная сеть, которая принимает на вход одно изображение, и выдает для него класс (круг или квадрат).
224
+ Соответственно, получив предсказания для двух изображений, мы можем сделать вывод о том, содержат ли они одинаковую фигуру.
225
+ Этот подход очень простой, решает задачу в лоб, но он не масштабируется, так как при переносе на реальные кейсы терпит крах, потому что далеко не всегда у нас есть два четко разделенных класса (да и в целом у нас может и не быть классов, а только изображения, которые нужно сопоставить между собой).
226
+
227
+ 2. Детекция + классификация.
228
+ Можно было достаточно просто обучить детектор, который находил бы как определенный класс (круг или квадрат), так и просто "фигуру" без класса (казалось бы, решение предыдущей проблемы).
229
+ В реальности же этот подход тоже не масштабируется, так как 1) это дорогостоящая разметка, 2) детекторы могут ошибаться (и для таких кейсов мы бы тогда вообще могли ничего не предсказать), 3) задача может состоять в сравнении нескольких разнородных объектов на изображении, а не одного (например, образование новых опухолей).
230
+ То же само касается и сегментации.
231
+
232
+
233
+ Здесь стоит сделать важную поправку, что есть реальные задачи, где эти подходы могут сработать (например, мы точно знаем, что на изображении нас интересует только один объект, а все остальное - неинформативный фон).
234
+
235
+
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+
4
+ from src.inference import CrossAttentionInference
5
+
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ inference = CrossAttentionInference(
10
+ model_path="best_attention_classifier.pth",
11
+ device=device
12
+ )
13
+
14
+ st.title("Random Image Inference")
15
+
16
+ st.write(
17
+ "Нажмите кнопку ниже, чтобы сгенерировать пару случайных изображений и получить предсказание модели."
18
+ )
19
+
20
+ if st.button("Сгенерировать изображения"):
21
+ pred_label, (img1, img2) = inference.predict_random_pair()
22
+
23
+ col1, col2 = st.columns(2)
24
+
25
+ with col1:
26
+ st.image(img1, caption="Image 1", use_container_width=True)
27
+ with col2:
28
+ st.image(img2, caption="Image 2", use_container_width=True)
29
+
30
+ st.write(f"**Предсказанная метка**: {pred_label}")
best_attention_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6496c9cd313964eef9b899924b8429b11b257cf6ea78a18af3a06df2ed16afb8
3
+ size 1527896
best_byol.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a09baac108efe3034c6f47aa174ba7afb3e8b2732aad4598f9e24ca642337d0
3
+ size 1158467
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "cels-test"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = [
6
+ {name = "AleksandraSorokovikova",email = "alexandraroze2000@gmail.com"}
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "torch (>=2.6.0,<3.0.0)",
12
+ "torchvision (>=0.21.0,<0.22.0)",
13
+ "matplotlib (>=3.10.0,<4.0.0)",
14
+ "wandb (>=0.19.6,<0.20.0)",
15
+ "tqdm (>=4.67.1,<5.0.0)",
16
+ "streamlit (>=1.42.0,<2.0.0)"
17
+ ]
18
+
19
+
20
+ [build-system]
21
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
22
+ build-backend = "poetry.core.masonry.api"
src/__init__.py ADDED
File without changes
src/dataset.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image, ImageDraw, ImageFilter
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ def generate_image(
12
+ size: int = 32,
13
+ contrast: Tuple[int, int] = (90, 110),
14
+ blur_radius: Tuple[float, float] = (0.5, 1.5),
15
+ shape: Optional[str] = None,
16
+ max_background_intensity: int = 128,
17
+ min_shape_intensity: Optional[int] = None,
18
+ shape_size: Optional[int] = None,
19
+ location: str = 'random',
20
+ random_intensity: bool = False
21
+ ) -> Tuple[Image.Image, str]:
22
+ """
23
+ Generate an image with a shape (circle or square) on a background.
24
+ :param size: size of the image
25
+ :param contrast: contrast of the shape
26
+ :param blur_radius: radius of the Gaussian blur
27
+ :param shape: shape type (circle or square)
28
+ :param max_background_intensity: maximum intensity of the background
29
+ :param min_shape_intensity: minimum intensity of the shape
30
+ :param shape_size: size of the shape
31
+ :param location: location of the shape ('random' or 'center')
32
+ :param random_intensity: whether to randomly invert the shape intensity
33
+ """
34
+ background_intensity = random.randint(0, max_background_intensity)
35
+ background = Image.new('L', (size, size), background_intensity)
36
+
37
+ if shape:
38
+ assert shape in ['circle', 'square'], "Wrong shape type"
39
+ else:
40
+ shape = random.choice(['circle', 'square'])
41
+
42
+ if not min_shape_intensity:
43
+ random_contrast = random.randint(*contrast)
44
+ min_shape_intensity = min(background_intensity + random_contrast, 255)
45
+ shape_intensity = random.randint(min_shape_intensity, 255)
46
+
47
+ mask = Image.new('L', (size, size), 0)
48
+ draw = ImageDraw.Draw(mask)
49
+
50
+ if not shape_size:
51
+ min_size = 8
52
+ max_size = size // 2
53
+ shape_size = random.randint(min_size, max_size)
54
+
55
+ if location == 'random':
56
+ max_pos = size - shape_size - 1
57
+ top_left_x = random.randint(0, max_pos)
58
+ top_left_y = random.randint(0, max_pos)
59
+ else:
60
+ top_left_x = (size - shape_size) // 2
61
+ top_left_y = (size - shape_size) // 2
62
+
63
+ if shape == 'square':
64
+ draw.rectangle([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255)
65
+ else:
66
+ draw.ellipse([top_left_x, top_left_y, top_left_x + shape_size, top_left_y + shape_size], fill=255)
67
+
68
+ if blur_radius:
69
+ random_blur_radius = random.uniform(*blur_radius)
70
+ mask = mask.filter(ImageFilter.GaussianBlur(radius=random_blur_radius))
71
+ else:
72
+ mask = mask.filter(ImageFilter.SMOOTH)
73
+
74
+ shape_img = Image.new('L', (size, size), shape_intensity)
75
+ img = Image.composite(shape_img, background, mask)
76
+
77
+ if random_intensity and random.random() < 0.5:
78
+ img = Image.eval(img, lambda x: 255 - x)
79
+
80
+ return img, shape
81
+
82
+
83
+ class RandomPairDataset(Dataset):
84
+ def __init__(
85
+ self,
86
+ shape_params: Optional[dict] = None,
87
+ num_samples: int = 1000,
88
+ train: bool = True,
89
+ fixed_test_data: Optional[list] = None
90
+ ):
91
+ """
92
+ Dataset for training a model to compare two images.
93
+ :param shape_params: parameters for generate_image function
94
+ :param num_samples: number of samples in the dataset
95
+ :param train: whether to generate training or test data
96
+ :param fixed_test_data: fixed test data (optional)
97
+ """
98
+ self.train = train
99
+ self.num_samples = num_samples
100
+ self.transform = T.Compose([
101
+ T.ToTensor(),
102
+ T.Normalize(mean=(0.5,), std=(0.5,))
103
+ ])
104
+ if not shape_params:
105
+ self.shape_params = {}
106
+ else:
107
+ self.shape_params = shape_params
108
+
109
+ if not self.train:
110
+ if fixed_test_data is None:
111
+ self.data = [self._generate_pair() for _ in range(num_samples)]
112
+ else:
113
+ self.data = fixed_test_data
114
+
115
+ def __len__(self) -> int:
116
+ return self.num_samples
117
+
118
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
119
+ if self.train:
120
+ img1, shape1, img2, shape2, label = self._generate_pair()
121
+ else:
122
+ img1, shape1, img2, shape2, label = self.data[idx]
123
+
124
+ img1 = self.transform(img1)
125
+ img2 = self.transform(img2)
126
+
127
+ return img1, img2, torch.tensor(label, dtype=torch.float32)
128
+
129
+ def _generate_pair(self) -> Tuple[Image.Image, str, Image.Image, str, int]:
130
+ img1, shape1 = generate_image(**self.shape_params)
131
+ img2, shape2 = generate_image(**self.shape_params)
132
+ label = 1 if shape1 == shape2 else 0
133
+
134
+ return img1, shape1, img2, shape2, label
135
+
136
+
137
+ class RandomAugmentedDataset(Dataset):
138
+ def __init__(
139
+ self,
140
+ augmentations: T.Compose,
141
+ shape_params: Optional[dict] = None,
142
+ num_samples: int = 1000,
143
+ train: bool = True,
144
+ fixed_test_data: Optional[list] = None
145
+ ):
146
+ """
147
+ Dataset for training a model with contrastive learning.
148
+ :param augmentations: augmentations to apply to the images
149
+ :param shape_params: parameters for generate_image function
150
+ :param num_samples: number of samples in the dataset
151
+ :param train: whether to generate training or test data
152
+ :param fixed_test_data: fixed test data (optional
153
+ """
154
+ self.train = train
155
+ self.num_samples = num_samples
156
+ self.augmentations = augmentations
157
+ if not shape_params:
158
+ self.shape_params = {}
159
+ else:
160
+ self.shape_params = shape_params
161
+
162
+ if not self.train:
163
+ if fixed_test_data is None:
164
+ self.data = [self._generate_single() for _ in range(num_samples)]
165
+ else:
166
+ self.data = fixed_test_data
167
+
168
+ def __len__(self) -> int:
169
+ return self.num_samples
170
+
171
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
172
+ if self.train:
173
+ img, _ = self._generate_single()
174
+ else:
175
+ img, _ = self.data[idx]
176
+ view_1, view_2 = self.augmentations(img), self.augmentations(img)
177
+
178
+ return view_1, view_2
179
+
180
+ def _generate_single(self) -> Tuple[Image.Image, int]:
181
+ img, shape = generate_image(**self.shape_params)
182
+ label = 1 if shape == "circle" else 0
183
+
184
+ return img, label
185
+
186
+
187
+ class AddGaussianNoise(object):
188
+ def __init__(self, mean: float = 0.0, std: float = 0.05):
189
+ self.mean = mean
190
+ self.std = std
191
+
192
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
193
+ noise = torch.randn(tensor.size()) * self.std + self.mean
194
+ tensor = tensor + noise
195
+ return torch.clamp(tensor, 0., 1.)
196
+
197
+ def __repr__(self):
198
+ return f'{self.__class__.__name__}(mean={self.mean}, std={self.std})'
199
+
200
+
201
+ class ColorInversion(object):
202
+ def __call__(self, image: Image.Image) -> Image.Image:
203
+ return Image.eval(image, lambda x: 255 - x)
204
+
205
+ def __repr__(self):
206
+ return f'{self.__class__.__name__}()'
207
+
208
+
209
+ def get_byol_transforms() -> T.Compose:
210
+ """
211
+ Get augmentations for training with BYOL.
212
+ """
213
+ augmentations = T.Compose([
214
+ T.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
215
+ T.RandomHorizontalFlip(p=0.5),
216
+ T.RandomVerticalFlip(p=0.5),
217
+ T.RandomRotation(degrees=15),
218
+ T.ColorJitter(brightness=0.2, contrast=0.2),
219
+ T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.5),
220
+ T.RandomApply([ColorInversion()]),
221
+ T.ToTensor(),
222
+ T.RandomApply([AddGaussianNoise(mean=0.0, std=0.05)], p=0.5),
223
+ T.Normalize(mean=(0.5,), std=(0.5,))
224
+ ])
225
+ return augmentations
226
+
227
+
228
+ def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
229
+ """
230
+ Convert a tensor to a PIL image.
231
+ """
232
+ img_norm = tensor.cpu()[0]
233
+ img_denorm = img_norm * 0.5 + 0.5
234
+ arr = (img_denorm.numpy() * 255).astype(np.uint8)
235
+ pil_img = Image.fromarray(arr, mode='L')
236
+ return pil_img
src/inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+
4
+ from typing import Optional
5
+
6
+ from src.dataset import generate_image
7
+ from src.models import CrossAttentionClassifier, VGGLikeEncode
8
+
9
+
10
+ class CrossAttentionInference:
11
+ def __init__(
12
+ self,
13
+ model_path: str,
14
+ shape_params: Optional[dict] = None,
15
+ device: torch.device = torch.device("cpu"),
16
+ ):
17
+ if not shape_params:
18
+ self.shape_params = {}
19
+ else:
20
+ self.shape_params = shape_params
21
+ self.device = device
22
+
23
+ self.encoder = VGGLikeEncode(
24
+ in_channels=1,
25
+ out_channels=128,
26
+ feature_dim=32,
27
+ apply_pooling=False
28
+ )
29
+ self.model = CrossAttentionClassifier(encoder=self.encoder)
30
+
31
+ state_dict = torch.load(model_path, map_location=device)
32
+ self.model.load_state_dict(state_dict)
33
+
34
+ self.model.eval()
35
+ self.model.to(device)
36
+
37
+ self.transform = T.Compose([
38
+ T.ToTensor(),
39
+ T.Normalize(mean=(0.5,), std=(0.5,))
40
+ ])
41
+
42
+ def pil_to_tensor(self, img):
43
+ return self.transform(img).unsqueeze(0).to(self.device)
44
+
45
+
46
+ def predict_random_pair(self):
47
+ img1, _ = generate_image(**self.shape_params)
48
+ img2, _ = generate_image(**self.shape_params)
49
+
50
+ img1_tensor = self.pil_to_tensor(img1)
51
+ img2_tensor = self.pil_to_tensor(img2)
52
+
53
+ with torch.no_grad():
54
+ logits, _ = self.model(img1_tensor, img2_tensor)
55
+
56
+ preds = (torch.sigmoid(logits) > 0.5).float()
57
+ predicted_label = int(preds.item())
58
+
59
+ return predicted_label, (img1, img2)
src/models.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+
9
+ class VGGLikeEncode(nn.Module):
10
+ def __init__(
11
+ self,
12
+ in_channels: int = 1,
13
+ out_channels: int = 128,
14
+ feature_dim: int = 32,
15
+ apply_pooling: bool = False
16
+ ):
17
+ """
18
+ VGG-like encoder for grayscale images.
19
+ :param in_channels: number of input channels
20
+ :param out_channels: number of output channels
21
+ :param feature_dim: number of channels in the intermediate layers
22
+ :param apply_pooling: whether to apply global average pooling at the end
23
+ """
24
+ super().__init__()
25
+ self.apply_pooling = apply_pooling
26
+
27
+ self.block1 = nn.Sequential(
28
+ nn.Conv2d(in_channels, feature_dim, kernel_size=3, padding=1),
29
+ nn.BatchNorm2d(feature_dim),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1),
32
+ nn.ReLU(inplace=True),
33
+ nn.MaxPool2d(kernel_size=2)
34
+ )
35
+
36
+ self.block2 = nn.Sequential(
37
+ nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=3, padding=1),
38
+ nn.ReLU(inplace=True),
39
+ nn.BatchNorm2d(feature_dim * 2),
40
+ nn.Conv2d(feature_dim * 2, feature_dim * 2, kernel_size=3, padding=1),
41
+ nn.ReLU(inplace=True),
42
+ nn.MaxPool2d(kernel_size=2)
43
+ )
44
+
45
+ self.block3 = nn.Sequential(
46
+ nn.Conv2d(feature_dim * 2, out_channels, kernel_size=3, padding=1),
47
+ nn.ReLU(inplace=True),
48
+ nn.BatchNorm2d(out_channels),
49
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
50
+ nn.ReLU(inplace=True),
51
+ nn.MaxPool2d(kernel_size=1)
52
+ )
53
+
54
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
55
+ self.blocks = [self.block1, self.block2, self.block3]
56
+
57
+ def forward(self, x: Tensor) -> Tensor:
58
+ x = self.block1(x)
59
+ x = self.block2(x)
60
+ x = self.block3(x)
61
+ if self.apply_pooling:
62
+ x = self.global_avg_pool(x).view(x.shape[0], -1)
63
+ return x
64
+
65
+ def get_conv_layer(self, block_num: int):
66
+ if block_num >= len(self.blocks):
67
+ return None
68
+ return self.blocks[block_num][0]
69
+
70
+
71
+ class CrossAttentionClassifier(nn.Module):
72
+ def __init__(
73
+ self,
74
+ feature_dim: int = 32,
75
+ num_heads: int = 4,
76
+ linear_dim: int = 128,
77
+ out_channels: int = 128,
78
+ encoder: Optional[VGGLikeEncode] = None
79
+ ):
80
+ """
81
+ Cross-attention classifier for comparing two grayscale images.
82
+ :param feature_dim: number of channels in the intermediate layers
83
+ :param num_heads: number of attention heads
84
+ :param linear_dim: number of units in the linear layer
85
+ :param out_channels: number of output channels
86
+ :param encoder: encoder to use
87
+ """
88
+ super(CrossAttentionClassifier, self).__init__()
89
+ if encoder:
90
+ self.encoder = encoder
91
+ else:
92
+ self.encoder = VGGLikeEncode(in_channels=1, feature_dim=feature_dim, out_channels=out_channels)
93
+
94
+ self.out_channels = out_channels
95
+ self.seq_len = 8 * 8
96
+ self.pos_embedding = nn.Parameter(torch.randn(self.seq_len, 1, out_channels) * 0.01)
97
+
98
+ self.cross_attention = nn.MultiheadAttention(
99
+ embed_dim=out_channels,
100
+ num_heads=num_heads,
101
+ batch_first=False
102
+ )
103
+
104
+ self.norm = nn.LayerNorm(out_channels)
105
+
106
+ self.classifier = nn.Sequential(
107
+ nn.Linear(out_channels, linear_dim),
108
+ nn.ReLU(),
109
+ nn.Linear(linear_dim, 1)
110
+ )
111
+
112
+ def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
113
+ feat1 = self.encoder(img1)
114
+ feat2 = self.encoder(img2)
115
+
116
+ B, C, H, W = feat1.shape
117
+ seq_len = H * W
118
+
119
+ feat1_flat = feat1.view(B, C, seq_len).permute(2, 0, 1)
120
+ feat2_flat = feat2.view(B, C, seq_len).permute(2, 0, 1)
121
+
122
+ feat1_flat = feat1_flat + self.pos_embedding
123
+ feat2_flat = feat2_flat + self.pos_embedding
124
+
125
+ feat1_flat = self.norm(feat1_flat)
126
+ feat2_flat = self.norm(feat2_flat)
127
+
128
+ attn_output, attn_weights = self.cross_attention(
129
+ query=feat1_flat,
130
+ key=feat2_flat,
131
+ value=feat2_flat,
132
+ need_weights=True,
133
+ average_attn_weights=True
134
+ )
135
+ pooled_features = attn_output.mean(dim=0)
136
+ logits = self.classifier(pooled_features).squeeze(-1)
137
+
138
+ return logits, attn_weights
139
+
140
+
141
+ class NormalizedMSELoss(nn.Module):
142
+ def __init__(self):
143
+ """
144
+ Normalized MSE loss for BYOL training.
145
+ """
146
+ super(NormalizedMSELoss, self).__init__()
147
+
148
+ def forward(self, view1: Tensor, view2: Tensor) -> Tensor:
149
+ v1 = F.normalize(view1, dim=-1)
150
+ v2 = F.normalize(view2, dim=-1)
151
+ return 2 - 2 * (v1 * v2).sum(dim=-1)
152
+
153
+
154
+ class MLP(nn.Module):
155
+ def __init__(self, input_dim: int, projection_dim: int = 128, hidden_dim: int = 512):
156
+ """
157
+ MLP for BYOL training.
158
+ :param input_dim: input dimension
159
+ :param projection_dim: projection dimension
160
+ :param hidden_dim: hidden dimension
161
+ """
162
+ super(MLP, self).__init__()
163
+
164
+ self.net = nn.Sequential(
165
+ nn.Linear(input_dim, hidden_dim),
166
+ nn.BatchNorm1d(hidden_dim),
167
+ nn.ReLU(inplace=True),
168
+ nn.Linear(hidden_dim, projection_dim)
169
+ )
170
+
171
+ def forward(self, x: Tensor) -> Tensor:
172
+ return self.net(x)
173
+
174
+
175
+ class EncoderProjecter(nn.Module):
176
+ def __init__(self, encoder: nn.Module, hidden_dim: int = 512, projection_out_dim: int = 128):
177
+ """
178
+ Encoder followed by a projection MLP.
179
+ :param encoder: encoder to use
180
+ :param hidden_dim: hidden dimension
181
+ :param projection_out_dim: projection output dimension
182
+ """
183
+ super(EncoderProjecter, self).__init__()
184
+
185
+ self.encoder = encoder
186
+ self.projection = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim)
187
+
188
+ def forward(self, x: Tensor) -> Tensor:
189
+ h = self.encoder(x)
190
+ return self.projection(h)
191
+
192
+
193
+ # https://arxiv.org/pdf/2006.07733
194
+ class BYOL(nn.Module):
195
+ def __init__(
196
+ self,
197
+ hidden_dim: int = 512,
198
+ projection_out_dim: int = 128,
199
+ target_decay: float = 0.9975
200
+ ):
201
+ """
202
+ BYOL model for self-supervised learning.
203
+ :param hidden_dim: hidden dimension
204
+ :param projection_out_dim: projection output dimension
205
+ :param target_decay: target network decay rate
206
+ """
207
+ super(BYOL, self).__init__()
208
+ encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=True)
209
+ self.online_network = EncoderProjecter(encoder)
210
+ self.online_predictor = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim)
211
+
212
+ self.target_network = EncoderProjecter(encoder)
213
+ self.target_network.load_state_dict(self.online_network.state_dict())
214
+
215
+ self.target_network.eval()
216
+ for param in self.target_network.parameters():
217
+ param.requires_grad = False
218
+ self.target_decay = target_decay
219
+ self.loss_function = NormalizedMSELoss()
220
+
221
+ @torch.no_grad()
222
+ def soft_update_target_network(self):
223
+ for online_p, target_p in zip(self.online_network.parameters(), self.target_network.parameters()):
224
+ target_p.data = target_p.data * self.target_decay + online_p.data * (1. - self.target_decay)
225
+
226
+ def forward(self, view: Tensor) -> Tuple[Tensor, Tensor]:
227
+ online_proj = self.online_network(view)
228
+ target_proj = self.target_network(view)
229
+
230
+ return online_proj, target_proj
231
+
232
+ def loss(self, view1: Tensor, view2: Tensor) -> Tensor:
233
+ online_proj1, target_proj1 = self(view1)
234
+ online_proj2, target_proj2 = self(view2)
235
+
236
+ online_prediction_1 = self.online_predictor(online_proj1)
237
+ online_prediction_2 = self.online_predictor(online_proj2)
238
+
239
+ loss1 = self.loss_function(online_prediction_1, target_proj2.detach())
240
+ loss2 = self.loss_function(online_prediction_2, target_proj1.detach())
241
+ return torch.mean(loss1 + loss2)
train_byol.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ import wandb
5
+ from torch import nn, optim
6
+ from torch.nn.functional import cosine_similarity
7
+ from torch.optim import lr_scheduler
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+ from typing_extensions import Optional
11
+
12
+ from src.dataset import RandomAugmentedDataset, get_byol_transforms
13
+ from src.models import BYOL
14
+
15
+
16
+ def get_data_loaders(
17
+ batch_size: int,
18
+ num_train_samples: int,
19
+ num_val_samples: int,
20
+ shape_params: Optional[dict] = None,
21
+ num_workers: int = 0
22
+ ):
23
+ augmentations = get_byol_transforms()
24
+
25
+ train_dataset = RandomAugmentedDataset(
26
+ augmentations,
27
+ shape_params,
28
+ num_samples=num_train_samples,
29
+ train=True
30
+ )
31
+ val_dataset = RandomAugmentedDataset(
32
+ augmentations,
33
+ shape_params,
34
+ num_samples=num_val_samples,
35
+ train=False
36
+ )
37
+
38
+ train_loader = DataLoader(
39
+ train_dataset,
40
+ batch_size=batch_size,
41
+ shuffle=True,
42
+ num_workers=num_workers
43
+ )
44
+ val_loader = DataLoader(
45
+ val_dataset,
46
+ batch_size=batch_size,
47
+ shuffle=False,
48
+ num_workers=num_workers
49
+ )
50
+
51
+ return train_loader, val_loader
52
+
53
+
54
+ def build_model(lr: float):
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ model = BYOL().to(device)
57
+
58
+ optimizer = optim.Adam(
59
+ list(model.online_network.parameters()) + list(model.online_predictor.parameters()),
60
+ lr=lr
61
+ )
62
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2)
63
+
64
+ return model, optimizer, scheduler, device
65
+
66
+
67
+ def train_epoch(
68
+ model: nn.Module,
69
+ optimizer: optim.Optimizer,
70
+ train_loader: DataLoader,
71
+ device: torch.device
72
+ ) -> dict:
73
+ model.train()
74
+ running_train_loss = 0.0
75
+ total_cos_sim, total_l2_dist, total_feat_norm, total_grad_norm = 0.0, 0.0, 0.0, 0.0
76
+ num_train_batches = 0
77
+
78
+ for (view_1, view_2) in tqdm(train_loader, desc="Training"):
79
+ view_1 = view_1.to(device)
80
+ view_2 = view_2.to(device)
81
+
82
+ loss = model.loss(view_1, view_2)
83
+
84
+ optimizer.zero_grad()
85
+ loss.backward()
86
+
87
+ with torch.no_grad():
88
+ online_proj1, target_proj1 = model(view_1)
89
+ online_proj2, target_proj2 = model(view_2)
90
+
91
+ cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item()
92
+ l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item()
93
+ feat_norm = torch.norm(online_proj1, dim=-1).mean().item()
94
+
95
+ grad_norm = torch.norm(
96
+ torch.cat([
97
+ p.grad.flatten()
98
+ for p in model.online_network.parameters()
99
+ if p.grad is not None
100
+ ])
101
+ ).item()
102
+
103
+ total_cos_sim += cos_sim
104
+ total_l2_dist += l2_dist
105
+ total_feat_norm += feat_norm
106
+ total_grad_norm += grad_norm
107
+
108
+ optimizer.step()
109
+ model.soft_update_target_network()
110
+
111
+ running_train_loss += loss.item()
112
+ num_train_batches += 1
113
+
114
+ train_loss = running_train_loss / num_train_batches
115
+ train_cos_sim = total_cos_sim / num_train_batches
116
+ train_l2_dist = total_l2_dist / num_train_batches
117
+ train_feat_norm = total_feat_norm / num_train_batches
118
+ train_grad_norm = total_grad_norm / num_train_batches
119
+
120
+ return {
121
+ "loss": train_loss,
122
+ "cos_sim": train_cos_sim,
123
+ "l2_dist": train_l2_dist,
124
+ "feat_norm": train_feat_norm,
125
+ "grad_norm": train_grad_norm,
126
+ }
127
+
128
+
129
+ @torch.no_grad()
130
+ def validate(
131
+ model: nn.Module,
132
+ val_loader: DataLoader,
133
+ device: torch.device
134
+ ) -> dict:
135
+ model.eval()
136
+ running_val_loss = 0.0
137
+ total_cos_sim, total_l2_dist, total_feat_norm = 0.0, 0.0, 0.0
138
+ num_val_batches = 0
139
+
140
+ for (view_1, view_2) in tqdm(val_loader, desc="Validation"):
141
+ view_1 = view_1.to(device)
142
+ view_2 = view_2.to(device)
143
+
144
+ loss = model.loss(view_1, view_2)
145
+ running_val_loss += loss.item()
146
+
147
+ online_proj1, target_proj1 = model(view_1)
148
+ online_proj2, target_proj2 = model(view_2)
149
+
150
+ cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item()
151
+ l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item()
152
+ feat_norm = torch.norm(online_proj1, dim=-1).mean().item()
153
+
154
+ total_cos_sim += cos_sim
155
+ total_l2_dist += l2_dist
156
+ total_feat_norm += feat_norm
157
+ num_val_batches += 1
158
+
159
+ val_loss = running_val_loss / num_val_batches
160
+ val_cos_sim = total_cos_sim / num_val_batches
161
+ val_l2_dist = total_l2_dist / num_val_batches
162
+ val_feat_norm = total_feat_norm / num_val_batches
163
+
164
+ return {
165
+ "loss": val_loss,
166
+ "cos_sim": val_cos_sim,
167
+ "l2_dist": val_l2_dist,
168
+ "feat_norm": val_feat_norm
169
+ }
170
+
171
+
172
+ def train(
173
+ model: nn.Module,
174
+ optimizer: optim.Optimizer,
175
+ scheduler,
176
+ device: torch.device,
177
+ train_loader: DataLoader,
178
+ val_loader: DataLoader,
179
+ num_epochs: int,
180
+ early_stopping_patience: int = 3,
181
+ save_path: str = "best_byol.pth"
182
+ ):
183
+ best_loss = float("inf")
184
+ epochs_no_improve = 0
185
+ print("Start training...")
186
+
187
+ for epoch in range(num_epochs):
188
+ print(f"Epoch {epoch + 1}/{num_epochs}")
189
+
190
+ train_metrics = train_epoch(model, optimizer, train_loader, device)
191
+
192
+ val_metrics = validate(model, val_loader, device)
193
+
194
+ wandb.log({
195
+ "epoch": epoch + 1,
196
+ "train_loss": train_metrics["loss"],
197
+ "train_cos_sim": train_metrics["cos_sim"],
198
+ "train_l2_dist": train_metrics["l2_dist"],
199
+ "train_feat_norm": train_metrics["feat_norm"],
200
+ "train_grad_norm": train_metrics["grad_norm"],
201
+ "val_loss": val_metrics["loss"],
202
+ "val_cos_sim": val_metrics["cos_sim"],
203
+ "val_l2_dist": val_metrics["l2_dist"],
204
+ "val_feat_norm": val_metrics["feat_norm"],
205
+ })
206
+
207
+ print(
208
+ f"Train Loss: {train_metrics['loss']:.4f} | "
209
+ f"CosSim: {train_metrics['cos_sim']:.4f} | "
210
+ f"L2Dist: {train_metrics['l2_dist']:.4f}"
211
+ )
212
+ print(
213
+ f"Val Loss: {val_metrics['loss']:.4f} | "
214
+ f"CosSim: {val_metrics['cos_sim']:.4f} | "
215
+ f"L2Dist: {val_metrics['l2_dist']:.4f}"
216
+ )
217
+
218
+ current_val_loss = val_metrics["loss"]
219
+ if current_val_loss < best_loss or val_metrics['cos_sim'] >= 0.86:
220
+ best_loss = current_val_loss
221
+ encoder_state_dict = model.online_network.encoder.state_dict()
222
+ torch.save(encoder_state_dict, save_path)
223
+ epochs_no_improve = 0
224
+ else:
225
+ epochs_no_improve += 1
226
+
227
+ scheduler.step(val_metrics["cos_sim"])
228
+
229
+ if epochs_no_improve >= early_stopping_patience:
230
+ print(f"Early stopping on epoch {epoch + 1}")
231
+ break
232
+
233
+
234
+ def main(config: dict):
235
+ wandb.init(project="contrastive_learning_byol", config=config)
236
+
237
+ train_loader, val_loader = get_data_loaders(
238
+ batch_size=config["batch_size"],
239
+ num_train_samples=config["num_train_samples"],
240
+ num_val_samples=config["num_val_samples"],
241
+ shape_params=config["shape_params"]
242
+ )
243
+
244
+ model, optimizer, scheduler, device = build_model(
245
+ lr=config["lr"]
246
+ )
247
+
248
+ train(
249
+ model=model,
250
+ optimizer=optimizer,
251
+ scheduler=scheduler,
252
+ device=device,
253
+ train_loader=train_loader,
254
+ val_loader=val_loader,
255
+ num_epochs=config["num_epochs"],
256
+ early_stopping_patience=config["early_stopping_patience"],
257
+ save_path=config["save_path"]
258
+ )
259
+
260
+ wandb.finish()
261
+
262
+
263
+ if __name__ == "__main__":
264
+ # parser = argparse.ArgumentParser(description="Train BYOL model")
265
+ # parser.add_argument("--batch_size", type=int, default=512)
266
+ # parser.add_argument("--lr", type=float, default=5e-4)
267
+ # parser.add_argument("--num_epochs", type=int, default=15)
268
+ # parser.add_argument("--num_train_samples", type=int, default=100000)
269
+ # parser.add_argument("--num_val_samples", type=int, default=10000)
270
+ # parser.add_argument("--random_intensity", type=int, default=1)
271
+ # parser.add_argument("--early_stopping_patience", type=int, default=3)
272
+ # parser.add_argument("--save_path", type=str, default="best_byol.pth")
273
+ # args = parser.parse_args()
274
+
275
+ # config = {
276
+ # "batch_size": args.batch_size,
277
+ # "lr": args.lr,
278
+ # "num_epochs": args.num_epochs,
279
+ # "num_train_samples": args.num_train_samples,
280
+ # "num_val_samples": args.num_val_samples,
281
+ # "shape_params": {
282
+ # "random_intensity": bool(args.random_intensity)
283
+ # },
284
+ # "early_stopping_patience": args.early_stopping_patience,
285
+ # "save_path": args.save_path
286
+ # }
287
+
288
+ config = {
289
+ "batch_size": 1024,
290
+ "lr": 1e-3,
291
+ "num_epochs": 15,
292
+ "num_train_samples": 100000,
293
+ "num_val_samples": 10000,
294
+ "shape_params": {
295
+ "random_intensity": True
296
+ },
297
+ "early_stopping_patience": 3,
298
+ "save_path": "best_byol.pth"
299
+ }
300
+
301
+ main(config)
train_cross_classifier.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ import wandb
8
+ from torch.optim.lr_scheduler import StepLR
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from typing_extensions import Optional
12
+
13
+ from src.dataset import RandomPairDataset
14
+ from src.models import CrossAttentionClassifier, VGGLikeEncode
15
+
16
+
17
+ def visualize_attention(attn_heatmap, epoch: int):
18
+ fig, ax = plt.subplots(figsize=(6, 6))
19
+ im = ax.imshow(attn_heatmap, cmap="hot", interpolation="nearest")
20
+ plt.colorbar(im, fraction=0.046, pad=0.04)
21
+ plt.title(f"Attention Heatmap (Flatten 64x64) | Epoch {epoch}")
22
+
23
+ wandb.log({"Flatten Attention Heatmap": wandb.Image(fig, caption=f"Flatten 64x64 | Epoch {epoch}")})
24
+
25
+ plt.close(fig)
26
+
27
+
28
+ def get_data_loaders(
29
+ num_train_samples: int,
30
+ num_val_samples: int,
31
+ batch_size: int,
32
+ num_workers: int = 0,
33
+ shape_params: Optional[dict] = None,
34
+ ):
35
+ train_dataset = RandomPairDataset(
36
+ shape_params=shape_params,
37
+ num_samples=num_train_samples,
38
+ train=True
39
+ )
40
+ val_dataset = RandomPairDataset(
41
+ shape_params=shape_params,
42
+ num_samples=num_val_samples,
43
+ train=False
44
+ )
45
+
46
+ train_loader = DataLoader(
47
+ train_dataset,
48
+ batch_size=batch_size,
49
+ shuffle=True,
50
+ num_workers=num_workers
51
+ )
52
+ val_loader = DataLoader(
53
+ val_dataset,
54
+ batch_size=batch_size,
55
+ shuffle=False,
56
+ num_workers=num_workers
57
+ )
58
+
59
+ return train_loader, val_loader
60
+
61
+
62
+ def build_model(
63
+ path_to_encoder: str,
64
+ lr: float,
65
+ weight_decay: float,
66
+ step_size: int,
67
+ gamma: float,
68
+ device: torch.device
69
+ ):
70
+ encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=False)
71
+ encoder.load_state_dict(torch.load(path_to_encoder))
72
+
73
+ model = CrossAttentionClassifier(encoder=encoder)
74
+ model = model.to(device)
75
+
76
+ criterion = nn.BCEWithLogitsLoss()
77
+
78
+ optimizer = optim.Adam(
79
+ model.parameters(),
80
+ lr=lr,
81
+ weight_decay=weight_decay
82
+ )
83
+
84
+ scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
85
+
86
+ return model, criterion, optimizer, scheduler
87
+
88
+
89
+ def train_epoch(
90
+ model: nn.Module,
91
+ criterion: nn.Module,
92
+ optimizer: optim.Optimizer,
93
+ train_loader: DataLoader,
94
+ device: torch.device
95
+ ):
96
+ model.train()
97
+ running_loss = 0.0
98
+ correct = 0
99
+ total = 0
100
+
101
+ for img1, img2, labels in tqdm(train_loader, desc="Training", leave=False):
102
+ img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
103
+
104
+ optimizer.zero_grad()
105
+
106
+ logits, attn_weights = model(img1, img2)
107
+ loss = criterion(logits, labels)
108
+
109
+ loss.backward()
110
+ optimizer.step()
111
+
112
+ running_loss += loss.item() * img1.size(0)
113
+
114
+ preds = (torch.sigmoid(logits) > 0.5).float()
115
+ correct += (preds == labels).sum().item()
116
+ total += labels.size(0)
117
+
118
+ epoch_loss = running_loss / len(train_loader.dataset)
119
+ epoch_acc = correct / total
120
+
121
+ return epoch_loss, epoch_acc
122
+
123
+
124
+ @torch.no_grad()
125
+ def validate(
126
+ model: nn.Module,
127
+ criterion: nn.Module,
128
+ val_loader: DataLoader,
129
+ device: torch.device
130
+ ):
131
+ model.eval()
132
+ running_loss = 0.0
133
+ correct = 0
134
+ total = 0
135
+
136
+ for img1, img2, labels in tqdm(val_loader, desc="Validation", leave=False):
137
+ img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
138
+
139
+ logits, attn_weights = model(img1, img2)
140
+ loss = criterion(logits, labels)
141
+
142
+ running_loss += loss.item() * img1.size(0)
143
+
144
+ preds = (torch.sigmoid(logits) > 0.5).float()
145
+ correct += (preds == labels).sum().item()
146
+ total += labels.size(0)
147
+
148
+ epoch_loss = running_loss / len(val_loader.dataset)
149
+ epoch_acc = correct / total
150
+
151
+ return epoch_loss, epoch_acc
152
+
153
+
154
+ def train(
155
+ model: nn.Module,
156
+ criterion: nn.Module,
157
+ optimizer: optim.Optimizer,
158
+ scheduler,
159
+ train_loader: DataLoader,
160
+ val_loader: DataLoader,
161
+ device: torch.device,
162
+ num_epochs: int = 30,
163
+ save_path: str = "best_attention_classifier.pth"
164
+ ):
165
+ best_val_loss = float("inf")
166
+ epochs_no_improve = 0
167
+ print("Start training...")
168
+
169
+ for epoch in range(num_epochs):
170
+ print(f"Epoch {epoch + 1}/{num_epochs}")
171
+
172
+ train_loss, train_acc = train_epoch(model, criterion, optimizer, train_loader, device)
173
+
174
+ val_loss, val_acc = validate(model, criterion, val_loader, device)
175
+
176
+ scheduler.step()
177
+
178
+ wandb.log({
179
+ "epoch": epoch + 1,
180
+ "train_loss": train_loss,
181
+ "train_acc": train_acc,
182
+ "val_loss": val_loss,
183
+ "val_acc": val_acc,
184
+ "lr": optimizer.param_groups[0]["lr"],
185
+ })
186
+
187
+ print(
188
+ f"learning rate: {optimizer.param_groups[0]['lr']:.6f}, "
189
+ f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
190
+ f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
191
+ )
192
+
193
+ if val_loss < best_val_loss:
194
+ best_val_loss = val_loss
195
+ torch.save(model.state_dict(), save_path)
196
+ epochs_no_improve = 0
197
+ else:
198
+ epochs_no_improve += 1
199
+
200
+ with torch.no_grad():
201
+ sample_img1, sample_img2, sample_labels = next(iter(val_loader))
202
+ sample_img1, sample_img2 = sample_img1.to(device), sample_img2.to(device)
203
+
204
+ _, sample_attn_weights = model(sample_img1, sample_img2)
205
+
206
+ wandb.log({
207
+ "attention_std": sample_attn_weights.std().item(),
208
+ "attention_mean": sample_attn_weights.mean().item(),
209
+ })
210
+
211
+ attn_heatmap = sample_attn_weights[0].detach().cpu().numpy()
212
+ visualize_attention(attn_heatmap, epoch)
213
+
214
+
215
+ def main(config):
216
+ wandb.init(project="cross_attention_classifier", config=config)
217
+
218
+ train_loader, val_loader = get_data_loaders(
219
+ shape_params=config["shape_params"],
220
+ num_train_samples=config["num_train_samples"],
221
+ num_val_samples=config["num_val_samples"],
222
+ batch_size=config["batch_size"]
223
+ )
224
+
225
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
+
227
+ model, criterion, optimizer, scheduler = build_model(
228
+ path_to_encoder=config["path_to_encoder"],
229
+ lr=config["lr"],
230
+ weight_decay=config["weight_decay"],
231
+ step_size=config["step_size"],
232
+ gamma=config["gamma"],
233
+ device=device
234
+ )
235
+
236
+ train(
237
+ model=model,
238
+ criterion=criterion,
239
+ optimizer=optimizer,
240
+ scheduler=scheduler,
241
+ train_loader=train_loader,
242
+ val_loader=val_loader,
243
+ device=device,
244
+ num_epochs=config["num_epochs"],
245
+ save_path=config["save_path"]
246
+ )
247
+
248
+ wandb.finish()
249
+
250
+
251
+ if __name__ == "__main__":
252
+
253
+ # parser = argparse.ArgumentParser(description="Train classifier model")
254
+ # parser.add_argument("--path_to_encoder", type=str, default="best_byol.pth")
255
+ # parser.add_argument("--batch_size", type=int, default=256)
256
+ # parser.add_argument("--lr", type=float, default=8e-5)
257
+ # parser.add_argument("--weight_decay", type=float, default=1e-4)
258
+ # parser.add_argument("--step_size", type=int, default=10)
259
+ # parser.add_argument("--gamma", type=float, default=0.1)
260
+ # parser.add_argument("--num_epochs", type=int, default=10)
261
+ # parser.add_argument("--num_train_samples", type=int, default=10000)
262
+ # parser.add_argument("--num_val_samples", type=int, default=2000)
263
+ # parser.add_argument("--save_path", type=str, default="best_attention_classifier.pth")
264
+ # args = parser.parse_args()
265
+
266
+ # config = {
267
+ # "path_to_encoder": args.path_to_encoder,
268
+ # "batch_size": args.batch_size,
269
+ # "lr": args.lr,
270
+ # "weight_decay": args.weight_decay,
271
+ # "step_size": args.step_size,
272
+ # "gamma": args.gamma,
273
+ # "num_epochs": args.num_epochs,
274
+ # "num_train_samples": args.num_train_samples,
275
+ # "num_val_samples": args.num_val_samples,
276
+ # "save_path": args.save_path,
277
+ # }
278
+
279
+ config = {
280
+ "path_to_encoder": "best_byol.pth",
281
+ "batch_size": 256,
282
+ "lr": 8e-5,
283
+ "weight_decay": 1e-4,
284
+ "step_size": 10,
285
+ "gamma": 0.1,
286
+ "num_epochs": 10,
287
+ "num_train_samples": 10000,
288
+ "num_val_samples": 2000,
289
+ "save_path": "best_attention_classifier.pth",
290
+ }
291
+
292
+ if "shape_params" not in config:
293
+ config["shape_params"] = {}
294
+
295
+ main(config)