mjpyeon commited on
Commit
0dce87a
ยท
1 Parent(s): bcf3441

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
LICENSE CHANGED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXAONEPath AI Model License Agreement 1.0 - NC
2
+
3
+ This License Agreement (โ€œAgreementโ€) is entered into between you (โ€œLicenseeโ€) and LG Management Development Institute Co., Ltd. (โ€œLicensorโ€), governing the use of the EXAONEPath AI Model (โ€œModelโ€). By downloading, installing, copying, or using the Model, you agree to comply with and be bound by the terms of this Agreement. If you do not agree to all the terms, you must not download, install, copy, or use the Model. This Agreement constitutes a binding legal agreement between the Licensee and Licensor.
4
+
5
+ 1. Definitions
6
+ 1.1 Model: The artificial intelligence model provided by Licensor, which includes any software, algorithms, machine learning models, or related components supplied by Licensor. This definition extends to encompass all updates, enhancements, improvements, bug fixes, patches, or other modifications that may be provided by Licensor from time to time, whether automatically or manually implemented.
7
+ 1.2 Derivatives: Any modifications, alterations, enhancements, improvements, adaptations, or derivative works of the Model created by Licensee or any third party. This includes changes made to the Model's architecture, parameters, data processing methods, or any other aspect of the Model that results in a modification of its functionality or output.
8
+ 1.3 Output: Any data, results, content, predictions, analyses, insights, or other materials generated by the Model or Derivatives, regardless of whether they are in their original form or have been further processed or modified by the Licensee. This includes, but is not limited to, textual or numerical produced directly or indirectly through the use of the Model.
9
+ 1.4 Licensor: LG Management Development Institute Co., Ltd., the owner, developer, and provider of the EXAONEPath AI Model. The Licensor holds all rights, title, and interest in the Model and is responsible for granting licenses to use the Model under the terms specified in this Agreement.
10
+ 1.5 Licensee: The individual, organization, corporation, academic institution, government agency, or other entity using or intending to use the Model under the terms and conditions of this Agreement. The Licensee is responsible for ensuring compliance with the Agreement by all authorized users who access or utilize the Model on behalf of the Licensee.
11
+
12
+ 2. License Grant
13
+ 2.1 Grant of License: Subject to the terms and conditions outlined in this Agreement, the Licensor hereby grants the Licensee a limited, non-exclusive, non-transferable, worldwide, and revocable license to:
14
+ a. Access, download, install, and use the Model solely for research purposes. This includes evaluation, testing, academic research and experimentation.
15
+ b. Publicly disclose research results and findings derived from the use of the Model or Derivatives, including publishing papers or presentations.
16
+ c. Modify the Model and create Derivatives based on the Model, provided that such modifications and Derivatives are used exclusively for research purposes. The Licensee may conduct experiments, perform analyses, and apply custom modifications to the Model to explore its capabilities and performance under various scenarios. If the Model is modified, the modified Model must include "EXAONEPath" at the beginning of its name.
17
+ d. Distribute the Model and Derivatives in each case with a copy of this Agreement.
18
+ 2.2 Scope of License: The license granted herein does not authorize the Licensee to use the Model for any purpose not explicitly permitted under this Agreement. Any use beyond the scope of this license, including any commercial application or external distribution, is strictly prohibited unless explicitly agreed upon in writing by the Licensor.
19
+
20
+ 3. Restrictions
21
+ 3.1 Commercial Use: The Licensee is expressly prohibited from using the Model, Derivatives, or Output for any commercial purposes, including but not limited to, developing or deploying products, services, or applications that generate revenue, whether directly or indirectly. Any commercial exploitation of the Model or its derivatives requires a separate commercial license agreement with the Licensor. Furthermore, the Licensee shall not use the Model, Derivatives or Output to develop or improve other models, except for research purposes, which is explicitly permitted.
22
+ 3.2 Reverse Engineering: The Licensee shall not decompile, disassemble, reverse engineer, or attempt to derive the source code, underlying ideas, algorithms, or structure of the Model, except to the extent that such activities are expressly permitted by applicable law. Any attempt to bypass or circumvent technological protection measures applied to the Model is strictly prohibited.
23
+ 3.3 Unlawful Use: The Licensee shall not use the Model and Derivatives for any illegal, fraudulent, or unauthorized activities, nor for any purpose that violates applicable laws or regulations. This includes but is not limited to the creation, distribution, or dissemination of malicious, deceptive, or unlawful content.
24
+ 3.4 Ethical Use: The Licensee shall ensure that the Model or Derivatives is used in an ethical and responsible manner, adhering to the following guidelines:
25
+ a. The Model and Derivatives shall not be used to generate, propagate, or amplify false, misleading, or harmful information, including fake news, misinformation, or disinformation.
26
+ b. The Model and Derivatives shall not be employed to create, distribute, or promote content that is discriminatory, harassing, defamatory, abusive, or otherwise offensive to individuals or groups based on race, gender, sexual orientation, religion, nationality, or other protected characteristics.
27
+ c. The Model and Derivatives shall not infringe on the rights of others, including intellectual property rights, privacy rights, or any other rights recognized by law. The Licensee shall obtain all necessary permissions and consents before using the Model and Derivatives in a manner that may impact the rights of third parties.
28
+ d. The Model and Derivatives shall not be used in a way that causes harm, whether physical, mental, emotional, or financial, to individuals, organizations, or communities. The Licensee shall take all reasonable measures to prevent misuse or abuse of the Model and Derivatives that could result in harm or injury.
29
+
30
+ 4. Ownership
31
+ 4.1 Intellectual Property: All rights, title, and interest in and to the Model, including any modifications, Derivatives, and associated documentation, are and shall remain the exclusive property of the Licensor. The Licensee acknowledges that this Agreement does not transfer any ownership rights to the Licensee. All trademarks, service marks, and logos associated with the Model are the property of the Licensor.
32
+ 4.2 Output: All output generated by the Model from Licensee Data ("Output") shall be the sole property of the Licensee. Licensor hereby waives any claim of ownership or intellectual property rights to the Output. Licensee is solely responsible for the legality, accuracy, quality, integrity, and use of the Output.
33
+ 4.3 Attribution: In any publication or presentation of results obtained using the Model, the Licensee shall provide appropriate attribution to the Licensor, citing the Model's name and version, along with any relevant documentation or references specified by the Licensor.
34
+
35
+ 5. No Warranty
36
+ 5.1 โ€œAs-Isโ€ Basis: The Model, Derivatives, and Output are provided on an โ€œas-isโ€ and โ€œas-availableโ€ basis, without any warranties or representations of any kind, whether express, implied, or statutory. The Licensor disclaims all warranties, including but not limited to, implied warranties of merchantability, fitness for a particular purpose, accuracy, reliability, non-infringement, or any warranty arising from the course of dealing or usage of trade.
37
+ 5.2 Performance and Reliability: The Licensor does not warrant or guarantee that the Model, Derivatives or Output will meet the Licenseeโ€™s requirements, that the operation of the Model, Derivatives or Output will be uninterrupted or error-free, or that defects in the Model will be corrected. The Licensee acknowledges that the use of the Model, Derivatives or Output is at its own risk and that the Model, Derivatives or Output may contain bugs, errors, or other limitations.
38
+ 5.3 No Endorsement: The Licensor does not endorse, approve, or certify any results, conclusions, or recommendations derived from the use of the Model. The Licensee is solely responsible for evaluating the accuracy, reliability, and suitability of the Model for its intended purposes.
39
+
40
+ 6. Limitation of Liability
41
+ 6.1 No Liability for Damages: To the fullest extent permitted by applicable law, in no event shall the Licensor be liable for any special, incidental, indirect, consequential, exemplary, or punitive damages, including but not limited to, damages for loss of business profits, business interruption, loss of business information, loss of data, or any other pecuniary or non-pecuniary loss arising out of or in connection with the use or inability to use the Model, Derivatives or any Output, even if the Licensor has been advised of the possibility of such damages.
42
+ 6.2 Indemnification: The Licensee agrees to indemnify, defend, and hold harmless the Licensor, its affiliates, officers, directors, employees, and agents from and against any claims, liabilities, damages, losses, costs, or expenses (including reasonable attorneys' fees) arising out of or related to the Licensee's use of the Model, any Derivatives, or any Output, including any violation of this Agreement or applicable laws. This includes, but is not limited to, ensuring compliance with copyright laws, privacy regulations, defamation laws, and any other applicable legal or regulatory requirements.
43
+
44
+ 7. Termination
45
+ 7.1 Termination by Licensor: The Licensor reserves the right to terminate this Agreement and revoke the Licenseeโ€™s rights to use the Model at any time, with or without cause, and without prior notice if the Licensee breaches any of the terms or conditions of this Agreement. Termination shall be effective immediately upon notice.
46
+ 7.2 Effect of Termination: Upon termination of this Agreement, the Licensee must immediately cease all use of the Model, Derivatives, and Output and destroy all copies of the Model, Derivatives, and Output in its possession or control, including any backup or archival copies. The Licensee shall certify in writing to the Licensor that such destruction has been completed.
47
+ 7.3 Survival: The provisions of this Agreement that by their nature should survive termination, including but not limited to, Sections 4 (Ownership), 5 (No Warranty), 6 (Limitation of Liability), and this Section 7 (Termination), shall continue to apply after termination.
48
+
49
+ 8. Governing Law
50
+ 8.1 Governing Law: This Agreement shall be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflict of laws principles.
51
+ 8.2 Arbitration: Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English.
52
+
53
+ 9. Alterations
54
+ 9.1 Modifications: The Licensor reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on the Licensorโ€™s website or through other means of communication. The Licensee is responsible for reviewing the Agreement periodically for changes. Continued use of the Model after any modifications have been made constitutes acceptance of the revised Agreement.
55
+ 9.2 Entire Agreement: This Agreement constitutes the entire agreement between the Licensee and Licensor concerning the subject matter hereof and supersedes all prior or contemporaneous oral or written agreements, representations, or understandings. Any terms or conditions of any purchase order or other document submitted by the Licensee in connection with the Model that are in addition to, different from, or inconsistent with the terms and conditions of this Agreement are not binding on the Licensor and are void.
56
+
57
+ By downloading, installing, or using the EXAONEPath AI Model, the Licensee acknowledges that it has read, understood, and agrees to be bound by the terms and conditions of this Agreement.
README.md CHANGED
@@ -2,4 +2,80 @@
2
  license: other
3
  license_name: exaonepath
4
  license_link: LICENSE
 
 
 
 
5
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: other
3
  license_name: exaonepath
4
  license_link: LICENSE
5
+ tags:
6
+ - lg-ai
7
+ - EXAONE-Path-2.0
8
+ - pathology
9
  ---
10
+
11
+ # EXAONE Path 2.0
12
+
13
+ ## Introduction
14
+ In digital pathology, whole-slide images (WSIs) are often difficult to handle due to their gigapixel scale, so most approaches train patch encoders via self-supervised learning (SSL) and then aggregate the patch-level embeddings via multiple instance learning (MIL) or slide encoders for downstream tasks.
15
+ However, patch-level SSL may overlook complex domain-specific features that are essential for biomarker prediction, such as mutation status and molecular characteristics, as SSL methods rely only on basic augmentations selected for natural image domains on small patch-level area.
16
+ Moreover, SSL methods remain less data efficient than fully supervised approaches, requiring extensive computational resources and datasets to achieve competitive performance.
17
+ To address these limitations, we present EXAONE Path 2.0, a pathology foundation model that learns patch-level representations under direct slide-level supervision.
18
+ Using only 35k WSIs for training, EXAONE Path 2.0 achieves state-of-the-art average performance across 10 biomarker prediction tasks, demonstrating remarkable data efficiency.
19
+
20
+ ## Quickstart
21
+ Load EXAONE Path and run inference on tile-level images.
22
+
23
+ ### 1. Prerequisites ###
24
+ - NVIDIA GPU with 24GB+ VRAM
25
+ - Python 3.12+
26
+
27
+ Note: This implementation requires NVIDIA GPU and drivers. The provided environment setup specifically uses CUDA-enabled PyTorch, making NVIDIA GPU mandatory for running the model.
28
+
29
+ ### 2. Setup Python environment ###
30
+ ```bash
31
+ git clone https://github.com/LG-AI-EXAONE/EXAONE-Path-2.0.git
32
+ cd EXAONE-Path-2.0
33
+ pip install -r
34
+ ```
35
+
36
+ ### 3. Load the model & Inference
37
+ ```python
38
+ from exaonepath import EXAONEPathV20
39
+
40
+ hf_token = "YOUR_HUGGING_FACE_ACCESS_TOKEN"
41
+ model = EXAONEPathV20.from_pretrained("LGAI-EXAONE/EXAONE-Path-2.0", use_auth_token=hf_token)
42
+
43
+ svs_path = "YOUR_SVS_PATH"
44
+ patch_features = model(svs_path)[0]
45
+ ```
46
+
47
+ ## Model Performance Comparison
48
+
49
+ | **Benchmarks** | **TITAN** | **PRISM** | **CHIEF** | **Prov-GigaPath** | **UNI2-h** | **EXAONE Path 1.0** | **EXAONE Path 2.0** |
50
+ |---|---|---|---|---|---|---|---|
51
+ | LUAD-TMB-USA1 | 0.690 | 0.645 | 0.650 | 0.674 | 0.669 | 0.692 | 0.664 |
52
+ | LUAD-EGFR-USA1 | 0.754 | 0.815 | 0.784 | 0.709 | 0.827 | 0.784 | 0.853 |
53
+ | LUAD-KRAS-USA2 | 0.541 | 0.623 | 0.468 | 0.511 | 0.469 | 0.527 | 0.645 |
54
+ | CRC-MSI-KOR | 0.937 | 0.943 | 0.927 | 0.954 | 0.981 | 0.972 | 0.938 |
55
+ | BRCA-TP53-CPTAC | 0.788 | 0.842 | 0.788 | 0.739 | 0.808 | 0.766 | 0.757 |
56
+ | BRCA-PIK3CA-CPTAC | 0.758 | 0.893 | 0.702 | 0.735 | 0.857 | 0.735 | 0.804 |
57
+ | RCC-PBRM1-CPTAC | 0.638 | 0.557 | 0.513 | 0.527 | 0.501 | 0.526 | 0.583 |
58
+ | RCC-BAP1-CPTAC | 0.719 | 0.769 | 0.731 | 0.697 | 0.716 | 0.719 | 0.807 |
59
+ | COAD-KRAS-CPTAC | 0.764 | 0.744 | 0.699 | 0.815 | 0.943 | 0.767 | 0.912 |
60
+ | COAD-TP53-CPTAC | 0.889 | 0.816 | 0.701 | 0.712 | 0.783 | 0.819 | 0.875 |
61
+ | **Average** | 0.748 | 0.765 | 0.696 | 0.707 | 0.755 | 0.731 | **0.784** |
62
+
63
+ <br>
64
+
65
+
66
+ ## License
67
+ The model is licensed under [EXAONEPath AI Model License Agreement 1.0 - NC](./LICENSE)
68
+
69
+ <!-- ## Citation
70
+ If you find EXAONE Path 2.0 useful, please cite it using this BibTeX:
71
+ ```
72
+ @article{yun2024exaonepath,
73
+ title={EXAONE Path 2.0 Techincal Report},
74
+ author={Yun, Juseung and Hu, Yi and Kim, Jinhyung and Jang, Jongseong and Lee, Soonyoung},
75
+ journal={arXiv preprint arXiv:2408.00380},
76
+ year={2024}
77
+ } -->
78
+ ```
79
+
80
+ ## Contact
81
+ LG AI Research Technical Support: <a href="mailto:contact_us1@lgresearch.ai">contact_us1@lgresearch.ai</a>
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "small_tile_size": 256,
3
+ "large_tile_size": 4096
4
+ }
exaonepath.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing as t
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from cucim import CuImage
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from torchvision.transforms import functional as TF
9
+ from torchvision.transforms import v2 as T
10
+
11
+ from networks.vit import vit4k_base, vit_base, vit_global_base
12
+ from utils.tensor_utils import (
13
+ format_first_stg_act_as_second_stg_inp,
14
+ format_second_stg_act_as_third_stg_inp,
15
+ forward_with_batch_size_limit,
16
+ scale_and_normalize,
17
+ tile,
18
+ )
19
+ from utils.wsi_utils import load_slide_img, pack_slide, segment_tissue
20
+
21
+ if t.TYPE_CHECKING:
22
+ from _typeshed import StrPath
23
+
24
+
25
+ class PadToDivisible(T.Transform):
26
+ def __init__(self, size: int, pad_value: float | None = None):
27
+ super().__init__()
28
+ self.size = size
29
+ self.pad_value = pad_value
30
+
31
+ def transform(self, inpt, params):
32
+ assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3
33
+
34
+ H, W = inpt.shape[-2:]
35
+
36
+ pad_h = (self.size - H % self.size) % self.size
37
+ pad_w = (self.size - W % self.size) % self.size
38
+
39
+ if pad_h > 0 or pad_w > 0:
40
+ inpt = torch.nn.functional.pad(
41
+ inpt, (0, pad_w, 0, pad_h), value=self.pad_value
42
+ )
43
+
44
+ return inpt
45
+
46
+
47
+ class Preprocessing(T.Transform):
48
+ def __init__(
49
+ self, small_tile_size_with_this_mpp: int, small_tile_size_with_target_mpp: int
50
+ ):
51
+ self.small_tile_size_with_this_mpp = small_tile_size_with_this_mpp
52
+ self.small_tile_size_with_target_mpp = small_tile_size_with_target_mpp
53
+
54
+ def transform(self, inpt, params):
55
+ assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3
56
+
57
+ # Scale the input tensor to the target MPP
58
+ if self.small_tile_size_with_this_mpp != self.small_tile_size_with_target_mpp:
59
+ inpt = TF.resize(
60
+ inpt,
61
+ [
62
+ self.small_tile_size_with_target_mpp,
63
+ self.small_tile_size_with_target_mpp,
64
+ ],
65
+ )
66
+
67
+ # Normalize the input tensor
68
+ inpt = scale_and_normalize(inpt)
69
+
70
+ return inpt
71
+
72
+
73
+ class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
74
+ def __init__(
75
+ self,
76
+ small_tile_size: int = 256,
77
+ large_tile_size: int = 4096,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.small_tile_size = small_tile_size
82
+ self.large_tile_size = large_tile_size
83
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+
85
+ self.model_first_stg = vit_base().to(self.device).eval()
86
+ self.model_second_stg = vit4k_base().to(self.device).eval()
87
+ self.model_third_stg = vit_global_base().to(self.device).eval()
88
+
89
+ def forward(
90
+ self,
91
+ svs_path: "StrPath",
92
+ target_mpp: float = 0.5,
93
+ first_stg_batch_size: int = 128,
94
+ ):
95
+ small_tiles, is_tile_valid, padded_size, small_tile_size, large_tile_size = (
96
+ self._load_wsi(svs_path, target_mpp=target_mpp)
97
+ )
98
+ width, height = padded_size
99
+
100
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
101
+ with torch.no_grad():
102
+ act1 = forward_with_batch_size_limit(
103
+ self.model_first_stg,
104
+ small_tiles,
105
+ batch_size_on_gpu=first_stg_batch_size,
106
+ preproc_fn=Preprocessing(
107
+ small_tile_size_with_this_mpp=small_tile_size,
108
+ small_tile_size_with_target_mpp=self.small_tile_size,
109
+ ),
110
+ device=self.device,
111
+ out_device=self.device,
112
+ dtype=torch.bfloat16,
113
+ )
114
+ act1 = format_first_stg_act_as_second_stg_inp(
115
+ act1,
116
+ height=height,
117
+ width=width,
118
+ small_tile_size=small_tile_size,
119
+ large_tile_size=large_tile_size,
120
+ )
121
+ act2: torch.Tensor = self.model_second_stg(act1)
122
+ act2_formatted = format_second_stg_act_as_third_stg_inp(
123
+ act2,
124
+ height=height,
125
+ width=width,
126
+ large_tile_size=large_tile_size,
127
+ )
128
+ act3: torch.Tensor = self.model_third_stg(act2_formatted)
129
+ return act1[is_tile_valid], act2, act3
130
+
131
+ def _load_wsi(self, svs_path: "StrPath", target_mpp: float):
132
+ # Load WSI tile
133
+ with CuImage(str(svs_path)) as wsi_obj:
134
+ try:
135
+ mpp = float(wsi_obj.metadata["aperio"]["MPP"])
136
+ except KeyError:
137
+ print(
138
+ f"Warning: MPP metadata not found, using default value of {target_mpp}"
139
+ )
140
+ mpp = target_mpp
141
+
142
+ img = load_slide_img(wsi_obj)
143
+ height, width = img.shape[:2]
144
+ mask_tensor = torch.from_numpy(segment_tissue(svs_path, seg_level=-1)[0])
145
+ mask_tensor = TF.resize(mask_tensor.unsqueeze(0), [height, width]).squeeze(
146
+ 0
147
+ )
148
+ x: torch.Tensor = torch.from_numpy(img).permute(2, 0, 1)
149
+
150
+ small_tile_size = math.ceil(self.small_tile_size * (target_mpp / mpp))
151
+ large_tile_size = (
152
+ self.large_tile_size // self.small_tile_size
153
+ ) * small_tile_size
154
+ pad_image = PadToDivisible(large_tile_size, 255)
155
+ pad_mask = PadToDivisible(large_tile_size, 0)
156
+
157
+ x = pad_image(x)
158
+ padded_size = (x.size(-1), x.size(-2))
159
+
160
+ x = tile(x, small_tile_size)
161
+ mask_padded = pad_mask(mask_tensor.unsqueeze(0))
162
+ mask_tile = tile(mask_padded, small_tile_size).squeeze(1)
163
+ is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0
164
+
165
+ return x, is_tile_valid, padded_size, small_tile_size, large_tile_size
networks/__init__.py ADDED
File without changes
networks/vit.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ torch.set_float32_matmul_precision("high")
10
+ torch.backends.cuda.enable_flash_sdp(True)
11
+
12
+
13
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
14
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
15
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
19
+
20
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
21
+ warnings.warn(
22
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
23
+ "The distribution of values may be incorrect.",
24
+ stacklevel=2,
25
+ )
26
+
27
+ with torch.no_grad():
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a - mean) / std)
32
+ u = norm_cdf((b - mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.0))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+
51
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
52
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
53
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
54
+
55
+
56
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
57
+ if drop_prob == 0.0 or not training:
58
+ return x
59
+ keep_prob = 1 - drop_prob
60
+ shape = (x.shape[0],) + (1,) * (
61
+ x.ndim - 1
62
+ ) # work with diff dim tensors, not just 2D ConvNets
63
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
64
+ random_tensor.floor_() # binarize
65
+ output = x.div(keep_prob) * random_tensor
66
+ return output
67
+
68
+
69
+ class LayerScale(nn.Module):
70
+ def __init__(
71
+ self,
72
+ dim: int,
73
+ init_values: float = 1e-5,
74
+ inplace: bool = False,
75
+ ) -> None:
76
+ super().__init__()
77
+ self.inplace = inplace
78
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
82
+
83
+
84
+ class DropPath(nn.Module):
85
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
86
+
87
+ def __init__(self, drop_prob=None):
88
+ super(DropPath, self).__init__()
89
+ self.drop_prob = drop_prob
90
+
91
+ def forward(self, x):
92
+ return drop_path(x, self.drop_prob, self.training)
93
+
94
+
95
+ class Mlp(nn.Module):
96
+ def __init__(
97
+ self,
98
+ in_features,
99
+ hidden_features=None,
100
+ out_features=None,
101
+ act_layer=nn.GELU,
102
+ drop=0.0,
103
+ ):
104
+ super().__init__()
105
+ out_features = out_features or in_features
106
+ hidden_features = hidden_features or in_features
107
+ self.fc1 = nn.Linear(in_features, hidden_features)
108
+ self.act = act_layer()
109
+ self.fc2 = nn.Linear(hidden_features, out_features)
110
+ self.drop = nn.Dropout(drop)
111
+ self.drop_p = drop
112
+
113
+ def forward(self, x):
114
+ x = self.fc1(x)
115
+ x = self.act(x)
116
+ x = self.drop(x)
117
+ x = self.fc2(x)
118
+ x = self.drop(x)
119
+ return x
120
+
121
+
122
+ # TODO Use SelfAttention class in networks.modules
123
+ class Attention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ dim,
127
+ num_heads=8,
128
+ qkv_bias=False,
129
+ qk_scale=None,
130
+ attn_drop=0.0,
131
+ proj_drop=0.0,
132
+ ):
133
+ super().__init__()
134
+ self.dim = dim
135
+ self.num_heads = num_heads
136
+ head_dim = dim // num_heads
137
+ self.scale = qk_scale or head_dim**-0.5
138
+
139
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
140
+ self.attn_drop = nn.Dropout(attn_drop)
141
+ self.attn_drop_p = attn_drop
142
+ self.proj = nn.Linear(dim, dim)
143
+ self.proj_drop = nn.Dropout(proj_drop)
144
+ self.proj_drop_p = proj_drop
145
+
146
+ def forward(self, x):
147
+ B, N, C = x.shape
148
+ qkv = (
149
+ self.qkv(x)
150
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
151
+ .permute(2, 0, 3, 1, 4)
152
+ )
153
+ q, k, v = qkv[0], qkv[1], qkv[2]
154
+
155
+ x = F.scaled_dot_product_attention(
156
+ q, k, v, dropout_p=self.attn_drop.p, scale=self.scale
157
+ )
158
+ x = x.transpose(1, 2).reshape(B, N, C)
159
+
160
+ x = self.proj(x)
161
+ x = self.proj_drop(x)
162
+ return x
163
+
164
+
165
+ class Block(nn.Module):
166
+ def __init__(
167
+ self,
168
+ dim,
169
+ num_heads,
170
+ mlp_ratio=4.0,
171
+ qkv_bias=False,
172
+ qk_scale=None,
173
+ drop=0.0,
174
+ attn_drop=0.0,
175
+ init_values=None,
176
+ drop_path=0.0,
177
+ act_layer=nn.GELU,
178
+ norm_layer=nn.LayerNorm,
179
+ ):
180
+ super().__init__()
181
+ self.norm1 = norm_layer(dim)
182
+ self.attn = Attention(
183
+ dim,
184
+ num_heads=num_heads,
185
+ qkv_bias=qkv_bias,
186
+ qk_scale=qk_scale,
187
+ attn_drop=attn_drop,
188
+ proj_drop=drop,
189
+ )
190
+ self.ls1 = (
191
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
192
+ )
193
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
194
+ self.norm2 = norm_layer(dim)
195
+ mlp_hidden_dim = int(dim * mlp_ratio)
196
+ self.mlp = Mlp(
197
+ in_features=dim,
198
+ hidden_features=mlp_hidden_dim,
199
+ act_layer=act_layer,
200
+ drop=drop,
201
+ )
202
+ self.ls2 = (
203
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
204
+ )
205
+
206
+ def forward(self, x):
207
+ x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
208
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
209
+ return x
210
+
211
+
212
+ class PatchEmbed(nn.Module):
213
+ """Image to Patch Embedding"""
214
+
215
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
216
+ super().__init__()
217
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
218
+ self.img_size = img_size
219
+ self.patch_size = patch_size
220
+ self.num_patches = num_patches
221
+
222
+ self.proj = nn.Conv2d(
223
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
224
+ )
225
+
226
+ def forward(self, x):
227
+ B, C, H, W = x.shape
228
+ x = self.proj(x).flatten(2).transpose(1, 2)
229
+ return x
230
+
231
+
232
+ class Network(nn.Module):
233
+ emb_dim: int
234
+
235
+
236
+ class VisionTransformer(Network):
237
+ """Vision Transformer"""
238
+
239
+ def __init__(
240
+ self,
241
+ img_size=256,
242
+ patch_size=16,
243
+ in_chans=3,
244
+ num_classes=0,
245
+ embed_dim=768,
246
+ depth=12,
247
+ num_heads=12,
248
+ mlp_ratio=4.0,
249
+ qkv_bias=False,
250
+ qk_scale=None,
251
+ init_values=None, # for layerscale: None or 0 => no layerscale
252
+ drop_rate=0.0,
253
+ attn_drop_rate=0.0,
254
+ drop_path_rate=0.0,
255
+ norm_layer=nn.LayerNorm,
256
+ **kwargs
257
+ ):
258
+ super().__init__()
259
+ self.num_features = self.embed_dim = embed_dim
260
+
261
+ self.patch_embed = PatchEmbed(
262
+ img_size=img_size,
263
+ patch_size=patch_size,
264
+ in_chans=in_chans,
265
+ embed_dim=embed_dim,
266
+ )
267
+ num_patches = self.patch_embed.num_patches
268
+
269
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
270
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
271
+ self.pos_drop = nn.Dropout(p=drop_rate)
272
+
273
+ dpr = [
274
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
275
+ ] # stochastic depth decay rule
276
+ self.blocks = nn.ModuleList(
277
+ [
278
+ Block(
279
+ dim=embed_dim,
280
+ num_heads=num_heads,
281
+ mlp_ratio=mlp_ratio,
282
+ qkv_bias=qkv_bias,
283
+ qk_scale=qk_scale,
284
+ init_values=init_values,
285
+ drop=drop_rate,
286
+ attn_drop=attn_drop_rate,
287
+ drop_path=dpr[i],
288
+ norm_layer=norm_layer,
289
+ )
290
+ for i in range(depth)
291
+ ]
292
+ )
293
+ self.norm = norm_layer(embed_dim)
294
+
295
+ # Classifier head
296
+ self.head = (
297
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
298
+ )
299
+
300
+ trunc_normal_(self.pos_embed, std=0.02)
301
+ trunc_normal_(self.cls_token, std=0.02)
302
+ self.apply(self._init_weights)
303
+
304
+ def _init_weights(self, m):
305
+ if isinstance(m, nn.Linear):
306
+ trunc_normal_(m.weight, std=0.02)
307
+ if isinstance(m, nn.Linear) and m.bias is not None:
308
+ nn.init.constant_(m.bias, 0)
309
+ elif isinstance(m, nn.LayerNorm):
310
+ nn.init.constant_(m.bias, 0)
311
+ nn.init.constant_(m.weight, 1.0)
312
+
313
+ def interpolate_pos_encoding(self, x, w, h):
314
+ npatch = x.shape[1] - 1
315
+ N = self.pos_embed.shape[1] - 1
316
+ if npatch == N and w == h:
317
+ return self.pos_embed
318
+ class_pos_embed = self.pos_embed[:, 0]
319
+ patch_pos_embed = self.pos_embed[:, 1:]
320
+ dim = x.shape[-1]
321
+ w0 = w // self.patch_embed.patch_size
322
+ h0 = h // self.patch_embed.patch_size
323
+ # we add a small number to avoid floating point error in the interpolation
324
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
325
+ w0, h0 = w0 + 0.1, h0 + 0.1
326
+ patch_pos_embed = nn.functional.interpolate(
327
+ patch_pos_embed.reshape(
328
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
329
+ ).permute(0, 3, 1, 2),
330
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
331
+ mode="bicubic",
332
+ )
333
+ assert (
334
+ int(w0) == patch_pos_embed.shape[-2]
335
+ and int(h0) == patch_pos_embed.shape[-1]
336
+ )
337
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
338
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
339
+
340
+ def prepare_tokens(self, x):
341
+ B, nc, w, h = x.shape
342
+ x = self.patch_embed(x) # patch linear embedding
343
+
344
+ # add the [CLS] token to the embed patch tokens
345
+ cls_tokens = self.cls_token.expand(B, -1, -1)
346
+ x = torch.cat((cls_tokens, x), dim=1) # (B, S + 1, C)
347
+
348
+ # add positional encoding to each token
349
+ x = x + self.interpolate_pos_encoding(x, w, h)
350
+ return self.pos_drop(x)
351
+
352
+ def forward(self, x):
353
+ x = self.prepare_tokens(x)
354
+ for blk in self.blocks:
355
+ x = blk(x)
356
+ x = self.norm(x)
357
+ return x[:, 0]
358
+
359
+ def get_last_selfattention(self, x):
360
+ x = self.prepare_tokens(x)
361
+ for i, blk in enumerate(self.blocks):
362
+ if i < len(self.blocks) - 1:
363
+ x = blk(x)
364
+ else:
365
+ # return attention of the last block
366
+ return blk(x, return_attention=True)
367
+
368
+ def get_intermediate_layers(self, x, n=1):
369
+ x = self.prepare_tokens(x)
370
+ # we return the output tokens from the `n` last blocks
371
+ output = []
372
+ for i, blk in enumerate(self.blocks):
373
+ x = blk(x)
374
+ if len(self.blocks) - i <= n:
375
+ output.append(self.norm(x))
376
+ return output
377
+
378
+
379
+ class VisionTransformer4K(Network):
380
+ """Vision Transformer 4K"""
381
+
382
+ def __init__(
383
+ self,
384
+ num_classes=0,
385
+ img_size=256,
386
+ input_embed_dim=384,
387
+ output_embed_dim=192,
388
+ depth=12,
389
+ num_heads=12,
390
+ mlp_ratio=4.0,
391
+ qkv_bias=False,
392
+ qk_scale=None,
393
+ init_values=None, # for layerscale: None or 0 => no layerscale
394
+ drop_rate=0.0,
395
+ attn_drop_rate=0.0,
396
+ drop_path_rate=0.0,
397
+ norm_layer=nn.LayerNorm,
398
+ num_prototypes=64,
399
+ **kwargs
400
+ ):
401
+ super().__init__()
402
+ embed_dim = output_embed_dim
403
+ self.num_features = self.embed_dim = embed_dim
404
+ self.phi = nn.Sequential(
405
+ *[
406
+ nn.Linear(input_embed_dim, output_embed_dim),
407
+ nn.GELU(),
408
+ nn.Dropout(p=drop_rate),
409
+ ]
410
+ )
411
+ num_patches = int(img_size // 16) ** 2
412
+
413
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
414
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
415
+ self.pos_drop = nn.Dropout(p=drop_rate)
416
+
417
+ dpr = [
418
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
419
+ ] # stochastic depth decay rule
420
+ self.blocks = nn.ModuleList(
421
+ [
422
+ Block(
423
+ dim=embed_dim,
424
+ num_heads=num_heads,
425
+ mlp_ratio=mlp_ratio,
426
+ qkv_bias=qkv_bias,
427
+ qk_scale=qk_scale,
428
+ init_values=init_values,
429
+ drop=drop_rate,
430
+ attn_drop=attn_drop_rate,
431
+ drop_path=dpr[i],
432
+ norm_layer=norm_layer,
433
+ )
434
+ for i in range(depth)
435
+ ]
436
+ )
437
+ self.norm = norm_layer(embed_dim)
438
+
439
+ # Classifier head
440
+ self.head = (
441
+ nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
442
+ )
443
+
444
+ trunc_normal_(self.pos_embed, std=0.02)
445
+ trunc_normal_(self.cls_token, std=0.02)
446
+ self.apply(self._init_weights)
447
+
448
+ def _init_weights(self, m):
449
+ if isinstance(m, nn.Linear):
450
+ trunc_normal_(m.weight, std=0.02)
451
+ if isinstance(m, nn.Linear) and m.bias is not None:
452
+ nn.init.constant_(m.bias, 0)
453
+ elif isinstance(m, nn.LayerNorm):
454
+ nn.init.constant_(m.bias, 0)
455
+ nn.init.constant_(m.weight, 1.0)
456
+
457
+ def interpolate_pos_encoding(self, x, w, h):
458
+ npatch = x.shape[1] - 1
459
+ N = self.pos_embed.shape[1] - 1
460
+ if npatch == N and w == h:
461
+ return self.pos_embed
462
+ class_pos_embed = self.pos_embed[:, 0]
463
+ patch_pos_embed = self.pos_embed[:, 1:]
464
+ dim = x.shape[-1]
465
+ w0 = w // 1
466
+ h0 = h // 1
467
+ # we add a small number to avoid floating point error in the interpolation
468
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
469
+ w0, h0 = w0 + 0.1, h0 + 0.1
470
+ patch_pos_embed = nn.functional.interpolate(
471
+ patch_pos_embed.reshape(
472
+ 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
473
+ ).permute(0, 3, 1, 2),
474
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
475
+ mode="bicubic",
476
+ )
477
+ assert (
478
+ int(w0) == patch_pos_embed.shape[-2]
479
+ and int(h0) == patch_pos_embed.shape[-1]
480
+ )
481
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
482
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
483
+
484
+ def prepare_tokens(self, x):
485
+ # print('preparing tokens (after crop)', x.shape)
486
+ self.mpp_feature = x
487
+ B, embed_dim, w, h = x.shape
488
+ x = x.flatten(2, 3).transpose(1, 2)
489
+
490
+ x = self.phi(x)
491
+
492
+ # add the [CLS] token to the embed patch tokens
493
+ cls_tokens = self.cls_token.expand(B, -1, -1)
494
+ x = torch.cat((cls_tokens, x), dim=1)
495
+
496
+ # add positional encoding to each token
497
+ x = x + self.interpolate_pos_encoding(x, w, h)
498
+
499
+ return self.pos_drop(x)
500
+
501
+ def forward(self, x):
502
+ x = self.prepare_tokens(x)
503
+ for blk in self.blocks:
504
+ x = blk(x)
505
+ x = self.norm(x)
506
+ return x[:, 0]
507
+
508
+ def get_last_selfattention(self, x):
509
+ x = self.prepare_tokens(x)
510
+ for i, blk in enumerate(self.blocks):
511
+ if i < len(self.blocks) - 1:
512
+ x = blk(x)
513
+ else:
514
+ # return attention of the last block
515
+ return blk(x, return_attention=True)
516
+
517
+ def get_intermediate_layers(self, x, n=1):
518
+ x = self.prepare_tokens(x)
519
+ # we return the output tokens from the `n` last blocks
520
+ output = []
521
+ for i, blk in enumerate(self.blocks):
522
+ x = blk(x)
523
+ if len(self.blocks) - i <= n:
524
+ output.append(self.norm(x))
525
+ return output
526
+
527
+
528
+ def vit_base(patch_size=16, **kwargs):
529
+ model = VisionTransformer(
530
+ patch_size=patch_size,
531
+ embed_dim=768,
532
+ depth=12,
533
+ num_heads=12,
534
+ mlp_ratio=4,
535
+ qkv_bias=True,
536
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
537
+ **kwargs
538
+ )
539
+ return model
540
+
541
+
542
+ def vit4k_base(patch_size=16, **kwargs):
543
+ model = VisionTransformer4K(
544
+ patch_size=patch_size,
545
+ input_embed_dim=768,
546
+ output_embed_dim=768,
547
+ depth=6,
548
+ num_heads=12,
549
+ mlp_ratio=4,
550
+ qkv_bias=True,
551
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
552
+ **kwargs
553
+ )
554
+ return model
555
+
556
+
557
+ def vit_global_base(patch_size=16, **kwargs):
558
+ model = VisionTransformer4K(
559
+ patch_size=patch_size,
560
+ input_embed_dim=768,
561
+ output_embed_dim=768,
562
+ depth=2,
563
+ num_heads=6,
564
+ mlp_ratio=4,
565
+ qkv_bias=True,
566
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
567
+ **kwargs
568
+ )
569
+ return model
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17cc4f62887c3dc97380f0b824278cc16f75638f46657f3487204a0633462373
3
+ size 576596452
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+
3
+ torch==2.6.0+cu124
4
+ torchvision==0.21.0+cu124
5
+
6
+ # configs
7
+ pydantic==2.10.3
8
+
9
+ # data
10
+ openslide-bin==4.0.0.6
11
+ openslide-python==1.4.1
12
+ cucim-cu12==25.2
13
+ rectangle_packer==2.0.2
14
+ opencv-python-headless==4.11.0.86
15
+
16
+ huggingface_hub
utils/__init__.py ADDED
File without changes
utils/tensor_utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as t
2
+
3
+ import torch
4
+ import torchvision.transforms.functional as TF
5
+
6
+
7
+ def tile(x: torch.Tensor, size: int, pad_value: int | float | None = None):
8
+ C, H, W = x.shape[-3:]
9
+
10
+ pad_h = (size - H % size) % size
11
+ pad_w = (size - W % size) % size
12
+ if pad_h > 0 or pad_w > 0:
13
+ x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), value=pad_value)
14
+
15
+ nh, nw = x.size(-2) // size, x.size(-1) // size
16
+ return (
17
+ x.view(-1, C, nh, size, nw, size)
18
+ .permute(0, 2, 4, 1, 3, 5)
19
+ .reshape(-1, C, size, size)
20
+ )
21
+
22
+
23
+ def small_tiles_to_large_tiles(
24
+ small_tiles: torch.Tensor,
25
+ width: int,
26
+ large_tile_size: int,
27
+ sampled_large_tiles_idx: list | torch.Tensor | None = None,
28
+ ) -> torch.Tensor:
29
+
30
+ has_channel = small_tiles.ndim == 4
31
+ small_tile_size = small_tiles.size(-1)
32
+ num_small_tiles = small_tiles.size(0)
33
+
34
+ nw = width // small_tile_size
35
+ nh = num_small_tiles // nw
36
+
37
+ r = large_tile_size // small_tile_size
38
+
39
+ num_large_tiles = (nh // r) * (nw // r)
40
+ large_tile_indices = (
41
+ range(num_large_tiles)
42
+ if sampled_large_tiles_idx is None
43
+ else sampled_large_tiles_idx
44
+ )
45
+
46
+ tiles = []
47
+ for k in large_tile_indices:
48
+ start_row = (k // (nw // r)) * r
49
+ start_col = (k % (nw // r)) * r
50
+ for i in range(start_row, start_row + r):
51
+ for j in range(start_col, start_col + r):
52
+ tiles.append(small_tiles[i * nw + j])
53
+
54
+ stacked = torch.stack(tiles, dim=0).view(-1, r, r, *small_tiles.shape[1:])
55
+ if has_channel:
56
+ large_tiles = stacked.permute(0, 3, 1, 4, 2, 5).reshape(
57
+ -1, small_tiles.size(1), large_tile_size, large_tile_size
58
+ )
59
+ else:
60
+ large_tiles = stacked.permute(0, 1, 3, 2, 4).reshape(
61
+ -1, large_tile_size, large_tile_size
62
+ )
63
+ return large_tiles
64
+
65
+
66
+ def small_tile_flags_to_large_tile_flags(
67
+ small_tile_flags: torch.Tensor,
68
+ width: int,
69
+ small_tile_size: int,
70
+ large_tile_size: int,
71
+ aggregation: t.Literal["any", "all"] = "any",
72
+ ):
73
+ small_tile_flags = small_tile_flags.view(-1, 1, 1)
74
+ num_small_tiles = small_tile_flags.size(0)
75
+ nw = width // small_tile_size
76
+ r = large_tile_size // small_tile_size
77
+ num_large_tiles = num_small_tiles // r**2
78
+ large_tile_flags = small_tiles_to_large_tiles(
79
+ small_tile_flags,
80
+ width=nw,
81
+ large_tile_size=r,
82
+ ).view(num_large_tiles, -1)
83
+ return (
84
+ large_tile_flags.any(-1) if aggregation == "any" else large_tile_flags.all(-1)
85
+ )
86
+
87
+
88
+ def format_first_stg_act_as_second_stg_inp(
89
+ x: torch.Tensor,
90
+ height: int,
91
+ width: int,
92
+ small_tile_size: int,
93
+ large_tile_size: int,
94
+ ):
95
+ assert height % small_tile_size == 0 and width % small_tile_size == 0
96
+ D = x.size(1)
97
+ nh, nw = height // small_tile_size, width // small_tile_size
98
+ r = large_tile_size // small_tile_size
99
+ x = x.view(-1, nh, nw, D)
100
+ x = x.permute(0, 3, 1, 2).reshape(-1, D, nh // r, r, nw // r, r)
101
+ x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, D, r, r)
102
+ return x
103
+
104
+
105
+ def format_second_stg_inp_as_first_stg_act(
106
+ x: torch.Tensor, height: int, width: int, small_tile_size: int, large_tile_size: int
107
+ ):
108
+ D = x.size(1)
109
+ nh, nw = height // small_tile_size, width // small_tile_size
110
+ r = large_tile_size // small_tile_size
111
+ x = x.view(-1, nh // r, nw // r, D, r, r)
112
+ x = x.permute(0, 3, 1, 4, 2, 5).reshape(-1, D, nh, nw)
113
+ x = x.permute(0, 2, 3, 1).reshape(-1, D)
114
+ return x
115
+
116
+
117
+ def format_second_stg_act_as_third_stg_inp(
118
+ x: torch.Tensor,
119
+ height: int,
120
+ width: int,
121
+ large_tile_size: int,
122
+ ):
123
+ D = x.size(1)
124
+ nh = height // large_tile_size
125
+ nw = width // large_tile_size
126
+ return x.view(-1, nh, nw, D).permute(0, 3, 1, 2).contiguous()
127
+
128
+
129
+ def forward_with_batch_size_limit(
130
+ net,
131
+ x: torch.Tensor,
132
+ batch_size_on_gpu: int,
133
+ device: str | torch.device,
134
+ out_device: str | torch.device,
135
+ preproc_fn: t.Callable[[torch.Tensor], torch.Tensor] | None = None,
136
+ dtype: torch.dtype = torch.float32,
137
+ ):
138
+ features = list()
139
+ for start_idx in range(0, x.size(0), batch_size_on_gpu):
140
+ end_idx = min(x.size(0), start_idx + batch_size_on_gpu)
141
+ batch = x[start_idx:end_idx].to(device=device, non_blocking=True)
142
+ batch = preproc_fn(batch) if preproc_fn else batch
143
+ batch = batch.to(dtype=dtype, non_blocking=True)
144
+ actual_bs = end_idx - start_idx
145
+ batch = pad_to_batch(batch, batch_size_on_gpu)
146
+ batch: torch.Tensor = forward_compiled(net, batch)
147
+ # batch = net(batch)
148
+ features.append(batch[:actual_bs].to(device=out_device, non_blocking=True))
149
+ if torch.device(out_device).type == "cpu" and torch.device(device).type == "cuda":
150
+ torch.cuda.synchronize()
151
+ return torch.cat(features)
152
+
153
+
154
+ @t.overload
155
+ def backward_with_batch_size_limit(
156
+ net,
157
+ x: torch.Tensor,
158
+ grad: torch.Tensor,
159
+ batch_size_on_gpu: int,
160
+ device: str | torch.device,
161
+ out_device: str | torch.device,
162
+ dtype: torch.dtype,
163
+ ret_grad: t.Literal[True],
164
+ ) -> torch.Tensor: ...
165
+
166
+
167
+ @t.overload
168
+ def backward_with_batch_size_limit(
169
+ net,
170
+ x: torch.Tensor,
171
+ grad: torch.Tensor,
172
+ batch_size_on_gpu: int,
173
+ device: str | torch.device,
174
+ out_device: str | torch.device,
175
+ dtype: torch.dtype,
176
+ ret_grad: t.Literal[False],
177
+ ) -> None: ...
178
+
179
+
180
+ def backward_with_batch_size_limit(
181
+ net,
182
+ x: torch.Tensor,
183
+ grad: torch.Tensor,
184
+ batch_size_on_gpu: int,
185
+ device: str | torch.device,
186
+ out_device: str | torch.device,
187
+ dtype: torch.dtype,
188
+ ret_grad: bool,
189
+ ):
190
+ assert x.size(0) == grad.size(0)
191
+
192
+ grads = []
193
+ total = x.size(0)
194
+ for start in range(0, total, batch_size_on_gpu):
195
+ end = min(total, start + batch_size_on_gpu)
196
+ actual_bs = end - start
197
+
198
+ batch = x[start:end].to(device=device, dtype=dtype, non_blocking=True)
199
+ batch = pad_to_batch(batch, batch_size_on_gpu)
200
+ if ret_grad:
201
+ batch.requires_grad_(True)
202
+
203
+ with torch.autocast(device_type="cuda", dtype=dtype):
204
+ out = net(batch)
205
+ # out = forward_compiled(net, batch)
206
+
207
+ grad_batch = grad[start:end].to(device=device, dtype=dtype, non_blocking=True)
208
+ grad_batch = pad_to_batch(grad_batch, batch_size_on_gpu)
209
+
210
+ with torch._dynamo.utils.maybe_enable_compiled_autograd(
211
+ True, fullgraph=True, dynamic=False
212
+ ):
213
+ out.backward(grad_batch)
214
+ # out.backward(grad_batch)
215
+
216
+ if ret_grad:
217
+ assert batch.grad is not None
218
+ grads.append(batch.grad[:actual_bs].to(out_device, non_blocking=True))
219
+
220
+ if ret_grad:
221
+ if (
222
+ torch.device(out_device).type == "cpu"
223
+ and torch.device(device).type == "cuda"
224
+ ):
225
+ torch.cuda.synchronize()
226
+ return torch.cat(grads)
227
+
228
+
229
+ @torch.compile(fullgraph=True, dynamic=False)
230
+ def forward_compiled(net, x: torch.Tensor) -> torch.Tensor:
231
+ return net(x)
232
+
233
+
234
+ def pad_to_batch(t: torch.Tensor, batch_size: int) -> torch.Tensor:
235
+ assert (
236
+ t.size(0) <= batch_size
237
+ ), f"'{t.shape}' size tensor cannot be padded to be batch size of '{batch_size}'"
238
+ pad = batch_size - t.size(0)
239
+ return torch.cat([t, t.new_zeros((pad,) + t.shape[1:])], dim=0) if pad > 0 else t
240
+
241
+
242
+ def scale_and_normalize(x: torch.Tensor, inplace: bool = False):
243
+ x = x.clamp_(0, 255) if inplace else x.clamp(0, 255)
244
+ x = TF.normalize(
245
+ x / 255, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=inplace
246
+ )
247
+ return x
248
+
249
+
250
+ def combine_tile_list(tile_list: list[torch.Tensor], ncols: int):
251
+ """
252
+ Combines a flat list of tile tensors (each with shape (C, H, W)) into one output tensor,
253
+ arranging them in a grid with the specified number of columns. The tiles in the final row
254
+ or column may have different sizes.
255
+
256
+ Args:
257
+ tile_list (list of torch.Tensor): A flat list of tile tensors, each with shape
258
+ (channels, tile_height, tile_width). It is assumed
259
+ that the number of channels is consistent across all tiles.
260
+ ncols (int): Number of columns to arrange the tiles in.
261
+
262
+ Returns:
263
+ torch.Tensor: A tensor of shape (channels, total_height, total_width), where:
264
+ - total_height is the sum of maximum tile heights in each row.
265
+ - total_width is the sum of maximum tile widths in each column.
266
+ """
267
+ if not tile_list:
268
+ raise ValueError("tile_list is empty")
269
+
270
+ ntiles = len(tile_list)
271
+ nrows = (ntiles + ncols - 1) // ncols # Ceiling division to get the number of rows
272
+
273
+ # Convert the flat tile list into a nested list (rows of tiles)
274
+ nested_tiles = [tile_list[i * ncols : (i + 1) * ncols] for i in range(nrows)]
275
+
276
+ # Compute the maximum tile height for each row
277
+ row_heights = [max(tile.shape[1] for tile in row) for row in nested_tiles]
278
+
279
+ # Compute the maximum tile width for each column (consider only rows that have a tile in that column)
280
+ col_widths = []
281
+ for col in range(ncols):
282
+ max_width = 0
283
+ for row in nested_tiles:
284
+ if col < len(row):
285
+ tile_w = row[col].shape[2]
286
+ if tile_w > max_width:
287
+ max_width = tile_w
288
+ col_widths.append(max_width)
289
+
290
+ # Calculate the total output dimensions
291
+ total_height = sum(row_heights)
292
+ total_width = sum(col_widths)
293
+
294
+ # Determine the number of channels from the first tile
295
+ channels = tile_list[0].shape[0]
296
+
297
+ # Preallocate the output tensor (this avoids repeated concatenation and extra memory copies)
298
+ out_tensor = torch.zeros(
299
+ channels,
300
+ total_height,
301
+ total_width,
302
+ dtype=tile_list[0].dtype,
303
+ device=tile_list[0].device,
304
+ )
305
+
306
+ # Place each tile in its proper location by calculating offsets
307
+ y_offset = 0
308
+ for i, row in enumerate(nested_tiles):
309
+ x_offset = 0
310
+ for j, tile in enumerate(row):
311
+ tile_h, tile_w = tile.shape[1], tile.shape[2]
312
+ out_tensor[
313
+ :, y_offset : y_offset + tile_h, x_offset : x_offset + tile_w
314
+ ] = tile
315
+ x_offset += col_widths[j]
316
+ y_offset += row_heights[i]
317
+
318
+ return out_tensor
utils/wsi_utils.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as t
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from pathlib import Path
4
+ from tracemalloc import start
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import rpack
9
+ from openslide import OpenSlide
10
+ from PIL import Image
11
+ from scipy.ndimage import binary_fill_holes
12
+ from skimage import filters
13
+ from skimage.morphology import remove_small_objects
14
+
15
+ if t.TYPE_CHECKING:
16
+ from _typeshed import StrPath
17
+
18
+ try:
19
+ from skimage import img_as_ubyte # type: ignore
20
+ except:
21
+ from skimage.util import img_as_ubyte # type: ignore
22
+
23
+
24
+ def find_contours(arr: np.ndarray, only_outer: bool = True, convex: bool = False):
25
+ """Find contours in a binary image
26
+
27
+ Parameters
28
+ ----------
29
+ arr : np.ndarray
30
+ Binary image
31
+ only_outer : bool
32
+ If True, only find external contours
33
+ convex : bool
34
+ If True, return convex hull of contours
35
+
36
+ Returns
37
+ -------
38
+ contours : list
39
+ List of contours
40
+ """
41
+ mode = cv2.RETR_EXTERNAL if only_outer else cv2.RETR_LIST
42
+ cresults = cv2.findContours(arr.astype(np.uint8), mode, cv2.CHAIN_APPROX_SIMPLE)
43
+
44
+ contours = cresults[1] if len(cresults) == 3 else cresults[0]
45
+ contours = list(contours) if isinstance(contours, tuple) else contours
46
+
47
+ if convex:
48
+ contours = [cv2.convexHull(cnt) for cnt in contours]
49
+ return contours
50
+
51
+
52
+ def merge_overlapping_bboxes(bboxes: list):
53
+ """Merge overlapping bounding boxes
54
+
55
+ Parameters
56
+ ----------
57
+ bboxes : list
58
+ List of bounding boxes in format (x, y, width, height)
59
+ """
60
+ candidate_count = 0
61
+ while candidate_count < len(bboxes):
62
+ candidate_count += 1
63
+ overlap = False
64
+ candidate_box = bboxes.pop(0)
65
+ for index, compare_box in enumerate(bboxes):
66
+ overlapping, new_bbox = merge_if_overlapping(candidate_box, compare_box)
67
+ if overlapping:
68
+ overlap = True
69
+ candidate_count = 0
70
+ bboxes.pop(index)
71
+ bboxes.append(new_bbox)
72
+ break
73
+ if not overlap:
74
+ bboxes.append(candidate_box)
75
+
76
+
77
+ def merge_if_overlapping(a: tuple, b: tuple):
78
+ """Check if two bounding boxes overlap and merge them if they do
79
+
80
+ Parameters
81
+ ----------
82
+ a : tuple
83
+ First bounding box in format (x, y, width, height)
84
+ b : tuple
85
+ Second bounding box in format (x, y, width, height)
86
+
87
+ Returns
88
+ -------
89
+ overlapping : bool
90
+ True if boxes overlap
91
+ new_bbox : tuple
92
+ Merged bounding box if overlapping, empty list otherwise
93
+ """
94
+ bottom = np.max([a[0], b[0]])
95
+ top = np.min([a[0] + a[2], b[0] + b[2]])
96
+ left = np.max([a[1], b[1]])
97
+ right = np.min([a[1] + a[3], b[1] + b[3]])
98
+
99
+ do_intersect = bottom < top and left < right
100
+
101
+ if do_intersect:
102
+ x_min = np.min([a[1], b[1]])
103
+ y_min = np.min([a[0], b[0]])
104
+ x_max = np.max([a[1] + a[3], b[1] + b[3]])
105
+ y_max = np.max([a[0] + a[2], b[0] + b[2]])
106
+ new_bbox = (y_min, x_min, y_max - y_min, x_max - x_min)
107
+ return True, new_bbox
108
+
109
+ return False, []
110
+
111
+
112
+
113
+ def load_slide_img(
114
+ wsi,
115
+ level: int = 0,
116
+ ) -> np.ndarray:
117
+ """Load slide image with specific level
118
+
119
+ Parameters
120
+ ----------
121
+ wsi : CuImage
122
+ The CuImage object
123
+ level : int
124
+ Slide level to load
125
+
126
+ Returns
127
+ -------
128
+ slide_img : np.ndarray
129
+ Numpy array with RGB channels
130
+ """
131
+ slide_img = np.asarray(wsi.read_region(level=level, device="gpu", num_workers=32))
132
+ if slide_img.shape[2] == 4:
133
+ slide_img = slide_img[:, :, :-1]
134
+ return slide_img
135
+
136
+
137
+ def rgb2gray(img):
138
+ """Convert RGB image to grayscale
139
+
140
+ Parameters
141
+ ----------
142
+ img : np.ndarray
143
+ RGB image with 3 channels
144
+
145
+ Returns
146
+ -------
147
+ gray : np.ndarray
148
+ Grayscale image
149
+ """
150
+ return np.dot(img, [0.299, 0.587, 0.114])
151
+
152
+
153
+ def thresh_slide(gray, thresh_val, sigma=13):
154
+ """Threshold gray image to binary image
155
+
156
+ Parameters
157
+ ----------
158
+ gray : np.ndarray
159
+ 2D grayscale image
160
+ thresh_val : float
161
+ Thresholding value
162
+ sigma : int
163
+ Gaussian smoothing sigma
164
+
165
+ Returns
166
+ -------
167
+ bw_img : np.ndarray
168
+ Binary image
169
+ """
170
+ smooth = filters.gaussian(gray, sigma=sigma)
171
+ smooth /= np.amax(smooth)
172
+ bw_img = smooth < thresh_val
173
+ return bw_img
174
+
175
+
176
+
177
+ def get_tissue_bboxes(
178
+ mask: np.ndarray, wsi_width: int, wsi_height: int, min_tissue_size: int = 10000
179
+ ):
180
+ scale = wsi_height / mask.shape[0]
181
+
182
+ contours = find_contours(mask)
183
+ areas = []
184
+ for cnt in contours:
185
+ area = cv2.contourArea(cnt)
186
+ areas.append(area)
187
+
188
+ large_contours = []
189
+ large_areas = []
190
+ for i, cnt in enumerate(contours):
191
+ area_mm = areas[i]
192
+ if area_mm >= min_tissue_size:
193
+ large_contours.append(cnt)
194
+ large_areas.append(area_mm)
195
+
196
+ areas = large_areas
197
+
198
+ boxes = [cv2.boundingRect(c) for c in large_contours]
199
+
200
+ return (
201
+ [cv2.boundingRect(c) for c in large_contours]
202
+ if boxes
203
+ else [[0, 0, wsi_width, wsi_height]]
204
+ )
205
+
206
+
207
+ def get_tissue_positions_and_packed_size(
208
+ boxes,
209
+ wsi_width: int,
210
+ wsi_height: int,
211
+ scale: float,
212
+ ) -> tuple[list[tuple[int, int]], tuple[int, int]]:
213
+ if len(boxes) > 1:
214
+ merge_overlapping_bboxes(boxes)
215
+ boxes = np.array(boxes, dtype=np.float32) * scale
216
+ if len(boxes.shape) == 1:
217
+ boxes = boxes[None]
218
+ boxes[:, :2] = np.floor(boxes[:, :2])
219
+ boxes[:, 0] = np.clip(boxes[:, 0], 0, wsi_width - 1)
220
+ boxes[:, 1] = np.clip(boxes[:, 1], 0, wsi_height - 1)
221
+ boxes[:, 2:] = np.ceil(boxes[:, 2:])
222
+ boxes[:, 2] = np.clip(boxes[:, 2], 0, wsi_width - boxes[:, 0])
223
+ boxes[:, 3] = np.clip(boxes[:, 3], 0, wsi_height - boxes[:, 1])
224
+ boxes = boxes.astype(np.int32)
225
+
226
+ box_sizes = [(int(box[2]), int(box[3])) for box in boxes]
227
+ positions = rpack.pack(box_sizes) # at processing spacing
228
+ packed_size: tuple[int, int] = rpack.bbox_size(
229
+ box_sizes, positions
230
+ ) # width, height
231
+
232
+ counter = 0
233
+ for sdf in np.arange(0.5, 0.96, 0.05):
234
+ # asymmetry_factor = min(packed_size)/max(packed_size)
235
+ # if asymmetry_factor < sdf:
236
+ rparams = {
237
+ "max_height": int(max(packed_size) * sdf),
238
+ "max_width": int(max(packed_size) * sdf),
239
+ }
240
+ try:
241
+ positions = rpack.pack(box_sizes, **rparams) # at processing spacing
242
+ packed_size: tuple[int, int] = rpack.bbox_size(box_sizes, positions)
243
+ break
244
+ except rpack.PackingImpossibleError as ex:
245
+ counter += 1
246
+
247
+ return positions, (int(packed_size[0]), int(packed_size[1]))
248
+
249
+
250
+ def pack_slide(
251
+ wsi_arr: np.ndarray,
252
+ mask: np.ndarray,
253
+ min_tissue_size: int = 10000,
254
+ ):
255
+ H, W = wsi_arr.shape[:2]
256
+ boxes = get_tissue_bboxes(mask, W, H, min_tissue_size=min_tissue_size)
257
+ if len(boxes) > 0:
258
+ positions, packed_size = get_tissue_positions_and_packed_size(
259
+ boxes, W, H, H / mask.shape[0]
260
+ )
261
+ img_out = np.full(
262
+ (packed_size[1], packed_size[0]) + wsi_arr.shape[2:],
263
+ 255,
264
+ dtype=wsi_arr.dtype,
265
+ )
266
+ mask_out = np.zeros((packed_size[1], packed_size[0]), dtype=np.bool)
267
+ for i, pos in enumerate(positions):
268
+ box = boxes[i]
269
+ img_out[pos[1] : pos[1] + box[3], pos[0] : pos[0] + box[2]] = wsi_arr[
270
+ box[1] : box[1] + box[3], box[0] : box[0] + box[2]
271
+ ]
272
+ mask_out[pos[1] : pos[1] + box[3], pos[0] : pos[0] + box[2]] = mask[
273
+ box[1] : box[1] + box[3], box[0] : box[0] + box[2]
274
+ ]
275
+ else:
276
+ img_out = wsi_arr
277
+ mask_out = mask
278
+
279
+ return img_out, mask_out
280
+
281
+
282
+ def get_level_downsamples(wsi: OpenSlide):
283
+ level_downsamples = []
284
+ dim_0 = wsi.level_dimensions[0]
285
+
286
+ for downsample, dim in zip(wsi.level_downsamples, wsi.level_dimensions):
287
+ estimated_downsample = (dim_0[0] / float(dim[0]), dim_0[1] / float(dim[1]))
288
+ (
289
+ level_downsamples.append(estimated_downsample)
290
+ if estimated_downsample != (downsample, downsample)
291
+ else level_downsamples.append((downsample, downsample))
292
+ )
293
+
294
+ return level_downsamples
295
+
296
+
297
+ def segment_tissue(
298
+ wsi_path: Path,
299
+ seg_level=-1,
300
+ sthresh=8,
301
+ sthresh_up=255,
302
+ mthresh=7,
303
+ close=4,
304
+ filter_params={"a_t": 1, "a_h": 1, "max_n_holes": 100},
305
+ ref_patch_size=512,
306
+ ):
307
+ """
308
+ Segment the tissue via HSV -> Median thresholding -> Binary threshold
309
+ """
310
+
311
+ def _filter_contours(contours, hierarchy, filter_params):
312
+ """
313
+ Filter contours by: area.
314
+ """
315
+ filtered = []
316
+
317
+ # find indices of foreground contours (parent == -1)
318
+ hierarchy_1 = np.flatnonzero(hierarchy[:, 1] == -1)
319
+ all_holes = []
320
+
321
+ # loop through foreground contour indices
322
+ for cont_idx in hierarchy_1:
323
+ # actual contour
324
+ cont = contours[cont_idx]
325
+ # indices of holes contained in this contour (children of parent contour)
326
+ holes = np.flatnonzero(hierarchy[:, 1] == cont_idx)
327
+ # take contour area (includes holes)
328
+ a = cv2.contourArea(cont)
329
+ # calculate the contour area of each hole
330
+ hole_areas = [cv2.contourArea(contours[hole_idx]) for hole_idx in holes]
331
+ # actual area of foreground contour region
332
+ a = a - np.array(hole_areas).sum()
333
+ if a == 0:
334
+ continue
335
+ if tuple((filter_params["a_t"],)) < tuple((a,)):
336
+ filtered.append(cont_idx)
337
+ all_holes.append(holes)
338
+
339
+ foreground_contours = [contours[cont_idx] for cont_idx in filtered]
340
+
341
+ hole_contours = []
342
+
343
+ for hole_ids in all_holes:
344
+ unfiltered_holes = [contours[idx] for idx in hole_ids]
345
+ unfilered_holes = sorted(
346
+ unfiltered_holes, key=cv2.contourArea, reverse=True
347
+ )
348
+ # take max_n_holes largest holes by area
349
+ unfilered_holes = unfilered_holes[: filter_params["max_n_holes"]]
350
+ filtered_holes = []
351
+
352
+ # filter these holes
353
+ for hole in unfilered_holes:
354
+ if cv2.contourArea(hole) > filter_params["a_h"]:
355
+ filtered_holes.append(hole)
356
+
357
+ hole_contours.append(filtered_holes)
358
+
359
+ return foreground_contours, hole_contours
360
+
361
+ def draw_white_bands(img: np.ndarray, thickness: int):
362
+ height, width = img.shape[:2]
363
+ white = [255, 255, 255] # ํฐ์ƒ‰ (B, G, R)
364
+
365
+ # cv2.copyMakeBorder ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ํฐ์ƒ‰ ๋ ๋ฅผ ์ถ”๊ฐ€
366
+ # ๋‘๊ป˜ 30ํ”ฝ์…€์˜ ์œ„์ชฝ ํฐ์ƒ‰ ๋  ๊ทธ๋ฆฌ๊ธฐ
367
+ cv2.rectangle(img, (0, 0), (width, thickness), white, -1)
368
+
369
+ # ๋‘๊ป˜ 30ํ”ฝ์…€์˜ ์•„๋ž˜์ชฝ ํฐ์ƒ‰ ๋  ๊ทธ๋ฆฌ๊ธฐ
370
+ cv2.rectangle(img, (0, height - thickness), (width, height), white, -1)
371
+
372
+ # ๋‘๊ป˜ 30ํ”ฝ์…€์˜ ์™ผ์ชฝ ํฐ์ƒ‰ ๋  ๊ทธ๋ฆฌ๊ธฐ
373
+ cv2.rectangle(img, (0, 0), (thickness, height), white, -1)
374
+
375
+ # ๋‘๊ป˜ 30ํ”ฝ์…€์˜ ์˜ค๋ฅธ์ชฝ ํฐ์ƒ‰ ๋  ๊ทธ๋ฆฌ๊ธฐ
376
+ cv2.rectangle(img, (width - thickness, 0), (width, height), white, -1)
377
+
378
+ with OpenSlide(str(wsi_path)) as wsi:
379
+ if seg_level < 0:
380
+ seg_level = wsi.get_best_level_for_downsample(64)
381
+
382
+ img = np.asarray(
383
+ wsi.read_region(
384
+ location=(0, 0), level=seg_level, size=wsi.level_dimensions[seg_level]
385
+ )
386
+ )
387
+
388
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
389
+ draw_white_bands(img_rgb, thickness=20)
390
+ img_gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
391
+
392
+ H, W = img_rgb.shape[:2]
393
+
394
+ B_8, G_8, R_8 = cv2.split(img_rgb)
395
+ B = B_8.astype(np.int32)
396
+ G = G_8.astype(np.int32)
397
+ R = R_8.astype(np.int32)
398
+
399
+ mask = (R >= 0) & (R <= 110) & (G >= 0) & (G <= 110) & (B >= 0) & (B <= 110)
400
+
401
+ color_difference1 = np.abs((R) - (G)) <= 15
402
+ color_difference2 = np.abs((G) - (B)) <= 15
403
+ color_difference3 = np.abs((R) - (B)) <= 15
404
+ color_difference = color_difference1 & color_difference2 & color_difference3
405
+
406
+ final_mask = mask & color_difference
407
+
408
+ laplacian = cv2.Laplacian(img_gray, cv2.CV_64F)
409
+ laplacian_abs = cv2.convertScaleAbs(laplacian)
410
+ mask = laplacian_abs <= 15
411
+ img_rgb[mask] = [255, 255, 255]
412
+
413
+ img_hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV) # Convert to HSV space
414
+ img_med = cv2.medianBlur(
415
+ img_hsv[:, :, 1], mthresh
416
+ ) # Apply median blurring #same to median filter
417
+
418
+ # Thresholding
419
+ _, img_thresh = cv2.threshold(img_med, sthresh, sthresh_up, cv2.THRESH_BINARY)
420
+ # Morphological closing
421
+ if close > 0:
422
+ kernel = np.ones((close, close), np.uint8)
423
+ img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_CLOSE, kernel)
424
+
425
+ # before k-medicon
426
+ scale = get_level_downsamples(wsi)[seg_level]
427
+ scaled_ref_patch_area = int(ref_patch_size**2 / (scale[0] * scale[1]))
428
+ filter_params = filter_params.copy()
429
+ filter_params["a_t"] = filter_params["a_t"] * scaled_ref_patch_area
430
+ filter_params["a_h"] = filter_params["a_h"] * scaled_ref_patch_area
431
+
432
+ # Find and filter contours
433
+ contours, hierarchy = cv2.findContours(
434
+ img_thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
435
+ )
436
+
437
+ hierarchy = np.squeeze(hierarchy, axis=(0,))[:, 2:]
438
+ foreground_contours, hole_contours = _filter_contours(
439
+ contours, hierarchy, filter_params
440
+ ) # Necessary for filtering out artifacts
441
+
442
+ mask = np.zeros(img_rgb.shape[:2], dtype=np.uint8)
443
+ for i, cont in enumerate(foreground_contours):
444
+ if cont is None or len(cont) == 0:
445
+ print(f"Warning: Empty contour at index {i}")
446
+ continue
447
+
448
+ if (
449
+ cont[:, :, 0].max() >= W
450
+ or cont[:, :, 1].max() >= H
451
+ or cont[:, :, 0].min() < 0
452
+ or cont[:, :, 1].min() < 0
453
+ ):
454
+ print(f"Warning: Contour {i} coordinates out of bounds!")
455
+ continue
456
+
457
+ # Fill the main tissue contour
458
+ cv2.fillPoly(mask, [cont], 255) # type: ignore
459
+
460
+ # Remove holes if they exist
461
+ if i < len(hole_contours) and hole_contours[i]:
462
+ for hole in hole_contours[i]: # type: ignore
463
+ cv2.fillPoly(mask, [hole], 0) # type: ignore
464
+ mask = mask.astype(np.bool)
465
+ if not mask.any():
466
+ mask[:, :] = True # If no mask, return full mask
467
+
468
+ return mask, img_rgb
469
+
470
+
471
+ def get_mask_path_by_wsi_path(wsi_path: Path, wsi_dir: Path, mask_dir: Path) -> Path:
472
+ wsi_path, wsi_dir, mask_dir = (
473
+ wsi_path.absolute(),
474
+ wsi_dir.absolute(),
475
+ mask_dir.absolute(),
476
+ )
477
+ rel_path = wsi_path.relative_to(wsi_dir)
478
+ stitch_path_prefix = mask_dir / rel_path
479
+ stitch_path_prefix = stitch_path_prefix.parent / rel_path.stem
480
+ extensions = ["jpg", "jpeg", "png", "webp"]
481
+ extensions += [ext.upper() for ext in extensions]
482
+ stitch_paths = [
483
+ stitch_path_prefix.parent / (rel_path.stem + f".{ext}") for ext in extensions
484
+ ]
485
+ stitch_paths += [
486
+ stitch_path_prefix.parent / rel_path.stem / (rel_path.stem + f".{ext}")
487
+ for ext in extensions
488
+ ]
489
+ ret = None
490
+ for stitch_path in stitch_paths:
491
+ if stitch_path.exists():
492
+ ret = stitch_path
493
+ if ret is None:
494
+ raise FileNotFoundError(
495
+ f"No mask for wsi '{wsi_path}' in mask dir '{mask_dir}' (candidates: {', '.join([str(p) for p in stitch_paths])})"
496
+ )
497
+ return ret
498
+
499
+
500
+ def read_mask(mask_path: Path) -> np.ndarray:
501
+ img = Image.open(mask_path)
502
+ w, h = img.size
503
+ return np.asarray(img).reshape((h, w, -1)).max(-1) > 0
504
+
505
+
506
+ def read_mask_by_wsi_path(wsi_path: Path, wsi_dir: Path, mask_dir: Path) -> np.ndarray:
507
+ wsi_path, wsi_dir, mask_dir = (
508
+ wsi_path.absolute(),
509
+ wsi_dir.absolute(),
510
+ mask_dir.absolute(),
511
+ )
512
+ mask_path = get_mask_path_by_wsi_path(wsi_path, wsi_dir, mask_dir)
513
+ return read_mask(mask_path)
514
+