initial commit
Browse files- .gitignore +174 -0
- LICENSE +57 -0
- README.md +76 -0
- config.json +4 -0
- exaonepath.py +165 -0
- networks/__init__.py +0 -0
- networks/vit.py +569 -0
- pytorch_model.bin +3 -0
- requirements.txt +16 -0
- utils/__init__.py +0 -0
- utils/tensor_utils.py +318 -0
- utils/wsi_utils.py +514 -0
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Ruff stuff:
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
LICENSE
CHANGED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
EXAONEPath AI Model License Agreement 1.0 - NC
|
2 |
+
|
3 |
+
This License Agreement (โAgreementโ) is entered into between you (โLicenseeโ) and LG Management Development Institute Co., Ltd. (โLicensorโ), governing the use of the EXAONEPath AI Model (โModelโ). By downloading, installing, copying, or using the Model, you agree to comply with and be bound by the terms of this Agreement. If you do not agree to all the terms, you must not download, install, copy, or use the Model. This Agreement constitutes a binding legal agreement between the Licensee and Licensor.
|
4 |
+
|
5 |
+
1. Definitions
|
6 |
+
1.1 Model: The artificial intelligence model provided by Licensor, which includes any software, algorithms, machine learning models, or related components supplied by Licensor. This definition extends to encompass all updates, enhancements, improvements, bug fixes, patches, or other modifications that may be provided by Licensor from time to time, whether automatically or manually implemented.
|
7 |
+
1.2 Derivatives: Any modifications, alterations, enhancements, improvements, adaptations, or derivative works of the Model created by Licensee or any third party. This includes changes made to the Model's architecture, parameters, data processing methods, or any other aspect of the Model that results in a modification of its functionality or output.
|
8 |
+
1.3 Output: Any data, results, content, predictions, analyses, insights, or other materials generated by the Model or Derivatives, regardless of whether they are in their original form or have been further processed or modified by the Licensee. This includes, but is not limited to, textual or numerical produced directly or indirectly through the use of the Model.
|
9 |
+
1.4 Licensor: LG Management Development Institute Co., Ltd., the owner, developer, and provider of the EXAONEPath AI Model. The Licensor holds all rights, title, and interest in the Model and is responsible for granting licenses to use the Model under the terms specified in this Agreement.
|
10 |
+
1.5 Licensee: The individual, organization, corporation, academic institution, government agency, or other entity using or intending to use the Model under the terms and conditions of this Agreement. The Licensee is responsible for ensuring compliance with the Agreement by all authorized users who access or utilize the Model on behalf of the Licensee.
|
11 |
+
|
12 |
+
2. License Grant
|
13 |
+
2.1 Grant of License: Subject to the terms and conditions outlined in this Agreement, the Licensor hereby grants the Licensee a limited, non-exclusive, non-transferable, worldwide, and revocable license to:
|
14 |
+
a. Access, download, install, and use the Model solely for research purposes. This includes evaluation, testing, academic research and experimentation.
|
15 |
+
b. Publicly disclose research results and findings derived from the use of the Model or Derivatives, including publishing papers or presentations.
|
16 |
+
c. Modify the Model and create Derivatives based on the Model, provided that such modifications and Derivatives are used exclusively for research purposes. The Licensee may conduct experiments, perform analyses, and apply custom modifications to the Model to explore its capabilities and performance under various scenarios. If the Model is modified, the modified Model must include "EXAONEPath" at the beginning of its name.
|
17 |
+
d. Distribute the Model and Derivatives in each case with a copy of this Agreement.
|
18 |
+
2.2 Scope of License: The license granted herein does not authorize the Licensee to use the Model for any purpose not explicitly permitted under this Agreement. Any use beyond the scope of this license, including any commercial application or external distribution, is strictly prohibited unless explicitly agreed upon in writing by the Licensor.
|
19 |
+
|
20 |
+
3. Restrictions
|
21 |
+
3.1 Commercial Use: The Licensee is expressly prohibited from using the Model, Derivatives, or Output for any commercial purposes, including but not limited to, developing or deploying products, services, or applications that generate revenue, whether directly or indirectly. Any commercial exploitation of the Model or its derivatives requires a separate commercial license agreement with the Licensor. Furthermore, the Licensee shall not use the Model, Derivatives or Output to develop or improve other models, except for research purposes, which is explicitly permitted.
|
22 |
+
3.2 Reverse Engineering: The Licensee shall not decompile, disassemble, reverse engineer, or attempt to derive the source code, underlying ideas, algorithms, or structure of the Model, except to the extent that such activities are expressly permitted by applicable law. Any attempt to bypass or circumvent technological protection measures applied to the Model is strictly prohibited.
|
23 |
+
3.3 Unlawful Use: The Licensee shall not use the Model and Derivatives for any illegal, fraudulent, or unauthorized activities, nor for any purpose that violates applicable laws or regulations. This includes but is not limited to the creation, distribution, or dissemination of malicious, deceptive, or unlawful content.
|
24 |
+
3.4 Ethical Use: The Licensee shall ensure that the Model or Derivatives is used in an ethical and responsible manner, adhering to the following guidelines:
|
25 |
+
a. The Model and Derivatives shall not be used to generate, propagate, or amplify false, misleading, or harmful information, including fake news, misinformation, or disinformation.
|
26 |
+
b. The Model and Derivatives shall not be employed to create, distribute, or promote content that is discriminatory, harassing, defamatory, abusive, or otherwise offensive to individuals or groups based on race, gender, sexual orientation, religion, nationality, or other protected characteristics.
|
27 |
+
c. The Model and Derivatives shall not infringe on the rights of others, including intellectual property rights, privacy rights, or any other rights recognized by law. The Licensee shall obtain all necessary permissions and consents before using the Model and Derivatives in a manner that may impact the rights of third parties.
|
28 |
+
d. The Model and Derivatives shall not be used in a way that causes harm, whether physical, mental, emotional, or financial, to individuals, organizations, or communities. The Licensee shall take all reasonable measures to prevent misuse or abuse of the Model and Derivatives that could result in harm or injury.
|
29 |
+
|
30 |
+
4. Ownership
|
31 |
+
4.1 Intellectual Property: All rights, title, and interest in and to the Model, including any modifications, Derivatives, and associated documentation, are and shall remain the exclusive property of the Licensor. The Licensee acknowledges that this Agreement does not transfer any ownership rights to the Licensee. All trademarks, service marks, and logos associated with the Model are the property of the Licensor.
|
32 |
+
4.2 Output: All output generated by the Model from Licensee Data ("Output") shall be the sole property of the Licensee. Licensor hereby waives any claim of ownership or intellectual property rights to the Output. Licensee is solely responsible for the legality, accuracy, quality, integrity, and use of the Output.
|
33 |
+
4.3 Attribution: In any publication or presentation of results obtained using the Model, the Licensee shall provide appropriate attribution to the Licensor, citing the Model's name and version, along with any relevant documentation or references specified by the Licensor.
|
34 |
+
|
35 |
+
5. No Warranty
|
36 |
+
5.1 โAs-Isโ Basis: The Model, Derivatives, and Output are provided on an โas-isโ and โas-availableโ basis, without any warranties or representations of any kind, whether express, implied, or statutory. The Licensor disclaims all warranties, including but not limited to, implied warranties of merchantability, fitness for a particular purpose, accuracy, reliability, non-infringement, or any warranty arising from the course of dealing or usage of trade.
|
37 |
+
5.2 Performance and Reliability: The Licensor does not warrant or guarantee that the Model, Derivatives or Output will meet the Licenseeโs requirements, that the operation of the Model, Derivatives or Output will be uninterrupted or error-free, or that defects in the Model will be corrected. The Licensee acknowledges that the use of the Model, Derivatives or Output is at its own risk and that the Model, Derivatives or Output may contain bugs, errors, or other limitations.
|
38 |
+
5.3 No Endorsement: The Licensor does not endorse, approve, or certify any results, conclusions, or recommendations derived from the use of the Model. The Licensee is solely responsible for evaluating the accuracy, reliability, and suitability of the Model for its intended purposes.
|
39 |
+
|
40 |
+
6. Limitation of Liability
|
41 |
+
6.1 No Liability for Damages: To the fullest extent permitted by applicable law, in no event shall the Licensor be liable for any special, incidental, indirect, consequential, exemplary, or punitive damages, including but not limited to, damages for loss of business profits, business interruption, loss of business information, loss of data, or any other pecuniary or non-pecuniary loss arising out of or in connection with the use or inability to use the Model, Derivatives or any Output, even if the Licensor has been advised of the possibility of such damages.
|
42 |
+
6.2 Indemnification: The Licensee agrees to indemnify, defend, and hold harmless the Licensor, its affiliates, officers, directors, employees, and agents from and against any claims, liabilities, damages, losses, costs, or expenses (including reasonable attorneys' fees) arising out of or related to the Licensee's use of the Model, any Derivatives, or any Output, including any violation of this Agreement or applicable laws. This includes, but is not limited to, ensuring compliance with copyright laws, privacy regulations, defamation laws, and any other applicable legal or regulatory requirements.
|
43 |
+
|
44 |
+
7. Termination
|
45 |
+
7.1 Termination by Licensor: The Licensor reserves the right to terminate this Agreement and revoke the Licenseeโs rights to use the Model at any time, with or without cause, and without prior notice if the Licensee breaches any of the terms or conditions of this Agreement. Termination shall be effective immediately upon notice.
|
46 |
+
7.2 Effect of Termination: Upon termination of this Agreement, the Licensee must immediately cease all use of the Model, Derivatives, and Output and destroy all copies of the Model, Derivatives, and Output in its possession or control, including any backup or archival copies. The Licensee shall certify in writing to the Licensor that such destruction has been completed.
|
47 |
+
7.3 Survival: The provisions of this Agreement that by their nature should survive termination, including but not limited to, Sections 4 (Ownership), 5 (No Warranty), 6 (Limitation of Liability), and this Section 7 (Termination), shall continue to apply after termination.
|
48 |
+
|
49 |
+
8. Governing Law
|
50 |
+
8.1 Governing Law: This Agreement shall be governed by and construed in accordance with the laws of the Republic of Korea, without regard to its conflict of laws principles.
|
51 |
+
8.2 Arbitration: Any disputes, controversies, or claims arising out of or relating to this Agreement, including its existence, validity, interpretation, performance, breach, or termination, shall be referred to and finally resolved by arbitration administered by the Korean Commercial Arbitration Board (KCAB) in accordance with the International Arbitration Rules of the Korean Commercial Arbitration Board in force at the time of the commencement of the arbitration. The seat of arbitration shall be Seoul, Republic of Korea. The tribunal shall consist of one arbitrator. The language of the arbitration shall be English.
|
52 |
+
|
53 |
+
9. Alterations
|
54 |
+
9.1 Modifications: The Licensor reserves the right to modify or amend this Agreement at any time, in its sole discretion. Any modifications will be effective upon posting the updated Agreement on the Licensorโs website or through other means of communication. The Licensee is responsible for reviewing the Agreement periodically for changes. Continued use of the Model after any modifications have been made constitutes acceptance of the revised Agreement.
|
55 |
+
9.2 Entire Agreement: This Agreement constitutes the entire agreement between the Licensee and Licensor concerning the subject matter hereof and supersedes all prior or contemporaneous oral or written agreements, representations, or understandings. Any terms or conditions of any purchase order or other document submitted by the Licensee in connection with the Model that are in addition to, different from, or inconsistent with the terms and conditions of this Agreement are not binding on the Licensor and are void.
|
56 |
+
|
57 |
+
By downloading, installing, or using the EXAONEPath AI Model, the Licensee acknowledges that it has read, understood, and agrees to be bound by the terms and conditions of this Agreement.
|
README.md
CHANGED
@@ -2,4 +2,80 @@
|
|
2 |
license: other
|
3 |
license_name: exaonepath
|
4 |
license_link: LICENSE
|
|
|
|
|
|
|
|
|
5 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: other
|
3 |
license_name: exaonepath
|
4 |
license_link: LICENSE
|
5 |
+
tags:
|
6 |
+
- lg-ai
|
7 |
+
- EXAONE-Path-2.0
|
8 |
+
- pathology
|
9 |
---
|
10 |
+
|
11 |
+
# EXAONE Path 2.0
|
12 |
+
|
13 |
+
## Introduction
|
14 |
+
In digital pathology, whole-slide images (WSIs) are often difficult to handle due to their gigapixel scale, so most approaches train patch encoders via self-supervised learning (SSL) and then aggregate the patch-level embeddings via multiple instance learning (MIL) or slide encoders for downstream tasks.
|
15 |
+
However, patch-level SSL may overlook complex domain-specific features that are essential for biomarker prediction, such as mutation status and molecular characteristics, as SSL methods rely only on basic augmentations selected for natural image domains on small patch-level area.
|
16 |
+
Moreover, SSL methods remain less data efficient than fully supervised approaches, requiring extensive computational resources and datasets to achieve competitive performance.
|
17 |
+
To address these limitations, we present EXAONE Path 2.0, a pathology foundation model that learns patch-level representations under direct slide-level supervision.
|
18 |
+
Using only 35k WSIs for training, EXAONE Path 2.0 achieves state-of-the-art average performance across 10 biomarker prediction tasks, demonstrating remarkable data efficiency.
|
19 |
+
|
20 |
+
## Quickstart
|
21 |
+
Load EXAONE Path and run inference on tile-level images.
|
22 |
+
|
23 |
+
### 1. Prerequisites ###
|
24 |
+
- NVIDIA GPU with 24GB+ VRAM
|
25 |
+
- Python 3.12+
|
26 |
+
|
27 |
+
Note: This implementation requires NVIDIA GPU and drivers. The provided environment setup specifically uses CUDA-enabled PyTorch, making NVIDIA GPU mandatory for running the model.
|
28 |
+
|
29 |
+
### 2. Setup Python environment ###
|
30 |
+
```bash
|
31 |
+
git clone https://github.com/LG-AI-EXAONE/EXAONE-Path-2.0.git
|
32 |
+
cd EXAONE-Path-2.0
|
33 |
+
pip install -r
|
34 |
+
```
|
35 |
+
|
36 |
+
### 3. Load the model & Inference
|
37 |
+
```python
|
38 |
+
from exaonepath import EXAONEPathV20
|
39 |
+
|
40 |
+
hf_token = "YOUR_HUGGING_FACE_ACCESS_TOKEN"
|
41 |
+
model = EXAONEPathV20.from_pretrained("LGAI-EXAONE/EXAONE-Path-2.0", use_auth_token=hf_token)
|
42 |
+
|
43 |
+
svs_path = "YOUR_SVS_PATH"
|
44 |
+
patch_features = model(svs_path)[0]
|
45 |
+
```
|
46 |
+
|
47 |
+
## Model Performance Comparison
|
48 |
+
|
49 |
+
| **Benchmarks** | **TITAN** | **PRISM** | **CHIEF** | **Prov-GigaPath** | **UNI2-h** | **EXAONE Path 1.0** | **EXAONE Path 2.0** |
|
50 |
+
|---|---|---|---|---|---|---|---|
|
51 |
+
| LUAD-TMB-USA1 | 0.690 | 0.645 | 0.650 | 0.674 | 0.669 | 0.692 | 0.664 |
|
52 |
+
| LUAD-EGFR-USA1 | 0.754 | 0.815 | 0.784 | 0.709 | 0.827 | 0.784 | 0.853 |
|
53 |
+
| LUAD-KRAS-USA2 | 0.541 | 0.623 | 0.468 | 0.511 | 0.469 | 0.527 | 0.645 |
|
54 |
+
| CRC-MSI-KOR | 0.937 | 0.943 | 0.927 | 0.954 | 0.981 | 0.972 | 0.938 |
|
55 |
+
| BRCA-TP53-CPTAC | 0.788 | 0.842 | 0.788 | 0.739 | 0.808 | 0.766 | 0.757 |
|
56 |
+
| BRCA-PIK3CA-CPTAC | 0.758 | 0.893 | 0.702 | 0.735 | 0.857 | 0.735 | 0.804 |
|
57 |
+
| RCC-PBRM1-CPTAC | 0.638 | 0.557 | 0.513 | 0.527 | 0.501 | 0.526 | 0.583 |
|
58 |
+
| RCC-BAP1-CPTAC | 0.719 | 0.769 | 0.731 | 0.697 | 0.716 | 0.719 | 0.807 |
|
59 |
+
| COAD-KRAS-CPTAC | 0.764 | 0.744 | 0.699 | 0.815 | 0.943 | 0.767 | 0.912 |
|
60 |
+
| COAD-TP53-CPTAC | 0.889 | 0.816 | 0.701 | 0.712 | 0.783 | 0.819 | 0.875 |
|
61 |
+
| **Average** | 0.748 | 0.765 | 0.696 | 0.707 | 0.755 | 0.731 | **0.784** |
|
62 |
+
|
63 |
+
<br>
|
64 |
+
|
65 |
+
|
66 |
+
## License
|
67 |
+
The model is licensed under [EXAONEPath AI Model License Agreement 1.0 - NC](./LICENSE)
|
68 |
+
|
69 |
+
<!-- ## Citation
|
70 |
+
If you find EXAONE Path 2.0 useful, please cite it using this BibTeX:
|
71 |
+
```
|
72 |
+
@article{yun2024exaonepath,
|
73 |
+
title={EXAONE Path 2.0 Techincal Report},
|
74 |
+
author={Yun, Juseung and Hu, Yi and Kim, Jinhyung and Jang, Jongseong and Lee, Soonyoung},
|
75 |
+
journal={arXiv preprint arXiv:2408.00380},
|
76 |
+
year={2024}
|
77 |
+
} -->
|
78 |
+
```
|
79 |
+
|
80 |
+
## Contact
|
81 |
+
LG AI Research Technical Support: <a href="mailto:contact_us1@lgresearch.ai">contact_us1@lgresearch.ai</a>
|
config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"small_tile_size": 256,
|
3 |
+
"large_tile_size": 4096
|
4 |
+
}
|
exaonepath.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import typing as t
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from cucim import CuImage
|
7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
8 |
+
from torchvision.transforms import functional as TF
|
9 |
+
from torchvision.transforms import v2 as T
|
10 |
+
|
11 |
+
from networks.vit import vit4k_base, vit_base, vit_global_base
|
12 |
+
from utils.tensor_utils import (
|
13 |
+
format_first_stg_act_as_second_stg_inp,
|
14 |
+
format_second_stg_act_as_third_stg_inp,
|
15 |
+
forward_with_batch_size_limit,
|
16 |
+
scale_and_normalize,
|
17 |
+
tile,
|
18 |
+
)
|
19 |
+
from utils.wsi_utils import load_slide_img, pack_slide, segment_tissue
|
20 |
+
|
21 |
+
if t.TYPE_CHECKING:
|
22 |
+
from _typeshed import StrPath
|
23 |
+
|
24 |
+
|
25 |
+
class PadToDivisible(T.Transform):
|
26 |
+
def __init__(self, size: int, pad_value: float | None = None):
|
27 |
+
super().__init__()
|
28 |
+
self.size = size
|
29 |
+
self.pad_value = pad_value
|
30 |
+
|
31 |
+
def transform(self, inpt, params):
|
32 |
+
assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3
|
33 |
+
|
34 |
+
H, W = inpt.shape[-2:]
|
35 |
+
|
36 |
+
pad_h = (self.size - H % self.size) % self.size
|
37 |
+
pad_w = (self.size - W % self.size) % self.size
|
38 |
+
|
39 |
+
if pad_h > 0 or pad_w > 0:
|
40 |
+
inpt = torch.nn.functional.pad(
|
41 |
+
inpt, (0, pad_w, 0, pad_h), value=self.pad_value
|
42 |
+
)
|
43 |
+
|
44 |
+
return inpt
|
45 |
+
|
46 |
+
|
47 |
+
class Preprocessing(T.Transform):
|
48 |
+
def __init__(
|
49 |
+
self, small_tile_size_with_this_mpp: int, small_tile_size_with_target_mpp: int
|
50 |
+
):
|
51 |
+
self.small_tile_size_with_this_mpp = small_tile_size_with_this_mpp
|
52 |
+
self.small_tile_size_with_target_mpp = small_tile_size_with_target_mpp
|
53 |
+
|
54 |
+
def transform(self, inpt, params):
|
55 |
+
assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3
|
56 |
+
|
57 |
+
# Scale the input tensor to the target MPP
|
58 |
+
if self.small_tile_size_with_this_mpp != self.small_tile_size_with_target_mpp:
|
59 |
+
inpt = TF.resize(
|
60 |
+
inpt,
|
61 |
+
[
|
62 |
+
self.small_tile_size_with_target_mpp,
|
63 |
+
self.small_tile_size_with_target_mpp,
|
64 |
+
],
|
65 |
+
)
|
66 |
+
|
67 |
+
# Normalize the input tensor
|
68 |
+
inpt = scale_and_normalize(inpt)
|
69 |
+
|
70 |
+
return inpt
|
71 |
+
|
72 |
+
|
73 |
+
class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
small_tile_size: int = 256,
|
77 |
+
large_tile_size: int = 4096,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self.small_tile_size = small_tile_size
|
82 |
+
self.large_tile_size = large_tile_size
|
83 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
84 |
+
|
85 |
+
self.model_first_stg = vit_base().to(self.device).eval()
|
86 |
+
self.model_second_stg = vit4k_base().to(self.device).eval()
|
87 |
+
self.model_third_stg = vit_global_base().to(self.device).eval()
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
svs_path: "StrPath",
|
92 |
+
target_mpp: float = 0.5,
|
93 |
+
first_stg_batch_size: int = 128,
|
94 |
+
):
|
95 |
+
small_tiles, is_tile_valid, padded_size, small_tile_size, large_tile_size = (
|
96 |
+
self._load_wsi(svs_path, target_mpp=target_mpp)
|
97 |
+
)
|
98 |
+
width, height = padded_size
|
99 |
+
|
100 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
101 |
+
with torch.no_grad():
|
102 |
+
act1 = forward_with_batch_size_limit(
|
103 |
+
self.model_first_stg,
|
104 |
+
small_tiles,
|
105 |
+
batch_size_on_gpu=first_stg_batch_size,
|
106 |
+
preproc_fn=Preprocessing(
|
107 |
+
small_tile_size_with_this_mpp=small_tile_size,
|
108 |
+
small_tile_size_with_target_mpp=self.small_tile_size,
|
109 |
+
),
|
110 |
+
device=self.device,
|
111 |
+
out_device=self.device,
|
112 |
+
dtype=torch.bfloat16,
|
113 |
+
)
|
114 |
+
act1 = format_first_stg_act_as_second_stg_inp(
|
115 |
+
act1,
|
116 |
+
height=height,
|
117 |
+
width=width,
|
118 |
+
small_tile_size=small_tile_size,
|
119 |
+
large_tile_size=large_tile_size,
|
120 |
+
)
|
121 |
+
act2: torch.Tensor = self.model_second_stg(act1)
|
122 |
+
act2_formatted = format_second_stg_act_as_third_stg_inp(
|
123 |
+
act2,
|
124 |
+
height=height,
|
125 |
+
width=width,
|
126 |
+
large_tile_size=large_tile_size,
|
127 |
+
)
|
128 |
+
act3: torch.Tensor = self.model_third_stg(act2_formatted)
|
129 |
+
return act1[is_tile_valid], act2, act3
|
130 |
+
|
131 |
+
def _load_wsi(self, svs_path: "StrPath", target_mpp: float):
|
132 |
+
# Load WSI tile
|
133 |
+
with CuImage(str(svs_path)) as wsi_obj:
|
134 |
+
try:
|
135 |
+
mpp = float(wsi_obj.metadata["aperio"]["MPP"])
|
136 |
+
except KeyError:
|
137 |
+
print(
|
138 |
+
f"Warning: MPP metadata not found, using default value of {target_mpp}"
|
139 |
+
)
|
140 |
+
mpp = target_mpp
|
141 |
+
|
142 |
+
img = load_slide_img(wsi_obj)
|
143 |
+
height, width = img.shape[:2]
|
144 |
+
mask_tensor = torch.from_numpy(segment_tissue(svs_path, seg_level=-1)[0])
|
145 |
+
mask_tensor = TF.resize(mask_tensor.unsqueeze(0), [height, width]).squeeze(
|
146 |
+
0
|
147 |
+
)
|
148 |
+
x: torch.Tensor = torch.from_numpy(img).permute(2, 0, 1)
|
149 |
+
|
150 |
+
small_tile_size = math.ceil(self.small_tile_size * (target_mpp / mpp))
|
151 |
+
large_tile_size = (
|
152 |
+
self.large_tile_size // self.small_tile_size
|
153 |
+
) * small_tile_size
|
154 |
+
pad_image = PadToDivisible(large_tile_size, 255)
|
155 |
+
pad_mask = PadToDivisible(large_tile_size, 0)
|
156 |
+
|
157 |
+
x = pad_image(x)
|
158 |
+
padded_size = (x.size(-1), x.size(-2))
|
159 |
+
|
160 |
+
x = tile(x, small_tile_size)
|
161 |
+
mask_padded = pad_mask(mask_tensor.unsqueeze(0))
|
162 |
+
mask_tile = tile(mask_padded, small_tile_size).squeeze(1)
|
163 |
+
is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0
|
164 |
+
|
165 |
+
return x, is_tile_valid, padded_size, small_tile_size, large_tile_size
|
networks/__init__.py
ADDED
File without changes
|
networks/vit.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
torch.set_float32_matmul_precision("high")
|
10 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
11 |
+
|
12 |
+
|
13 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
14 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
15 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
16 |
+
def norm_cdf(x):
|
17 |
+
# Computes standard normal cumulative distribution function
|
18 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
19 |
+
|
20 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
21 |
+
warnings.warn(
|
22 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
23 |
+
"The distribution of values may be incorrect.",
|
24 |
+
stacklevel=2,
|
25 |
+
)
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
# Values are generated by using a truncated uniform distribution and
|
29 |
+
# then using the inverse CDF for the normal distribution.
|
30 |
+
# Get upper and lower cdf values
|
31 |
+
l = norm_cdf((a - mean) / std)
|
32 |
+
u = norm_cdf((b - mean) / std)
|
33 |
+
|
34 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
35 |
+
# [2l-1, 2u-1].
|
36 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
37 |
+
|
38 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
39 |
+
# standard normal
|
40 |
+
tensor.erfinv_()
|
41 |
+
|
42 |
+
# Transform to proper mean, std
|
43 |
+
tensor.mul_(std * math.sqrt(2.0))
|
44 |
+
tensor.add_(mean)
|
45 |
+
|
46 |
+
# Clamp to ensure it's in the proper range
|
47 |
+
tensor.clamp_(min=a, max=b)
|
48 |
+
return tensor
|
49 |
+
|
50 |
+
|
51 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
52 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
53 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
54 |
+
|
55 |
+
|
56 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
57 |
+
if drop_prob == 0.0 or not training:
|
58 |
+
return x
|
59 |
+
keep_prob = 1 - drop_prob
|
60 |
+
shape = (x.shape[0],) + (1,) * (
|
61 |
+
x.ndim - 1
|
62 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
63 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
64 |
+
random_tensor.floor_() # binarize
|
65 |
+
output = x.div(keep_prob) * random_tensor
|
66 |
+
return output
|
67 |
+
|
68 |
+
|
69 |
+
class LayerScale(nn.Module):
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
dim: int,
|
73 |
+
init_values: float = 1e-5,
|
74 |
+
inplace: bool = False,
|
75 |
+
) -> None:
|
76 |
+
super().__init__()
|
77 |
+
self.inplace = inplace
|
78 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
79 |
+
|
80 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
81 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
82 |
+
|
83 |
+
|
84 |
+
class DropPath(nn.Module):
|
85 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
86 |
+
|
87 |
+
def __init__(self, drop_prob=None):
|
88 |
+
super(DropPath, self).__init__()
|
89 |
+
self.drop_prob = drop_prob
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
return drop_path(x, self.drop_prob, self.training)
|
93 |
+
|
94 |
+
|
95 |
+
class Mlp(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
in_features,
|
99 |
+
hidden_features=None,
|
100 |
+
out_features=None,
|
101 |
+
act_layer=nn.GELU,
|
102 |
+
drop=0.0,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
out_features = out_features or in_features
|
106 |
+
hidden_features = hidden_features or in_features
|
107 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
108 |
+
self.act = act_layer()
|
109 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
110 |
+
self.drop = nn.Dropout(drop)
|
111 |
+
self.drop_p = drop
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
x = self.fc1(x)
|
115 |
+
x = self.act(x)
|
116 |
+
x = self.drop(x)
|
117 |
+
x = self.fc2(x)
|
118 |
+
x = self.drop(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
# TODO Use SelfAttention class in networks.modules
|
123 |
+
class Attention(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
dim,
|
127 |
+
num_heads=8,
|
128 |
+
qkv_bias=False,
|
129 |
+
qk_scale=None,
|
130 |
+
attn_drop=0.0,
|
131 |
+
proj_drop=0.0,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
self.dim = dim
|
135 |
+
self.num_heads = num_heads
|
136 |
+
head_dim = dim // num_heads
|
137 |
+
self.scale = qk_scale or head_dim**-0.5
|
138 |
+
|
139 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
140 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
141 |
+
self.attn_drop_p = attn_drop
|
142 |
+
self.proj = nn.Linear(dim, dim)
|
143 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
144 |
+
self.proj_drop_p = proj_drop
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
B, N, C = x.shape
|
148 |
+
qkv = (
|
149 |
+
self.qkv(x)
|
150 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
151 |
+
.permute(2, 0, 3, 1, 4)
|
152 |
+
)
|
153 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
154 |
+
|
155 |
+
x = F.scaled_dot_product_attention(
|
156 |
+
q, k, v, dropout_p=self.attn_drop.p, scale=self.scale
|
157 |
+
)
|
158 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
159 |
+
|
160 |
+
x = self.proj(x)
|
161 |
+
x = self.proj_drop(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class Block(nn.Module):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
dim,
|
169 |
+
num_heads,
|
170 |
+
mlp_ratio=4.0,
|
171 |
+
qkv_bias=False,
|
172 |
+
qk_scale=None,
|
173 |
+
drop=0.0,
|
174 |
+
attn_drop=0.0,
|
175 |
+
init_values=None,
|
176 |
+
drop_path=0.0,
|
177 |
+
act_layer=nn.GELU,
|
178 |
+
norm_layer=nn.LayerNorm,
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
self.norm1 = norm_layer(dim)
|
182 |
+
self.attn = Attention(
|
183 |
+
dim,
|
184 |
+
num_heads=num_heads,
|
185 |
+
qkv_bias=qkv_bias,
|
186 |
+
qk_scale=qk_scale,
|
187 |
+
attn_drop=attn_drop,
|
188 |
+
proj_drop=drop,
|
189 |
+
)
|
190 |
+
self.ls1 = (
|
191 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
192 |
+
)
|
193 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
194 |
+
self.norm2 = norm_layer(dim)
|
195 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
196 |
+
self.mlp = Mlp(
|
197 |
+
in_features=dim,
|
198 |
+
hidden_features=mlp_hidden_dim,
|
199 |
+
act_layer=act_layer,
|
200 |
+
drop=drop,
|
201 |
+
)
|
202 |
+
self.ls2 = (
|
203 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
204 |
+
)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
|
208 |
+
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
|
209 |
+
return x
|
210 |
+
|
211 |
+
|
212 |
+
class PatchEmbed(nn.Module):
|
213 |
+
"""Image to Patch Embedding"""
|
214 |
+
|
215 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
216 |
+
super().__init__()
|
217 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
218 |
+
self.img_size = img_size
|
219 |
+
self.patch_size = patch_size
|
220 |
+
self.num_patches = num_patches
|
221 |
+
|
222 |
+
self.proj = nn.Conv2d(
|
223 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
224 |
+
)
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
B, C, H, W = x.shape
|
228 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
class Network(nn.Module):
|
233 |
+
emb_dim: int
|
234 |
+
|
235 |
+
|
236 |
+
class VisionTransformer(Network):
|
237 |
+
"""Vision Transformer"""
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
img_size=256,
|
242 |
+
patch_size=16,
|
243 |
+
in_chans=3,
|
244 |
+
num_classes=0,
|
245 |
+
embed_dim=768,
|
246 |
+
depth=12,
|
247 |
+
num_heads=12,
|
248 |
+
mlp_ratio=4.0,
|
249 |
+
qkv_bias=False,
|
250 |
+
qk_scale=None,
|
251 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
252 |
+
drop_rate=0.0,
|
253 |
+
attn_drop_rate=0.0,
|
254 |
+
drop_path_rate=0.0,
|
255 |
+
norm_layer=nn.LayerNorm,
|
256 |
+
**kwargs
|
257 |
+
):
|
258 |
+
super().__init__()
|
259 |
+
self.num_features = self.embed_dim = embed_dim
|
260 |
+
|
261 |
+
self.patch_embed = PatchEmbed(
|
262 |
+
img_size=img_size,
|
263 |
+
patch_size=patch_size,
|
264 |
+
in_chans=in_chans,
|
265 |
+
embed_dim=embed_dim,
|
266 |
+
)
|
267 |
+
num_patches = self.patch_embed.num_patches
|
268 |
+
|
269 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
270 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
271 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
272 |
+
|
273 |
+
dpr = [
|
274 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
275 |
+
] # stochastic depth decay rule
|
276 |
+
self.blocks = nn.ModuleList(
|
277 |
+
[
|
278 |
+
Block(
|
279 |
+
dim=embed_dim,
|
280 |
+
num_heads=num_heads,
|
281 |
+
mlp_ratio=mlp_ratio,
|
282 |
+
qkv_bias=qkv_bias,
|
283 |
+
qk_scale=qk_scale,
|
284 |
+
init_values=init_values,
|
285 |
+
drop=drop_rate,
|
286 |
+
attn_drop=attn_drop_rate,
|
287 |
+
drop_path=dpr[i],
|
288 |
+
norm_layer=norm_layer,
|
289 |
+
)
|
290 |
+
for i in range(depth)
|
291 |
+
]
|
292 |
+
)
|
293 |
+
self.norm = norm_layer(embed_dim)
|
294 |
+
|
295 |
+
# Classifier head
|
296 |
+
self.head = (
|
297 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
298 |
+
)
|
299 |
+
|
300 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
301 |
+
trunc_normal_(self.cls_token, std=0.02)
|
302 |
+
self.apply(self._init_weights)
|
303 |
+
|
304 |
+
def _init_weights(self, m):
|
305 |
+
if isinstance(m, nn.Linear):
|
306 |
+
trunc_normal_(m.weight, std=0.02)
|
307 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
308 |
+
nn.init.constant_(m.bias, 0)
|
309 |
+
elif isinstance(m, nn.LayerNorm):
|
310 |
+
nn.init.constant_(m.bias, 0)
|
311 |
+
nn.init.constant_(m.weight, 1.0)
|
312 |
+
|
313 |
+
def interpolate_pos_encoding(self, x, w, h):
|
314 |
+
npatch = x.shape[1] - 1
|
315 |
+
N = self.pos_embed.shape[1] - 1
|
316 |
+
if npatch == N and w == h:
|
317 |
+
return self.pos_embed
|
318 |
+
class_pos_embed = self.pos_embed[:, 0]
|
319 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
320 |
+
dim = x.shape[-1]
|
321 |
+
w0 = w // self.patch_embed.patch_size
|
322 |
+
h0 = h // self.patch_embed.patch_size
|
323 |
+
# we add a small number to avoid floating point error in the interpolation
|
324 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
325 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
326 |
+
patch_pos_embed = nn.functional.interpolate(
|
327 |
+
patch_pos_embed.reshape(
|
328 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
329 |
+
).permute(0, 3, 1, 2),
|
330 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
331 |
+
mode="bicubic",
|
332 |
+
)
|
333 |
+
assert (
|
334 |
+
int(w0) == patch_pos_embed.shape[-2]
|
335 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
336 |
+
)
|
337 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
338 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
339 |
+
|
340 |
+
def prepare_tokens(self, x):
|
341 |
+
B, nc, w, h = x.shape
|
342 |
+
x = self.patch_embed(x) # patch linear embedding
|
343 |
+
|
344 |
+
# add the [CLS] token to the embed patch tokens
|
345 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
346 |
+
x = torch.cat((cls_tokens, x), dim=1) # (B, S + 1, C)
|
347 |
+
|
348 |
+
# add positional encoding to each token
|
349 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
350 |
+
return self.pos_drop(x)
|
351 |
+
|
352 |
+
def forward(self, x):
|
353 |
+
x = self.prepare_tokens(x)
|
354 |
+
for blk in self.blocks:
|
355 |
+
x = blk(x)
|
356 |
+
x = self.norm(x)
|
357 |
+
return x[:, 0]
|
358 |
+
|
359 |
+
def get_last_selfattention(self, x):
|
360 |
+
x = self.prepare_tokens(x)
|
361 |
+
for i, blk in enumerate(self.blocks):
|
362 |
+
if i < len(self.blocks) - 1:
|
363 |
+
x = blk(x)
|
364 |
+
else:
|
365 |
+
# return attention of the last block
|
366 |
+
return blk(x, return_attention=True)
|
367 |
+
|
368 |
+
def get_intermediate_layers(self, x, n=1):
|
369 |
+
x = self.prepare_tokens(x)
|
370 |
+
# we return the output tokens from the `n` last blocks
|
371 |
+
output = []
|
372 |
+
for i, blk in enumerate(self.blocks):
|
373 |
+
x = blk(x)
|
374 |
+
if len(self.blocks) - i <= n:
|
375 |
+
output.append(self.norm(x))
|
376 |
+
return output
|
377 |
+
|
378 |
+
|
379 |
+
class VisionTransformer4K(Network):
|
380 |
+
"""Vision Transformer 4K"""
|
381 |
+
|
382 |
+
def __init__(
|
383 |
+
self,
|
384 |
+
num_classes=0,
|
385 |
+
img_size=256,
|
386 |
+
input_embed_dim=384,
|
387 |
+
output_embed_dim=192,
|
388 |
+
depth=12,
|
389 |
+
num_heads=12,
|
390 |
+
mlp_ratio=4.0,
|
391 |
+
qkv_bias=False,
|
392 |
+
qk_scale=None,
|
393 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
394 |
+
drop_rate=0.0,
|
395 |
+
attn_drop_rate=0.0,
|
396 |
+
drop_path_rate=0.0,
|
397 |
+
norm_layer=nn.LayerNorm,
|
398 |
+
num_prototypes=64,
|
399 |
+
**kwargs
|
400 |
+
):
|
401 |
+
super().__init__()
|
402 |
+
embed_dim = output_embed_dim
|
403 |
+
self.num_features = self.embed_dim = embed_dim
|
404 |
+
self.phi = nn.Sequential(
|
405 |
+
*[
|
406 |
+
nn.Linear(input_embed_dim, output_embed_dim),
|
407 |
+
nn.GELU(),
|
408 |
+
nn.Dropout(p=drop_rate),
|
409 |
+
]
|
410 |
+
)
|
411 |
+
num_patches = int(img_size // 16) ** 2
|
412 |
+
|
413 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
414 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
415 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
416 |
+
|
417 |
+
dpr = [
|
418 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
419 |
+
] # stochastic depth decay rule
|
420 |
+
self.blocks = nn.ModuleList(
|
421 |
+
[
|
422 |
+
Block(
|
423 |
+
dim=embed_dim,
|
424 |
+
num_heads=num_heads,
|
425 |
+
mlp_ratio=mlp_ratio,
|
426 |
+
qkv_bias=qkv_bias,
|
427 |
+
qk_scale=qk_scale,
|
428 |
+
init_values=init_values,
|
429 |
+
drop=drop_rate,
|
430 |
+
attn_drop=attn_drop_rate,
|
431 |
+
drop_path=dpr[i],
|
432 |
+
norm_layer=norm_layer,
|
433 |
+
)
|
434 |
+
for i in range(depth)
|
435 |
+
]
|
436 |
+
)
|
437 |
+
self.norm = norm_layer(embed_dim)
|
438 |
+
|
439 |
+
# Classifier head
|
440 |
+
self.head = (
|
441 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
442 |
+
)
|
443 |
+
|
444 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
445 |
+
trunc_normal_(self.cls_token, std=0.02)
|
446 |
+
self.apply(self._init_weights)
|
447 |
+
|
448 |
+
def _init_weights(self, m):
|
449 |
+
if isinstance(m, nn.Linear):
|
450 |
+
trunc_normal_(m.weight, std=0.02)
|
451 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
452 |
+
nn.init.constant_(m.bias, 0)
|
453 |
+
elif isinstance(m, nn.LayerNorm):
|
454 |
+
nn.init.constant_(m.bias, 0)
|
455 |
+
nn.init.constant_(m.weight, 1.0)
|
456 |
+
|
457 |
+
def interpolate_pos_encoding(self, x, w, h):
|
458 |
+
npatch = x.shape[1] - 1
|
459 |
+
N = self.pos_embed.shape[1] - 1
|
460 |
+
if npatch == N and w == h:
|
461 |
+
return self.pos_embed
|
462 |
+
class_pos_embed = self.pos_embed[:, 0]
|
463 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
464 |
+
dim = x.shape[-1]
|
465 |
+
w0 = w // 1
|
466 |
+
h0 = h // 1
|
467 |
+
# we add a small number to avoid floating point error in the interpolation
|
468 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
469 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
470 |
+
patch_pos_embed = nn.functional.interpolate(
|
471 |
+
patch_pos_embed.reshape(
|
472 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
473 |
+
).permute(0, 3, 1, 2),
|
474 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
475 |
+
mode="bicubic",
|
476 |
+
)
|
477 |
+
assert (
|
478 |
+
int(w0) == patch_pos_embed.shape[-2]
|
479 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
480 |
+
)
|
481 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
482 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
483 |
+
|
484 |
+
def prepare_tokens(self, x):
|
485 |
+
# print('preparing tokens (after crop)', x.shape)
|
486 |
+
self.mpp_feature = x
|
487 |
+
B, embed_dim, w, h = x.shape
|
488 |
+
x = x.flatten(2, 3).transpose(1, 2)
|
489 |
+
|
490 |
+
x = self.phi(x)
|
491 |
+
|
492 |
+
# add the [CLS] token to the embed patch tokens
|
493 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
494 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
495 |
+
|
496 |
+
# add positional encoding to each token
|
497 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
498 |
+
|
499 |
+
return self.pos_drop(x)
|
500 |
+
|
501 |
+
def forward(self, x):
|
502 |
+
x = self.prepare_tokens(x)
|
503 |
+
for blk in self.blocks:
|
504 |
+
x = blk(x)
|
505 |
+
x = self.norm(x)
|
506 |
+
return x[:, 0]
|
507 |
+
|
508 |
+
def get_last_selfattention(self, x):
|
509 |
+
x = self.prepare_tokens(x)
|
510 |
+
for i, blk in enumerate(self.blocks):
|
511 |
+
if i < len(self.blocks) - 1:
|
512 |
+
x = blk(x)
|
513 |
+
else:
|
514 |
+
# return attention of the last block
|
515 |
+
return blk(x, return_attention=True)
|
516 |
+
|
517 |
+
def get_intermediate_layers(self, x, n=1):
|
518 |
+
x = self.prepare_tokens(x)
|
519 |
+
# we return the output tokens from the `n` last blocks
|
520 |
+
output = []
|
521 |
+
for i, blk in enumerate(self.blocks):
|
522 |
+
x = blk(x)
|
523 |
+
if len(self.blocks) - i <= n:
|
524 |
+
output.append(self.norm(x))
|
525 |
+
return output
|
526 |
+
|
527 |
+
|
528 |
+
def vit_base(patch_size=16, **kwargs):
|
529 |
+
model = VisionTransformer(
|
530 |
+
patch_size=patch_size,
|
531 |
+
embed_dim=768,
|
532 |
+
depth=12,
|
533 |
+
num_heads=12,
|
534 |
+
mlp_ratio=4,
|
535 |
+
qkv_bias=True,
|
536 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
537 |
+
**kwargs
|
538 |
+
)
|
539 |
+
return model
|
540 |
+
|
541 |
+
|
542 |
+
def vit4k_base(patch_size=16, **kwargs):
|
543 |
+
model = VisionTransformer4K(
|
544 |
+
patch_size=patch_size,
|
545 |
+
input_embed_dim=768,
|
546 |
+
output_embed_dim=768,
|
547 |
+
depth=6,
|
548 |
+
num_heads=12,
|
549 |
+
mlp_ratio=4,
|
550 |
+
qkv_bias=True,
|
551 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
552 |
+
**kwargs
|
553 |
+
)
|
554 |
+
return model
|
555 |
+
|
556 |
+
|
557 |
+
def vit_global_base(patch_size=16, **kwargs):
|
558 |
+
model = VisionTransformer4K(
|
559 |
+
patch_size=patch_size,
|
560 |
+
input_embed_dim=768,
|
561 |
+
output_embed_dim=768,
|
562 |
+
depth=2,
|
563 |
+
num_heads=6,
|
564 |
+
mlp_ratio=4,
|
565 |
+
qkv_bias=True,
|
566 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
567 |
+
**kwargs
|
568 |
+
)
|
569 |
+
return model
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17cc4f62887c3dc97380f0b824278cc16f75638f46657f3487204a0633462373
|
3 |
+
size 576596452
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
|
3 |
+
torch==2.6.0+cu124
|
4 |
+
torchvision==0.21.0+cu124
|
5 |
+
|
6 |
+
# configs
|
7 |
+
pydantic==2.10.3
|
8 |
+
|
9 |
+
# data
|
10 |
+
openslide-bin==4.0.0.6
|
11 |
+
openslide-python==1.4.1
|
12 |
+
cucim-cu12==25.2
|
13 |
+
rectangle_packer==2.0.2
|
14 |
+
opencv-python-headless==4.11.0.86
|
15 |
+
|
16 |
+
huggingface_hub
|
utils/__init__.py
ADDED
File without changes
|
utils/tensor_utils.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as t
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
|
6 |
+
|
7 |
+
def tile(x: torch.Tensor, size: int, pad_value: int | float | None = None):
|
8 |
+
C, H, W = x.shape[-3:]
|
9 |
+
|
10 |
+
pad_h = (size - H % size) % size
|
11 |
+
pad_w = (size - W % size) % size
|
12 |
+
if pad_h > 0 or pad_w > 0:
|
13 |
+
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), value=pad_value)
|
14 |
+
|
15 |
+
nh, nw = x.size(-2) // size, x.size(-1) // size
|
16 |
+
return (
|
17 |
+
x.view(-1, C, nh, size, nw, size)
|
18 |
+
.permute(0, 2, 4, 1, 3, 5)
|
19 |
+
.reshape(-1, C, size, size)
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def small_tiles_to_large_tiles(
|
24 |
+
small_tiles: torch.Tensor,
|
25 |
+
width: int,
|
26 |
+
large_tile_size: int,
|
27 |
+
sampled_large_tiles_idx: list | torch.Tensor | None = None,
|
28 |
+
) -> torch.Tensor:
|
29 |
+
|
30 |
+
has_channel = small_tiles.ndim == 4
|
31 |
+
small_tile_size = small_tiles.size(-1)
|
32 |
+
num_small_tiles = small_tiles.size(0)
|
33 |
+
|
34 |
+
nw = width // small_tile_size
|
35 |
+
nh = num_small_tiles // nw
|
36 |
+
|
37 |
+
r = large_tile_size // small_tile_size
|
38 |
+
|
39 |
+
num_large_tiles = (nh // r) * (nw // r)
|
40 |
+
large_tile_indices = (
|
41 |
+
range(num_large_tiles)
|
42 |
+
if sampled_large_tiles_idx is None
|
43 |
+
else sampled_large_tiles_idx
|
44 |
+
)
|
45 |
+
|
46 |
+
tiles = []
|
47 |
+
for k in large_tile_indices:
|
48 |
+
start_row = (k // (nw // r)) * r
|
49 |
+
start_col = (k % (nw // r)) * r
|
50 |
+
for i in range(start_row, start_row + r):
|
51 |
+
for j in range(start_col, start_col + r):
|
52 |
+
tiles.append(small_tiles[i * nw + j])
|
53 |
+
|
54 |
+
stacked = torch.stack(tiles, dim=0).view(-1, r, r, *small_tiles.shape[1:])
|
55 |
+
if has_channel:
|
56 |
+
large_tiles = stacked.permute(0, 3, 1, 4, 2, 5).reshape(
|
57 |
+
-1, small_tiles.size(1), large_tile_size, large_tile_size
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
large_tiles = stacked.permute(0, 1, 3, 2, 4).reshape(
|
61 |
+
-1, large_tile_size, large_tile_size
|
62 |
+
)
|
63 |
+
return large_tiles
|
64 |
+
|
65 |
+
|
66 |
+
def small_tile_flags_to_large_tile_flags(
|
67 |
+
small_tile_flags: torch.Tensor,
|
68 |
+
width: int,
|
69 |
+
small_tile_size: int,
|
70 |
+
large_tile_size: int,
|
71 |
+
aggregation: t.Literal["any", "all"] = "any",
|
72 |
+
):
|
73 |
+
small_tile_flags = small_tile_flags.view(-1, 1, 1)
|
74 |
+
num_small_tiles = small_tile_flags.size(0)
|
75 |
+
nw = width // small_tile_size
|
76 |
+
r = large_tile_size // small_tile_size
|
77 |
+
num_large_tiles = num_small_tiles // r**2
|
78 |
+
large_tile_flags = small_tiles_to_large_tiles(
|
79 |
+
small_tile_flags,
|
80 |
+
width=nw,
|
81 |
+
large_tile_size=r,
|
82 |
+
).view(num_large_tiles, -1)
|
83 |
+
return (
|
84 |
+
large_tile_flags.any(-1) if aggregation == "any" else large_tile_flags.all(-1)
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
def format_first_stg_act_as_second_stg_inp(
|
89 |
+
x: torch.Tensor,
|
90 |
+
height: int,
|
91 |
+
width: int,
|
92 |
+
small_tile_size: int,
|
93 |
+
large_tile_size: int,
|
94 |
+
):
|
95 |
+
assert height % small_tile_size == 0 and width % small_tile_size == 0
|
96 |
+
D = x.size(1)
|
97 |
+
nh, nw = height // small_tile_size, width // small_tile_size
|
98 |
+
r = large_tile_size // small_tile_size
|
99 |
+
x = x.view(-1, nh, nw, D)
|
100 |
+
x = x.permute(0, 3, 1, 2).reshape(-1, D, nh // r, r, nw // r, r)
|
101 |
+
x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, D, r, r)
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
def format_second_stg_inp_as_first_stg_act(
|
106 |
+
x: torch.Tensor, height: int, width: int, small_tile_size: int, large_tile_size: int
|
107 |
+
):
|
108 |
+
D = x.size(1)
|
109 |
+
nh, nw = height // small_tile_size, width // small_tile_size
|
110 |
+
r = large_tile_size // small_tile_size
|
111 |
+
x = x.view(-1, nh // r, nw // r, D, r, r)
|
112 |
+
x = x.permute(0, 3, 1, 4, 2, 5).reshape(-1, D, nh, nw)
|
113 |
+
x = x.permute(0, 2, 3, 1).reshape(-1, D)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
def format_second_stg_act_as_third_stg_inp(
|
118 |
+
x: torch.Tensor,
|
119 |
+
height: int,
|
120 |
+
width: int,
|
121 |
+
large_tile_size: int,
|
122 |
+
):
|
123 |
+
D = x.size(1)
|
124 |
+
nh = height // large_tile_size
|
125 |
+
nw = width // large_tile_size
|
126 |
+
return x.view(-1, nh, nw, D).permute(0, 3, 1, 2).contiguous()
|
127 |
+
|
128 |
+
|
129 |
+
def forward_with_batch_size_limit(
|
130 |
+
net,
|
131 |
+
x: torch.Tensor,
|
132 |
+
batch_size_on_gpu: int,
|
133 |
+
device: str | torch.device,
|
134 |
+
out_device: str | torch.device,
|
135 |
+
preproc_fn: t.Callable[[torch.Tensor], torch.Tensor] | None = None,
|
136 |
+
dtype: torch.dtype = torch.float32,
|
137 |
+
):
|
138 |
+
features = list()
|
139 |
+
for start_idx in range(0, x.size(0), batch_size_on_gpu):
|
140 |
+
end_idx = min(x.size(0), start_idx + batch_size_on_gpu)
|
141 |
+
batch = x[start_idx:end_idx].to(device=device, non_blocking=True)
|
142 |
+
batch = preproc_fn(batch) if preproc_fn else batch
|
143 |
+
batch = batch.to(dtype=dtype, non_blocking=True)
|
144 |
+
actual_bs = end_idx - start_idx
|
145 |
+
batch = pad_to_batch(batch, batch_size_on_gpu)
|
146 |
+
batch: torch.Tensor = forward_compiled(net, batch)
|
147 |
+
# batch = net(batch)
|
148 |
+
features.append(batch[:actual_bs].to(device=out_device, non_blocking=True))
|
149 |
+
if torch.device(out_device).type == "cpu" and torch.device(device).type == "cuda":
|
150 |
+
torch.cuda.synchronize()
|
151 |
+
return torch.cat(features)
|
152 |
+
|
153 |
+
|
154 |
+
@t.overload
|
155 |
+
def backward_with_batch_size_limit(
|
156 |
+
net,
|
157 |
+
x: torch.Tensor,
|
158 |
+
grad: torch.Tensor,
|
159 |
+
batch_size_on_gpu: int,
|
160 |
+
device: str | torch.device,
|
161 |
+
out_device: str | torch.device,
|
162 |
+
dtype: torch.dtype,
|
163 |
+
ret_grad: t.Literal[True],
|
164 |
+
) -> torch.Tensor: ...
|
165 |
+
|
166 |
+
|
167 |
+
@t.overload
|
168 |
+
def backward_with_batch_size_limit(
|
169 |
+
net,
|
170 |
+
x: torch.Tensor,
|
171 |
+
grad: torch.Tensor,
|
172 |
+
batch_size_on_gpu: int,
|
173 |
+
device: str | torch.device,
|
174 |
+
out_device: str | torch.device,
|
175 |
+
dtype: torch.dtype,
|
176 |
+
ret_grad: t.Literal[False],
|
177 |
+
) -> None: ...
|
178 |
+
|
179 |
+
|
180 |
+
def backward_with_batch_size_limit(
|
181 |
+
net,
|
182 |
+
x: torch.Tensor,
|
183 |
+
grad: torch.Tensor,
|
184 |
+
batch_size_on_gpu: int,
|
185 |
+
device: str | torch.device,
|
186 |
+
out_device: str | torch.device,
|
187 |
+
dtype: torch.dtype,
|
188 |
+
ret_grad: bool,
|
189 |
+
):
|
190 |
+
assert x.size(0) == grad.size(0)
|
191 |
+
|
192 |
+
grads = []
|
193 |
+
total = x.size(0)
|
194 |
+
for start in range(0, total, batch_size_on_gpu):
|
195 |
+
end = min(total, start + batch_size_on_gpu)
|
196 |
+
actual_bs = end - start
|
197 |
+
|
198 |
+
batch = x[start:end].to(device=device, dtype=dtype, non_blocking=True)
|
199 |
+
batch = pad_to_batch(batch, batch_size_on_gpu)
|
200 |
+
if ret_grad:
|
201 |
+
batch.requires_grad_(True)
|
202 |
+
|
203 |
+
with torch.autocast(device_type="cuda", dtype=dtype):
|
204 |
+
out = net(batch)
|
205 |
+
# out = forward_compiled(net, batch)
|
206 |
+
|
207 |
+
grad_batch = grad[start:end].to(device=device, dtype=dtype, non_blocking=True)
|
208 |
+
grad_batch = pad_to_batch(grad_batch, batch_size_on_gpu)
|
209 |
+
|
210 |
+
with torch._dynamo.utils.maybe_enable_compiled_autograd(
|
211 |
+
True, fullgraph=True, dynamic=False
|
212 |
+
):
|
213 |
+
out.backward(grad_batch)
|
214 |
+
# out.backward(grad_batch)
|
215 |
+
|
216 |
+
if ret_grad:
|
217 |
+
assert batch.grad is not None
|
218 |
+
grads.append(batch.grad[:actual_bs].to(out_device, non_blocking=True))
|
219 |
+
|
220 |
+
if ret_grad:
|
221 |
+
if (
|
222 |
+
torch.device(out_device).type == "cpu"
|
223 |
+
and torch.device(device).type == "cuda"
|
224 |
+
):
|
225 |
+
torch.cuda.synchronize()
|
226 |
+
return torch.cat(grads)
|
227 |
+
|
228 |
+
|
229 |
+
@torch.compile(fullgraph=True, dynamic=False)
|
230 |
+
def forward_compiled(net, x: torch.Tensor) -> torch.Tensor:
|
231 |
+
return net(x)
|
232 |
+
|
233 |
+
|
234 |
+
def pad_to_batch(t: torch.Tensor, batch_size: int) -> torch.Tensor:
|
235 |
+
assert (
|
236 |
+
t.size(0) <= batch_size
|
237 |
+
), f"'{t.shape}' size tensor cannot be padded to be batch size of '{batch_size}'"
|
238 |
+
pad = batch_size - t.size(0)
|
239 |
+
return torch.cat([t, t.new_zeros((pad,) + t.shape[1:])], dim=0) if pad > 0 else t
|
240 |
+
|
241 |
+
|
242 |
+
def scale_and_normalize(x: torch.Tensor, inplace: bool = False):
|
243 |
+
x = x.clamp_(0, 255) if inplace else x.clamp(0, 255)
|
244 |
+
x = TF.normalize(
|
245 |
+
x / 255, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=inplace
|
246 |
+
)
|
247 |
+
return x
|
248 |
+
|
249 |
+
|
250 |
+
def combine_tile_list(tile_list: list[torch.Tensor], ncols: int):
|
251 |
+
"""
|
252 |
+
Combines a flat list of tile tensors (each with shape (C, H, W)) into one output tensor,
|
253 |
+
arranging them in a grid with the specified number of columns. The tiles in the final row
|
254 |
+
or column may have different sizes.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
tile_list (list of torch.Tensor): A flat list of tile tensors, each with shape
|
258 |
+
(channels, tile_height, tile_width). It is assumed
|
259 |
+
that the number of channels is consistent across all tiles.
|
260 |
+
ncols (int): Number of columns to arrange the tiles in.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
torch.Tensor: A tensor of shape (channels, total_height, total_width), where:
|
264 |
+
- total_height is the sum of maximum tile heights in each row.
|
265 |
+
- total_width is the sum of maximum tile widths in each column.
|
266 |
+
"""
|
267 |
+
if not tile_list:
|
268 |
+
raise ValueError("tile_list is empty")
|
269 |
+
|
270 |
+
ntiles = len(tile_list)
|
271 |
+
nrows = (ntiles + ncols - 1) // ncols # Ceiling division to get the number of rows
|
272 |
+
|
273 |
+
# Convert the flat tile list into a nested list (rows of tiles)
|
274 |
+
nested_tiles = [tile_list[i * ncols : (i + 1) * ncols] for i in range(nrows)]
|
275 |
+
|
276 |
+
# Compute the maximum tile height for each row
|
277 |
+
row_heights = [max(tile.shape[1] for tile in row) for row in nested_tiles]
|
278 |
+
|
279 |
+
# Compute the maximum tile width for each column (consider only rows that have a tile in that column)
|
280 |
+
col_widths = []
|
281 |
+
for col in range(ncols):
|
282 |
+
max_width = 0
|
283 |
+
for row in nested_tiles:
|
284 |
+
if col < len(row):
|
285 |
+
tile_w = row[col].shape[2]
|
286 |
+
if tile_w > max_width:
|
287 |
+
max_width = tile_w
|
288 |
+
col_widths.append(max_width)
|
289 |
+
|
290 |
+
# Calculate the total output dimensions
|
291 |
+
total_height = sum(row_heights)
|
292 |
+
total_width = sum(col_widths)
|
293 |
+
|
294 |
+
# Determine the number of channels from the first tile
|
295 |
+
channels = tile_list[0].shape[0]
|
296 |
+
|
297 |
+
# Preallocate the output tensor (this avoids repeated concatenation and extra memory copies)
|
298 |
+
out_tensor = torch.zeros(
|
299 |
+
channels,
|
300 |
+
total_height,
|
301 |
+
total_width,
|
302 |
+
dtype=tile_list[0].dtype,
|
303 |
+
device=tile_list[0].device,
|
304 |
+
)
|
305 |
+
|
306 |
+
# Place each tile in its proper location by calculating offsets
|
307 |
+
y_offset = 0
|
308 |
+
for i, row in enumerate(nested_tiles):
|
309 |
+
x_offset = 0
|
310 |
+
for j, tile in enumerate(row):
|
311 |
+
tile_h, tile_w = tile.shape[1], tile.shape[2]
|
312 |
+
out_tensor[
|
313 |
+
:, y_offset : y_offset + tile_h, x_offset : x_offset + tile_w
|
314 |
+
] = tile
|
315 |
+
x_offset += col_widths[j]
|
316 |
+
y_offset += row_heights[i]
|
317 |
+
|
318 |
+
return out_tensor
|
utils/wsi_utils.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import typing as t
|
2 |
+
from concurrent.futures import ThreadPoolExecutor
|
3 |
+
from pathlib import Path
|
4 |
+
from tracemalloc import start
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import rpack
|
9 |
+
from openslide import OpenSlide
|
10 |
+
from PIL import Image
|
11 |
+
from scipy.ndimage import binary_fill_holes
|
12 |
+
from skimage import filters
|
13 |
+
from skimage.morphology import remove_small_objects
|
14 |
+
|
15 |
+
if t.TYPE_CHECKING:
|
16 |
+
from _typeshed import StrPath
|
17 |
+
|
18 |
+
try:
|
19 |
+
from skimage import img_as_ubyte # type: ignore
|
20 |
+
except:
|
21 |
+
from skimage.util import img_as_ubyte # type: ignore
|
22 |
+
|
23 |
+
|
24 |
+
def find_contours(arr: np.ndarray, only_outer: bool = True, convex: bool = False):
|
25 |
+
"""Find contours in a binary image
|
26 |
+
|
27 |
+
Parameters
|
28 |
+
----------
|
29 |
+
arr : np.ndarray
|
30 |
+
Binary image
|
31 |
+
only_outer : bool
|
32 |
+
If True, only find external contours
|
33 |
+
convex : bool
|
34 |
+
If True, return convex hull of contours
|
35 |
+
|
36 |
+
Returns
|
37 |
+
-------
|
38 |
+
contours : list
|
39 |
+
List of contours
|
40 |
+
"""
|
41 |
+
mode = cv2.RETR_EXTERNAL if only_outer else cv2.RETR_LIST
|
42 |
+
cresults = cv2.findContours(arr.astype(np.uint8), mode, cv2.CHAIN_APPROX_SIMPLE)
|
43 |
+
|
44 |
+
contours = cresults[1] if len(cresults) == 3 else cresults[0]
|
45 |
+
contours = list(contours) if isinstance(contours, tuple) else contours
|
46 |
+
|
47 |
+
if convex:
|
48 |
+
contours = [cv2.convexHull(cnt) for cnt in contours]
|
49 |
+
return contours
|
50 |
+
|
51 |
+
|
52 |
+
def merge_overlapping_bboxes(bboxes: list):
|
53 |
+
"""Merge overlapping bounding boxes
|
54 |
+
|
55 |
+
Parameters
|
56 |
+
----------
|
57 |
+
bboxes : list
|
58 |
+
List of bounding boxes in format (x, y, width, height)
|
59 |
+
"""
|
60 |
+
candidate_count = 0
|
61 |
+
while candidate_count < len(bboxes):
|
62 |
+
candidate_count += 1
|
63 |
+
overlap = False
|
64 |
+
candidate_box = bboxes.pop(0)
|
65 |
+
for index, compare_box in enumerate(bboxes):
|
66 |
+
overlapping, new_bbox = merge_if_overlapping(candidate_box, compare_box)
|
67 |
+
if overlapping:
|
68 |
+
overlap = True
|
69 |
+
candidate_count = 0
|
70 |
+
bboxes.pop(index)
|
71 |
+
bboxes.append(new_bbox)
|
72 |
+
break
|
73 |
+
if not overlap:
|
74 |
+
bboxes.append(candidate_box)
|
75 |
+
|
76 |
+
|
77 |
+
def merge_if_overlapping(a: tuple, b: tuple):
|
78 |
+
"""Check if two bounding boxes overlap and merge them if they do
|
79 |
+
|
80 |
+
Parameters
|
81 |
+
----------
|
82 |
+
a : tuple
|
83 |
+
First bounding box in format (x, y, width, height)
|
84 |
+
b : tuple
|
85 |
+
Second bounding box in format (x, y, width, height)
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
overlapping : bool
|
90 |
+
True if boxes overlap
|
91 |
+
new_bbox : tuple
|
92 |
+
Merged bounding box if overlapping, empty list otherwise
|
93 |
+
"""
|
94 |
+
bottom = np.max([a[0], b[0]])
|
95 |
+
top = np.min([a[0] + a[2], b[0] + b[2]])
|
96 |
+
left = np.max([a[1], b[1]])
|
97 |
+
right = np.min([a[1] + a[3], b[1] + b[3]])
|
98 |
+
|
99 |
+
do_intersect = bottom < top and left < right
|
100 |
+
|
101 |
+
if do_intersect:
|
102 |
+
x_min = np.min([a[1], b[1]])
|
103 |
+
y_min = np.min([a[0], b[0]])
|
104 |
+
x_max = np.max([a[1] + a[3], b[1] + b[3]])
|
105 |
+
y_max = np.max([a[0] + a[2], b[0] + b[2]])
|
106 |
+
new_bbox = (y_min, x_min, y_max - y_min, x_max - x_min)
|
107 |
+
return True, new_bbox
|
108 |
+
|
109 |
+
return False, []
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def load_slide_img(
|
114 |
+
wsi,
|
115 |
+
level: int = 0,
|
116 |
+
) -> np.ndarray:
|
117 |
+
"""Load slide image with specific level
|
118 |
+
|
119 |
+
Parameters
|
120 |
+
----------
|
121 |
+
wsi : CuImage
|
122 |
+
The CuImage object
|
123 |
+
level : int
|
124 |
+
Slide level to load
|
125 |
+
|
126 |
+
Returns
|
127 |
+
-------
|
128 |
+
slide_img : np.ndarray
|
129 |
+
Numpy array with RGB channels
|
130 |
+
"""
|
131 |
+
slide_img = np.asarray(wsi.read_region(level=level, device="gpu", num_workers=32))
|
132 |
+
if slide_img.shape[2] == 4:
|
133 |
+
slide_img = slide_img[:, :, :-1]
|
134 |
+
return slide_img
|
135 |
+
|
136 |
+
|
137 |
+
def rgb2gray(img):
|
138 |
+
"""Convert RGB image to grayscale
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
img : np.ndarray
|
143 |
+
RGB image with 3 channels
|
144 |
+
|
145 |
+
Returns
|
146 |
+
-------
|
147 |
+
gray : np.ndarray
|
148 |
+
Grayscale image
|
149 |
+
"""
|
150 |
+
return np.dot(img, [0.299, 0.587, 0.114])
|
151 |
+
|
152 |
+
|
153 |
+
def thresh_slide(gray, thresh_val, sigma=13):
|
154 |
+
"""Threshold gray image to binary image
|
155 |
+
|
156 |
+
Parameters
|
157 |
+
----------
|
158 |
+
gray : np.ndarray
|
159 |
+
2D grayscale image
|
160 |
+
thresh_val : float
|
161 |
+
Thresholding value
|
162 |
+
sigma : int
|
163 |
+
Gaussian smoothing sigma
|
164 |
+
|
165 |
+
Returns
|
166 |
+
-------
|
167 |
+
bw_img : np.ndarray
|
168 |
+
Binary image
|
169 |
+
"""
|
170 |
+
smooth = filters.gaussian(gray, sigma=sigma)
|
171 |
+
smooth /= np.amax(smooth)
|
172 |
+
bw_img = smooth < thresh_val
|
173 |
+
return bw_img
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
def get_tissue_bboxes(
|
178 |
+
mask: np.ndarray, wsi_width: int, wsi_height: int, min_tissue_size: int = 10000
|
179 |
+
):
|
180 |
+
scale = wsi_height / mask.shape[0]
|
181 |
+
|
182 |
+
contours = find_contours(mask)
|
183 |
+
areas = []
|
184 |
+
for cnt in contours:
|
185 |
+
area = cv2.contourArea(cnt)
|
186 |
+
areas.append(area)
|
187 |
+
|
188 |
+
large_contours = []
|
189 |
+
large_areas = []
|
190 |
+
for i, cnt in enumerate(contours):
|
191 |
+
area_mm = areas[i]
|
192 |
+
if area_mm >= min_tissue_size:
|
193 |
+
large_contours.append(cnt)
|
194 |
+
large_areas.append(area_mm)
|
195 |
+
|
196 |
+
areas = large_areas
|
197 |
+
|
198 |
+
boxes = [cv2.boundingRect(c) for c in large_contours]
|
199 |
+
|
200 |
+
return (
|
201 |
+
[cv2.boundingRect(c) for c in large_contours]
|
202 |
+
if boxes
|
203 |
+
else [[0, 0, wsi_width, wsi_height]]
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
def get_tissue_positions_and_packed_size(
|
208 |
+
boxes,
|
209 |
+
wsi_width: int,
|
210 |
+
wsi_height: int,
|
211 |
+
scale: float,
|
212 |
+
) -> tuple[list[tuple[int, int]], tuple[int, int]]:
|
213 |
+
if len(boxes) > 1:
|
214 |
+
merge_overlapping_bboxes(boxes)
|
215 |
+
boxes = np.array(boxes, dtype=np.float32) * scale
|
216 |
+
if len(boxes.shape) == 1:
|
217 |
+
boxes = boxes[None]
|
218 |
+
boxes[:, :2] = np.floor(boxes[:, :2])
|
219 |
+
boxes[:, 0] = np.clip(boxes[:, 0], 0, wsi_width - 1)
|
220 |
+
boxes[:, 1] = np.clip(boxes[:, 1], 0, wsi_height - 1)
|
221 |
+
boxes[:, 2:] = np.ceil(boxes[:, 2:])
|
222 |
+
boxes[:, 2] = np.clip(boxes[:, 2], 0, wsi_width - boxes[:, 0])
|
223 |
+
boxes[:, 3] = np.clip(boxes[:, 3], 0, wsi_height - boxes[:, 1])
|
224 |
+
boxes = boxes.astype(np.int32)
|
225 |
+
|
226 |
+
box_sizes = [(int(box[2]), int(box[3])) for box in boxes]
|
227 |
+
positions = rpack.pack(box_sizes) # at processing spacing
|
228 |
+
packed_size: tuple[int, int] = rpack.bbox_size(
|
229 |
+
box_sizes, positions
|
230 |
+
) # width, height
|
231 |
+
|
232 |
+
counter = 0
|
233 |
+
for sdf in np.arange(0.5, 0.96, 0.05):
|
234 |
+
# asymmetry_factor = min(packed_size)/max(packed_size)
|
235 |
+
# if asymmetry_factor < sdf:
|
236 |
+
rparams = {
|
237 |
+
"max_height": int(max(packed_size) * sdf),
|
238 |
+
"max_width": int(max(packed_size) * sdf),
|
239 |
+
}
|
240 |
+
try:
|
241 |
+
positions = rpack.pack(box_sizes, **rparams) # at processing spacing
|
242 |
+
packed_size: tuple[int, int] = rpack.bbox_size(box_sizes, positions)
|
243 |
+
break
|
244 |
+
except rpack.PackingImpossibleError as ex:
|
245 |
+
counter += 1
|
246 |
+
|
247 |
+
return positions, (int(packed_size[0]), int(packed_size[1]))
|
248 |
+
|
249 |
+
|
250 |
+
def pack_slide(
|
251 |
+
wsi_arr: np.ndarray,
|
252 |
+
mask: np.ndarray,
|
253 |
+
min_tissue_size: int = 10000,
|
254 |
+
):
|
255 |
+
H, W = wsi_arr.shape[:2]
|
256 |
+
boxes = get_tissue_bboxes(mask, W, H, min_tissue_size=min_tissue_size)
|
257 |
+
if len(boxes) > 0:
|
258 |
+
positions, packed_size = get_tissue_positions_and_packed_size(
|
259 |
+
boxes, W, H, H / mask.shape[0]
|
260 |
+
)
|
261 |
+
img_out = np.full(
|
262 |
+
(packed_size[1], packed_size[0]) + wsi_arr.shape[2:],
|
263 |
+
255,
|
264 |
+
dtype=wsi_arr.dtype,
|
265 |
+
)
|
266 |
+
mask_out = np.zeros((packed_size[1], packed_size[0]), dtype=np.bool)
|
267 |
+
for i, pos in enumerate(positions):
|
268 |
+
box = boxes[i]
|
269 |
+
img_out[pos[1] : pos[1] + box[3], pos[0] : pos[0] + box[2]] = wsi_arr[
|
270 |
+
box[1] : box[1] + box[3], box[0] : box[0] + box[2]
|
271 |
+
]
|
272 |
+
mask_out[pos[1] : pos[1] + box[3], pos[0] : pos[0] + box[2]] = mask[
|
273 |
+
box[1] : box[1] + box[3], box[0] : box[0] + box[2]
|
274 |
+
]
|
275 |
+
else:
|
276 |
+
img_out = wsi_arr
|
277 |
+
mask_out = mask
|
278 |
+
|
279 |
+
return img_out, mask_out
|
280 |
+
|
281 |
+
|
282 |
+
def get_level_downsamples(wsi: OpenSlide):
|
283 |
+
level_downsamples = []
|
284 |
+
dim_0 = wsi.level_dimensions[0]
|
285 |
+
|
286 |
+
for downsample, dim in zip(wsi.level_downsamples, wsi.level_dimensions):
|
287 |
+
estimated_downsample = (dim_0[0] / float(dim[0]), dim_0[1] / float(dim[1]))
|
288 |
+
(
|
289 |
+
level_downsamples.append(estimated_downsample)
|
290 |
+
if estimated_downsample != (downsample, downsample)
|
291 |
+
else level_downsamples.append((downsample, downsample))
|
292 |
+
)
|
293 |
+
|
294 |
+
return level_downsamples
|
295 |
+
|
296 |
+
|
297 |
+
def segment_tissue(
|
298 |
+
wsi_path: Path,
|
299 |
+
seg_level=-1,
|
300 |
+
sthresh=8,
|
301 |
+
sthresh_up=255,
|
302 |
+
mthresh=7,
|
303 |
+
close=4,
|
304 |
+
filter_params={"a_t": 1, "a_h": 1, "max_n_holes": 100},
|
305 |
+
ref_patch_size=512,
|
306 |
+
):
|
307 |
+
"""
|
308 |
+
Segment the tissue via HSV -> Median thresholding -> Binary threshold
|
309 |
+
"""
|
310 |
+
|
311 |
+
def _filter_contours(contours, hierarchy, filter_params):
|
312 |
+
"""
|
313 |
+
Filter contours by: area.
|
314 |
+
"""
|
315 |
+
filtered = []
|
316 |
+
|
317 |
+
# find indices of foreground contours (parent == -1)
|
318 |
+
hierarchy_1 = np.flatnonzero(hierarchy[:, 1] == -1)
|
319 |
+
all_holes = []
|
320 |
+
|
321 |
+
# loop through foreground contour indices
|
322 |
+
for cont_idx in hierarchy_1:
|
323 |
+
# actual contour
|
324 |
+
cont = contours[cont_idx]
|
325 |
+
# indices of holes contained in this contour (children of parent contour)
|
326 |
+
holes = np.flatnonzero(hierarchy[:, 1] == cont_idx)
|
327 |
+
# take contour area (includes holes)
|
328 |
+
a = cv2.contourArea(cont)
|
329 |
+
# calculate the contour area of each hole
|
330 |
+
hole_areas = [cv2.contourArea(contours[hole_idx]) for hole_idx in holes]
|
331 |
+
# actual area of foreground contour region
|
332 |
+
a = a - np.array(hole_areas).sum()
|
333 |
+
if a == 0:
|
334 |
+
continue
|
335 |
+
if tuple((filter_params["a_t"],)) < tuple((a,)):
|
336 |
+
filtered.append(cont_idx)
|
337 |
+
all_holes.append(holes)
|
338 |
+
|
339 |
+
foreground_contours = [contours[cont_idx] for cont_idx in filtered]
|
340 |
+
|
341 |
+
hole_contours = []
|
342 |
+
|
343 |
+
for hole_ids in all_holes:
|
344 |
+
unfiltered_holes = [contours[idx] for idx in hole_ids]
|
345 |
+
unfilered_holes = sorted(
|
346 |
+
unfiltered_holes, key=cv2.contourArea, reverse=True
|
347 |
+
)
|
348 |
+
# take max_n_holes largest holes by area
|
349 |
+
unfilered_holes = unfilered_holes[: filter_params["max_n_holes"]]
|
350 |
+
filtered_holes = []
|
351 |
+
|
352 |
+
# filter these holes
|
353 |
+
for hole in unfilered_holes:
|
354 |
+
if cv2.contourArea(hole) > filter_params["a_h"]:
|
355 |
+
filtered_holes.append(hole)
|
356 |
+
|
357 |
+
hole_contours.append(filtered_holes)
|
358 |
+
|
359 |
+
return foreground_contours, hole_contours
|
360 |
+
|
361 |
+
def draw_white_bands(img: np.ndarray, thickness: int):
|
362 |
+
height, width = img.shape[:2]
|
363 |
+
white = [255, 255, 255] # ํฐ์ (B, G, R)
|
364 |
+
|
365 |
+
# cv2.copyMakeBorder ํจ์๋ฅผ ์ฌ์ฉํด ํฐ์ ๋ ๋ฅผ ์ถ๊ฐ
|
366 |
+
# ๋๊ป 30ํฝ์
์ ์์ชฝ ํฐ์ ๋ ๊ทธ๋ฆฌ๊ธฐ
|
367 |
+
cv2.rectangle(img, (0, 0), (width, thickness), white, -1)
|
368 |
+
|
369 |
+
# ๋๊ป 30ํฝ์
์ ์๋์ชฝ ํฐ์ ๋ ๊ทธ๋ฆฌ๊ธฐ
|
370 |
+
cv2.rectangle(img, (0, height - thickness), (width, height), white, -1)
|
371 |
+
|
372 |
+
# ๋๊ป 30ํฝ์
์ ์ผ์ชฝ ํฐ์ ๋ ๊ทธ๋ฆฌ๊ธฐ
|
373 |
+
cv2.rectangle(img, (0, 0), (thickness, height), white, -1)
|
374 |
+
|
375 |
+
# ๋๊ป 30ํฝ์
์ ์ค๋ฅธ์ชฝ ํฐ์ ๋ ๊ทธ๋ฆฌ๊ธฐ
|
376 |
+
cv2.rectangle(img, (width - thickness, 0), (width, height), white, -1)
|
377 |
+
|
378 |
+
with OpenSlide(str(wsi_path)) as wsi:
|
379 |
+
if seg_level < 0:
|
380 |
+
seg_level = wsi.get_best_level_for_downsample(64)
|
381 |
+
|
382 |
+
img = np.asarray(
|
383 |
+
wsi.read_region(
|
384 |
+
location=(0, 0), level=seg_level, size=wsi.level_dimensions[seg_level]
|
385 |
+
)
|
386 |
+
)
|
387 |
+
|
388 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
389 |
+
draw_white_bands(img_rgb, thickness=20)
|
390 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
|
391 |
+
|
392 |
+
H, W = img_rgb.shape[:2]
|
393 |
+
|
394 |
+
B_8, G_8, R_8 = cv2.split(img_rgb)
|
395 |
+
B = B_8.astype(np.int32)
|
396 |
+
G = G_8.astype(np.int32)
|
397 |
+
R = R_8.astype(np.int32)
|
398 |
+
|
399 |
+
mask = (R >= 0) & (R <= 110) & (G >= 0) & (G <= 110) & (B >= 0) & (B <= 110)
|
400 |
+
|
401 |
+
color_difference1 = np.abs((R) - (G)) <= 15
|
402 |
+
color_difference2 = np.abs((G) - (B)) <= 15
|
403 |
+
color_difference3 = np.abs((R) - (B)) <= 15
|
404 |
+
color_difference = color_difference1 & color_difference2 & color_difference3
|
405 |
+
|
406 |
+
final_mask = mask & color_difference
|
407 |
+
|
408 |
+
laplacian = cv2.Laplacian(img_gray, cv2.CV_64F)
|
409 |
+
laplacian_abs = cv2.convertScaleAbs(laplacian)
|
410 |
+
mask = laplacian_abs <= 15
|
411 |
+
img_rgb[mask] = [255, 255, 255]
|
412 |
+
|
413 |
+
img_hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV) # Convert to HSV space
|
414 |
+
img_med = cv2.medianBlur(
|
415 |
+
img_hsv[:, :, 1], mthresh
|
416 |
+
) # Apply median blurring #same to median filter
|
417 |
+
|
418 |
+
# Thresholding
|
419 |
+
_, img_thresh = cv2.threshold(img_med, sthresh, sthresh_up, cv2.THRESH_BINARY)
|
420 |
+
# Morphological closing
|
421 |
+
if close > 0:
|
422 |
+
kernel = np.ones((close, close), np.uint8)
|
423 |
+
img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_CLOSE, kernel)
|
424 |
+
|
425 |
+
# before k-medicon
|
426 |
+
scale = get_level_downsamples(wsi)[seg_level]
|
427 |
+
scaled_ref_patch_area = int(ref_patch_size**2 / (scale[0] * scale[1]))
|
428 |
+
filter_params = filter_params.copy()
|
429 |
+
filter_params["a_t"] = filter_params["a_t"] * scaled_ref_patch_area
|
430 |
+
filter_params["a_h"] = filter_params["a_h"] * scaled_ref_patch_area
|
431 |
+
|
432 |
+
# Find and filter contours
|
433 |
+
contours, hierarchy = cv2.findContours(
|
434 |
+
img_thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE
|
435 |
+
)
|
436 |
+
|
437 |
+
hierarchy = np.squeeze(hierarchy, axis=(0,))[:, 2:]
|
438 |
+
foreground_contours, hole_contours = _filter_contours(
|
439 |
+
contours, hierarchy, filter_params
|
440 |
+
) # Necessary for filtering out artifacts
|
441 |
+
|
442 |
+
mask = np.zeros(img_rgb.shape[:2], dtype=np.uint8)
|
443 |
+
for i, cont in enumerate(foreground_contours):
|
444 |
+
if cont is None or len(cont) == 0:
|
445 |
+
print(f"Warning: Empty contour at index {i}")
|
446 |
+
continue
|
447 |
+
|
448 |
+
if (
|
449 |
+
cont[:, :, 0].max() >= W
|
450 |
+
or cont[:, :, 1].max() >= H
|
451 |
+
or cont[:, :, 0].min() < 0
|
452 |
+
or cont[:, :, 1].min() < 0
|
453 |
+
):
|
454 |
+
print(f"Warning: Contour {i} coordinates out of bounds!")
|
455 |
+
continue
|
456 |
+
|
457 |
+
# Fill the main tissue contour
|
458 |
+
cv2.fillPoly(mask, [cont], 255) # type: ignore
|
459 |
+
|
460 |
+
# Remove holes if they exist
|
461 |
+
if i < len(hole_contours) and hole_contours[i]:
|
462 |
+
for hole in hole_contours[i]: # type: ignore
|
463 |
+
cv2.fillPoly(mask, [hole], 0) # type: ignore
|
464 |
+
mask = mask.astype(np.bool)
|
465 |
+
if not mask.any():
|
466 |
+
mask[:, :] = True # If no mask, return full mask
|
467 |
+
|
468 |
+
return mask, img_rgb
|
469 |
+
|
470 |
+
|
471 |
+
def get_mask_path_by_wsi_path(wsi_path: Path, wsi_dir: Path, mask_dir: Path) -> Path:
|
472 |
+
wsi_path, wsi_dir, mask_dir = (
|
473 |
+
wsi_path.absolute(),
|
474 |
+
wsi_dir.absolute(),
|
475 |
+
mask_dir.absolute(),
|
476 |
+
)
|
477 |
+
rel_path = wsi_path.relative_to(wsi_dir)
|
478 |
+
stitch_path_prefix = mask_dir / rel_path
|
479 |
+
stitch_path_prefix = stitch_path_prefix.parent / rel_path.stem
|
480 |
+
extensions = ["jpg", "jpeg", "png", "webp"]
|
481 |
+
extensions += [ext.upper() for ext in extensions]
|
482 |
+
stitch_paths = [
|
483 |
+
stitch_path_prefix.parent / (rel_path.stem + f".{ext}") for ext in extensions
|
484 |
+
]
|
485 |
+
stitch_paths += [
|
486 |
+
stitch_path_prefix.parent / rel_path.stem / (rel_path.stem + f".{ext}")
|
487 |
+
for ext in extensions
|
488 |
+
]
|
489 |
+
ret = None
|
490 |
+
for stitch_path in stitch_paths:
|
491 |
+
if stitch_path.exists():
|
492 |
+
ret = stitch_path
|
493 |
+
if ret is None:
|
494 |
+
raise FileNotFoundError(
|
495 |
+
f"No mask for wsi '{wsi_path}' in mask dir '{mask_dir}' (candidates: {', '.join([str(p) for p in stitch_paths])})"
|
496 |
+
)
|
497 |
+
return ret
|
498 |
+
|
499 |
+
|
500 |
+
def read_mask(mask_path: Path) -> np.ndarray:
|
501 |
+
img = Image.open(mask_path)
|
502 |
+
w, h = img.size
|
503 |
+
return np.asarray(img).reshape((h, w, -1)).max(-1) > 0
|
504 |
+
|
505 |
+
|
506 |
+
def read_mask_by_wsi_path(wsi_path: Path, wsi_dir: Path, mask_dir: Path) -> np.ndarray:
|
507 |
+
wsi_path, wsi_dir, mask_dir = (
|
508 |
+
wsi_path.absolute(),
|
509 |
+
wsi_dir.absolute(),
|
510 |
+
mask_dir.absolute(),
|
511 |
+
)
|
512 |
+
mask_path = get_mask_path_by_wsi_path(wsi_path, wsi_dir, mask_dir)
|
513 |
+
return read_mask(mask_path)
|
514 |
+
|