project-monai commited on
Commit
b6d2bca
·
verified ·
1 Parent(s): 8354516

Upload renalStructures_UNEST_segmentation version 0.2.6

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/demos.png filter=lfs diff=lfs merge=lfs -text
37
+ docs/renal.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
configs/inference.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os"
5
+ ],
6
+ "bundle_root": "/models/renalStructures_UNEST_segmentation",
7
+ "output_dir": "$@bundle_root + '/eval'",
8
+ "dataset_dir": "$@bundle_root + './dataset/spleen'",
9
+ "datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.nii.gz')))",
10
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
11
+ "network_def": {
12
+ "_target_": "scripts.networks.unest.UNesT",
13
+ "in_channels": 1,
14
+ "out_channels": 4
15
+ },
16
+ "network": "$@network_def.to(@device)",
17
+ "preprocessing": {
18
+ "_target_": "Compose",
19
+ "transforms": [
20
+ {
21
+ "_target_": "LoadImaged",
22
+ "keys": "image"
23
+ },
24
+ {
25
+ "_target_": "EnsureChannelFirstd",
26
+ "keys": "image",
27
+ "channel_dim": "no_channel"
28
+ },
29
+ {
30
+ "_target_": "Orientationd",
31
+ "keys": "image",
32
+ "axcodes": "RAS"
33
+ },
34
+ {
35
+ "_target_": "Spacingd",
36
+ "keys": "image",
37
+ "pixdim": [
38
+ 1.0,
39
+ 1.0,
40
+ 1.0
41
+ ],
42
+ "mode": "bilinear"
43
+ },
44
+ {
45
+ "_target_": "ScaleIntensityRanged",
46
+ "keys": "image",
47
+ "a_min": -175,
48
+ "a_max": 250,
49
+ "b_min": 0.0,
50
+ "b_max": 1.0,
51
+ "clip": true
52
+ },
53
+ {
54
+ "_target_": "EnsureTyped",
55
+ "keys": "image"
56
+ }
57
+ ]
58
+ },
59
+ "dataset": {
60
+ "_target_": "Dataset",
61
+ "data": "$[{'image': i} for i in @datalist]",
62
+ "transform": "@preprocessing"
63
+ },
64
+ "dataloader": {
65
+ "_target_": "DataLoader",
66
+ "dataset": "@dataset",
67
+ "batch_size": 1,
68
+ "shuffle": false,
69
+ "num_workers": 4
70
+ },
71
+ "inferer": {
72
+ "_target_": "SlidingWindowInferer",
73
+ "roi_size": [
74
+ 96,
75
+ 96,
76
+ 96
77
+ ],
78
+ "sw_batch_size": 4,
79
+ "overlap": 0.5
80
+ },
81
+ "postprocessing": {
82
+ "_target_": "Compose",
83
+ "transforms": [
84
+ {
85
+ "_target_": "Activationsd",
86
+ "keys": "pred",
87
+ "softmax": true
88
+ },
89
+ {
90
+ "_target_": "Invertd",
91
+ "keys": "pred",
92
+ "transform": "@preprocessing",
93
+ "orig_keys": "image",
94
+ "nearest_interp": false,
95
+ "to_tensor": true
96
+ },
97
+ {
98
+ "_target_": "AsDiscreted",
99
+ "keys": "pred",
100
+ "argmax": true
101
+ },
102
+ {
103
+ "_target_": "SaveImaged",
104
+ "keys": "pred",
105
+ "output_dir": "@output_dir"
106
+ }
107
+ ]
108
+ },
109
+ "handlers": [
110
+ {
111
+ "_target_": "CheckpointLoader",
112
+ "load_path": "$@bundle_root + '/models/model.pt'",
113
+ "load_dict": {
114
+ "model": "@network"
115
+ },
116
+ "strict": "True"
117
+ },
118
+ {
119
+ "_target_": "StatsHandler",
120
+ "iteration_log": false
121
+ }
122
+ ],
123
+ "evaluator": {
124
+ "_target_": "SupervisedEvaluator",
125
+ "device": "@device",
126
+ "val_data_loader": "@dataloader",
127
+ "network": "@network",
128
+ "inferer": "@inferer",
129
+ "postprocessing": "@postprocessing",
130
+ "val_handlers": "@handlers",
131
+ "amp": false
132
+ },
133
+ "evaluating": [
134
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
135
136
+ ]
137
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.2.6",
4
+ "changelog": {
5
+ "0.2.6": "update to huggingface hosting",
6
+ "0.2.5": "update large files",
7
+ "0.2.4": "fix black 24.1 format error",
8
+ "0.2.3": "update AddChanneld with EnsureChannelFirstd and remove meta_dict",
9
+ "0.2.2": "add name tag",
10
+ "0.2.1": "fix license Copyright error",
11
+ "0.2.0": "update license files",
12
+ "0.1.3": "Add training pipeline for fine-tuning models, support MONAI Label active learning",
13
+ "0.1.2": "fixed the dimension in convolution according to MONAI 1.0 update",
14
+ "0.1.1": "fixed the model state dict name",
15
+ "0.1.0": "complete the model package"
16
+ },
17
+ "monai_version": "1.4.0",
18
+ "pytorch_version": "2.4.0",
19
+ "numpy_version": "1.24.4",
20
+ "optional_packages_version": {
21
+ "nibabel": "5.2.1",
22
+ "pytorch-ignite": "0.4.11",
23
+ "einops": "0.7.0",
24
+ "fire": "0.6.0",
25
+ "timm": "0.6.7",
26
+ "torchvision": "0.19.0",
27
+ "tensorboard": "2.17.0"
28
+ },
29
+ "name": "Renal structures UNEST segmentation",
30
+ "task": "Renal segmentation",
31
+ "description": "A transformer-based model for renal segmentation from CT image",
32
+ "authors": "Vanderbilt University + MONAI team",
33
+ "copyright": "Copyright (c) MONAI Consortium",
34
+ "data_source": "RawData.zip",
35
+ "data_type": "nibabel",
36
+ "image_classes": "single channel data, intensity scaled to [0, 1]",
37
+ "label_classes": "1: Kideny Cortex, 2:Medulla, 3:Pelvicalyceal system",
38
+ "pred_classes": "1: Kideny Cortex, 2:Medulla, 3:Pelvicalyceal system",
39
+ "eval_metrics": {
40
+ "mean_dice": 0.85
41
+ },
42
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
43
+ "references": [
44
+ "Tang, Yucheng, et al. 'Self-supervised pre-training of swin transformers for 3d medical image analysis. arXiv preprint arXiv:2111.14791 (2021). https://arxiv.org/abs/2111.14791."
45
+ ],
46
+ "network_data_format": {
47
+ "inputs": {
48
+ "image": {
49
+ "type": "image",
50
+ "format": "hounsfield",
51
+ "modality": "CT",
52
+ "num_channels": 1,
53
+ "spatial_shape": [
54
+ 96,
55
+ 96,
56
+ 96
57
+ ],
58
+ "dtype": "float32",
59
+ "value_range": [
60
+ 0,
61
+ 1
62
+ ],
63
+ "is_patch_data": true,
64
+ "channel_def": {
65
+ "0": "image"
66
+ }
67
+ }
68
+ },
69
+ "outputs": {
70
+ "pred": {
71
+ "type": "image",
72
+ "format": "segmentation",
73
+ "num_channels": 4,
74
+ "spatial_shape": [
75
+ 96,
76
+ 96,
77
+ 96
78
+ ],
79
+ "dtype": "float32",
80
+ "value_range": [
81
+ 0,
82
+ 1
83
+ ],
84
+ "is_patch_data": true,
85
+ "channel_def": {
86
+ "0": "background",
87
+ "1": "kidney cortex",
88
+ "2": "medulla",
89
+ "3": "pelvicalyceal system"
90
+ }
91
+ }
92
+ }
93
+ }
94
+ }
configs/multi_gpu_train.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
3
+ "network": {
4
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
5
+ "module": "$@network_def.to(@device)",
6
+ "device_ids": [
7
+ "@device"
8
+ ]
9
+ },
10
+ "train#sampler": {
11
+ "_target_": "DistributedSampler",
12
+ "dataset": "@train#dataset",
13
+ "even_divisible": true,
14
+ "shuffle": true
15
+ },
16
+ "train#dataloader#sampler": "@train#sampler",
17
+ "train#dataloader#shuffle": false,
18
+ "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
19
+ "validate#sampler": {
20
+ "_target_": "DistributedSampler",
21
+ "dataset": "@validate#dataset",
22
+ "even_divisible": false,
23
+ "shuffle": false
24
+ },
25
+ "validate#dataloader#sampler": "@validate#sampler",
26
+ "validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers",
27
+ "training": [
28
+ "$import torch.distributed as dist",
29
+ "$dist.init_process_group(backend='nccl')",
30
+ "$torch.cuda.set_device(@device)",
31
+ "$monai.utils.set_determinism(seed=123)",
32
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
33
+ "$@train#trainer.run()",
34
+ "$dist.destroy_process_group()"
35
+ ]
36
+ }
configs/train.json ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import os",
5
+ "$import ignite"
6
+ ],
7
+ "bundle_root": "/models/renalStructures_UNEST_segmentation",
8
+ "ckpt_dir": "$@bundle_root + '/models'",
9
+ "output_dir": "$@bundle_root + '/eval'",
10
+ "dataset_dir": "$@bundle_root + './dataset'",
11
+ "images": "$list(sorted(glob.glob(@dataset_dir + '/imagesTr/*.nii.gz')))",
12
+ "labels": "$list(sorted(glob.glob(@dataset_dir + '/labelsTr/*.nii.gz')))",
13
+ "val_interval": 5,
14
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
15
+ "network_def": {
16
+ "_target_": "scripts.networks.unest.UNesT",
17
+ "in_channels": 1,
18
+ "out_channels": 4
19
+ },
20
+ "network": "$@network_def.to(@device)",
21
+ "loss": {
22
+ "_target_": "DiceCELoss",
23
+ "to_onehot_y": true,
24
+ "softmax": true,
25
+ "squared_pred": true,
26
+ "batch": true
27
+ },
28
+ "optimizer": {
29
+ "_target_": "torch.optim.Adam",
30
+ "params": "[email protected]()",
31
+ "lr": 0.0002
32
+ },
33
+ "train": {
34
+ "deterministic_transforms": [
35
+ {
36
+ "_target_": "LoadImaged",
37
+ "keys": [
38
+ "image",
39
+ "label"
40
+ ]
41
+ },
42
+ {
43
+ "_target_": "EnsureChannelFirstd",
44
+ "keys": [
45
+ "image",
46
+ "label"
47
+ ]
48
+ },
49
+ {
50
+ "_target_": "Orientationd",
51
+ "keys": [
52
+ "image",
53
+ "label"
54
+ ],
55
+ "axcodes": "RAS"
56
+ },
57
+ {
58
+ "_target_": "Spacingd",
59
+ "keys": [
60
+ "image",
61
+ "label"
62
+ ],
63
+ "pixdim": [
64
+ 1.0,
65
+ 1.0,
66
+ 1.0
67
+ ],
68
+ "mode": [
69
+ "bilinear",
70
+ "nearest"
71
+ ]
72
+ },
73
+ {
74
+ "_target_": "ScaleIntensityRanged",
75
+ "keys": "image",
76
+ "a_min": -175,
77
+ "a_max": 250,
78
+ "b_min": 0.0,
79
+ "b_max": 1.0,
80
+ "clip": true
81
+ },
82
+ {
83
+ "_target_": "EnsureTyped",
84
+ "keys": [
85
+ "image",
86
+ "label"
87
+ ]
88
+ }
89
+ ],
90
+ "random_transforms": [
91
+ {
92
+ "_target_": "RandCropByPosNegLabeld",
93
+ "keys": [
94
+ "image",
95
+ "label"
96
+ ],
97
+ "label_key": "label",
98
+ "spatial_size": [
99
+ 96,
100
+ 96,
101
+ 96
102
+ ],
103
+ "pos": 1,
104
+ "neg": 1,
105
+ "num_samples": 4,
106
+ "image_key": "image",
107
+ "image_threshold": 0
108
+ },
109
+ {
110
+ "_target_": "RandFlipd",
111
+ "keys": [
112
+ "image",
113
+ "label"
114
+ ],
115
+ "spatial_axis": [
116
+ 0
117
+ ],
118
+ "prob": 0.1
119
+ },
120
+ {
121
+ "_target_": "RandFlipd",
122
+ "keys": [
123
+ "image",
124
+ "label"
125
+ ],
126
+ "spatial_axis": [
127
+ 1
128
+ ],
129
+ "prob": 0.1
130
+ },
131
+ {
132
+ "_target_": "RandFlipd",
133
+ "keys": [
134
+ "image",
135
+ "label"
136
+ ],
137
+ "spatial_axis": [
138
+ 2
139
+ ],
140
+ "prob": 0.1
141
+ },
142
+ {
143
+ "_target_": "RandRotate90d",
144
+ "keys": [
145
+ "image",
146
+ "label"
147
+ ],
148
+ "max_k": 3,
149
+ "prob": 0.1
150
+ },
151
+ {
152
+ "_target_": "RandShiftIntensityd",
153
+ "keys": "image",
154
+ "offsets": 0.1,
155
+ "prob": 0.5
156
+ }
157
+ ],
158
+ "preprocessing": {
159
+ "_target_": "Compose",
160
+ "transforms": "$@train#deterministic_transforms + @train#random_transforms"
161
+ },
162
+ "dataset": {
163
+ "_target_": "CacheDataset",
164
+ "data": "$[{'image': i, 'label': l} for i, l in zip(@images[:-9], @labels[:-9])]",
165
+ "transform": "@train#preprocessing",
166
+ "cache_rate": 1.0,
167
+ "num_workers": 4
168
+ },
169
+ "dataloader": {
170
+ "_target_": "DataLoader",
171
+ "dataset": "@train#dataset",
172
+ "batch_size": 2,
173
+ "shuffle": true,
174
+ "num_workers": 4
175
+ },
176
+ "inferer": {
177
+ "_target_": "SimpleInferer"
178
+ },
179
+ "postprocessing": {
180
+ "_target_": "Compose",
181
+ "transforms": [
182
+ {
183
+ "_target_": "Activationsd",
184
+ "keys": "pred",
185
+ "softmax": true
186
+ },
187
+ {
188
+ "_target_": "AsDiscreted",
189
+ "keys": [
190
+ "pred",
191
+ "label"
192
+ ],
193
+ "argmax": [
194
+ true,
195
+ false
196
+ ],
197
+ "to_onehot": 4
198
+ }
199
+ ]
200
+ },
201
+ "handlers": [
202
+ {
203
+ "_target_": "ValidationHandler",
204
+ "validator": "@validate#evaluator",
205
+ "epoch_level": true,
206
+ "interval": "@val_interval"
207
+ },
208
+ {
209
+ "_target_": "StatsHandler",
210
+ "tag_name": "train_loss",
211
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
212
+ },
213
+ {
214
+ "_target_": "TensorBoardStatsHandler",
215
+ "log_dir": "@output_dir",
216
+ "tag_name": "train_loss",
217
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
218
+ }
219
+ ],
220
+ "key_metric": {
221
+ "train_accuracy": {
222
+ "_target_": "ignite.metrics.Accuracy",
223
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
224
+ }
225
+ },
226
+ "trainer": {
227
+ "_target_": "SupervisedTrainer",
228
+ "max_epochs": 1000,
229
+ "device": "@device",
230
+ "train_data_loader": "@train#dataloader",
231
+ "network": "@network",
232
+ "loss_function": "@loss",
233
+ "optimizer": "@optimizer",
234
+ "inferer": "@train#inferer",
235
+ "postprocessing": "@train#postprocessing",
236
+ "key_train_metric": "@train#key_metric",
237
+ "train_handlers": "@train#handlers",
238
+ "amp": true
239
+ }
240
+ },
241
+ "validate": {
242
+ "preprocessing": {
243
+ "_target_": "Compose",
244
+ "transforms": "%train#deterministic_transforms"
245
+ },
246
+ "dataset": {
247
+ "_target_": "CacheDataset",
248
+ "data": "$[{'image': i, 'label': l} for i, l in zip(@images[-9:], @labels[-9:])]",
249
+ "transform": "@validate#preprocessing",
250
+ "cache_rate": 1.0
251
+ },
252
+ "dataloader": {
253
+ "_target_": "DataLoader",
254
+ "dataset": "@validate#dataset",
255
+ "batch_size": 1,
256
+ "shuffle": false,
257
+ "num_workers": 4
258
+ },
259
+ "inferer": {
260
+ "_target_": "SlidingWindowInferer",
261
+ "roi_size": [
262
+ 96,
263
+ 96,
264
+ 96
265
+ ],
266
+ "sw_batch_size": 4,
267
+ "overlap": 0.5
268
+ },
269
+ "postprocessing": "%train#postprocessing",
270
+ "handlers": [
271
+ {
272
+ "_target_": "StatsHandler",
273
+ "iteration_log": false
274
+ },
275
+ {
276
+ "_target_": "TensorBoardStatsHandler",
277
+ "log_dir": "@output_dir",
278
+ "iteration_log": false
279
+ },
280
+ {
281
+ "_target_": "CheckpointSaver",
282
+ "save_dir": "@ckpt_dir",
283
+ "save_dict": {
284
+ "model": "@network"
285
+ },
286
+ "save_key_metric": true,
287
+ "key_metric_filename": "model.pt"
288
+ }
289
+ ],
290
+ "key_metric": {
291
+ "val_mean_dice": {
292
+ "_target_": "MeanDice",
293
+ "include_background": false,
294
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
295
+ }
296
+ },
297
+ "additional_metrics": {
298
+ "val_accuracy": {
299
+ "_target_": "ignite.metrics.Accuracy",
300
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
301
+ }
302
+ },
303
+ "evaluator": {
304
+ "_target_": "SupervisedEvaluator",
305
+ "device": "@device",
306
+ "val_data_loader": "@validate#dataloader",
307
+ "network": "@network",
308
+ "inferer": "@validate#inferer",
309
+ "postprocessing": "@validate#postprocessing",
310
+ "key_val_metric": "@validate#key_metric",
311
+ "additional_metrics": "@validate#additional_metrics",
312
+ "val_handlers": "@validate#handlers",
313
+ "amp": true
314
+ }
315
+ },
316
+ "training": [
317
+ "$monai.utils.set_determinism(seed=123)",
318
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
319
+ "$@train#trainer.run()"
320
+ ]
321
+ }
docs/README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description
2
+ A pre-trained model for training and inferencing volumetric (3D) kidney substructures segmentation from contrast-enhanced CT images (Arterial/Portal Venous Phase). Training pipeline is provided to support model fine-tuning with bundle and MONAI Label active learning.
3
+
4
+ A tutorial and release of model for kidney cortex, medulla and collecting system segmentation.
5
+
6
+ Authors: Yinchi Zhou ([email protected]) | Xin Yu ([email protected]) | Yucheng Tang ([email protected]) |
7
+
8
+
9
+ # Model Overview
10
+ A pre-trained UNEST base model [1] for volumetric (3D) renal structures segmentation using dynamic contrast enhanced arterial or venous phase CT images.
11
+
12
+ ## Data
13
+ The training data is from the [ImageVU RenalSeg dataset] from Vanderbilt University and Vanderbilt University Medical Center.
14
+ (The training data is not public available yet).
15
+
16
+ - Target: Renal Cortex | Medulla | Pelvis Collecting System
17
+ - Task: Segmentation
18
+ - Modality: CT (Artrial | Venous phase)
19
+ - Size: 96 3D volumes
20
+
21
+
22
+ The data and segmentation demonstration is as follow:
23
+
24
+ ![](./renal.png) <br>
25
+
26
+ ## Method and Network
27
+
28
+ The UNEST model is a 3D hierarchical transformer-based semgnetation network.
29
+
30
+ Details of the architecture:
31
+ ![](./unest.png) <br>
32
+
33
+ ## Training configuration
34
+ The training was performed with at least one 16GB-memory GPU.
35
+
36
+ Actual Model Input: 96 x 96 x 96
37
+
38
+ ## Input and output formats
39
+ Input: 1 channel CT image
40
+
41
+ Output: 4: 0:Background, 1:Renal Cortex, 2:Medulla, 3:Pelvicalyceal System
42
+
43
+ ## Performance
44
+ A graph showing the validation mean Dice for 5000 epochs.
45
+
46
+ ![](./val_dice.png) <br>
47
+
48
+ This model achieves the following Dice score on the validation data (our own split from the training dataset):
49
+
50
+ Mean Valdiation Dice = 0.8523
51
+
52
+ Note that mean dice is computed in the original spacing of the input data.
53
+
54
+ ## commands example
55
+ Download trained checkpoint model to ./model/model.pt:
56
+
57
+
58
+ Add scripts component: To run the workflow with customized components, PYTHONPATH should be revised to include the path to the customized component:
59
+
60
+ ```
61
+ export PYTHONPATH=$PYTHONPATH:"'<path to the bundle root dir>/scripts'"
62
+
63
+ ```
64
+ Execute Training:
65
+
66
+ ```
67
+ python -m monai.bundle run training --meta_file configs/metadata.json --config_file configs/train.json --logging_file configs/logging.conf
68
+ ```
69
+
70
+ Execute inference:
71
+
72
+ ```
73
+ python -m monai.bundle run evaluating --meta_file configs/metadata.json --config_file configs/inference.json --logging_file configs/logging.conf
74
+ ```
75
+
76
+
77
+ ## More examples output
78
+
79
+ ![](./demos.png) <br>
80
+
81
+
82
+ # Disclaimer
83
+ This is an example, not to be used for diagnostic purposes.
84
+
85
+ # References
86
+ [1] Yu, Xin, Yinchi Zhou, Yucheng Tang et al. "Characterizing Renal Structures with 3D Block Aggregate Transformers." arXiv preprint arXiv:2203.02430 (2022). https://arxiv.org/pdf/2203.02430.pdf
87
+
88
+ [2] Zizhao Zhang et al. "Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding." AAAI Conference on Artificial Intelligence (AAAI) 2022
89
+
90
+ # License
91
+ Copyright (c) MONAI Consortium
92
+
93
+ Licensed under the Apache License, Version 2.0 (the "License");
94
+ you may not use this file except in compliance with the License.
95
+ You may obtain a copy of the License at
96
+
97
+ http://www.apache.org/licenses/LICENSE-2.0
98
+
99
+ Unless required by applicable law or agreed to in writing, software
100
+ distributed under the License is distributed on an "AS IS" BASIS,
101
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
102
+ See the License for the specific language governing permissions and
103
+ limitations under the License.
docs/demos.png ADDED

Git LFS Details

  • SHA256: fe4fb5b171619b0c3a1eacf404acab7bfae1b42ca7cc1991e442e6d622d1af00
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB
docs/renal.png ADDED

Git LFS Details

  • SHA256: fa598f7b3176d1570c323866d710522cfe8ca41d295e20cd2908e481e06d631d
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
docs/unest.png ADDED
docs/val_dice.png ADDED
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8928e88771d31945c51d1b302a8448825e6f9861a543a6e1023acb9576840962
3
+ size 348887167
scripts/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/networks/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
scripts/networks/nest/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from .utils import (
3
+ Conv3dSame,
4
+ DropPath,
5
+ Linear,
6
+ Mlp,
7
+ _assert,
8
+ conv3d_same,
9
+ create_conv3d,
10
+ create_pool3d,
11
+ get_padding,
12
+ get_same_padding,
13
+ pad_same,
14
+ to_ntuple,
15
+ trunc_normal_,
16
+ )
scripts/networks/nest/utils.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import collections.abc
5
+ import math
6
+ import warnings
7
+ from itertools import repeat
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from torch import _assert
16
+ except ImportError:
17
+
18
+ def _assert(condition: bool, message: str):
19
+ assert condition, message
20
+
21
+
22
+ def drop_block_2d(
23
+ x,
24
+ drop_prob: float = 0.1,
25
+ block_size: int = 7,
26
+ gamma_scale: float = 1.0,
27
+ with_noise: bool = False,
28
+ inplace: bool = False,
29
+ batchwise: bool = False,
30
+ ):
31
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
32
+
33
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
34
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
35
+ """
36
+ b, c, h, w = x.shape
37
+ total_size = w * h
38
+ clipped_block_size = min(block_size, min(w, h))
39
+ # seed_drop_rate, the gamma parameter
40
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
41
+
42
+ # Forces the block to be inside the feature map.
43
+ w_i, h_i = torch.meshgrid(torch.arange(w).to(x.device), torch.arange(h).to(x.device))
44
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < w - (clipped_block_size - 1) // 2)) & (
45
+ (h_i >= clipped_block_size // 2) & (h_i < h - (clipped_block_size - 1) // 2)
46
+ )
47
+ valid_block = torch.reshape(valid_block, (1, 1, h, w)).to(dtype=x.dtype)
48
+
49
+ if batchwise:
50
+ # one mask for whole batch, quite a bit faster
51
+ uniform_noise = torch.rand((1, c, h, w), dtype=x.dtype, device=x.device)
52
+ else:
53
+ uniform_noise = torch.rand_like(x)
54
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
55
+ block_mask = -F.max_pool2d(
56
+ -block_mask, kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2 # block_size,
57
+ )
58
+
59
+ if with_noise:
60
+ normal_noise = torch.randn((1, c, h, w), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
61
+ if inplace:
62
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
63
+ else:
64
+ x = x * block_mask + normal_noise * (1 - block_mask)
65
+ else:
66
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
67
+ if inplace:
68
+ x.mul_(block_mask * normalize_scale)
69
+ else:
70
+ x = x * block_mask * normalize_scale
71
+ return x
72
+
73
+
74
+ def drop_block_fast_2d(
75
+ x: torch.Tensor,
76
+ drop_prob: float = 0.1,
77
+ block_size: int = 7,
78
+ gamma_scale: float = 1.0,
79
+ with_noise: bool = False,
80
+ inplace: bool = False,
81
+ ):
82
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
83
+
84
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
85
+ block mask at edges.
86
+ """
87
+ b, c, h, w = x.shape
88
+ total_size = w * h
89
+ clipped_block_size = min(block_size, min(w, h))
90
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size**2 / ((w - block_size + 1) * (h - block_size + 1))
91
+
92
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
93
+ block_mask = F.max_pool2d(
94
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2
95
+ )
96
+
97
+ if with_noise:
98
+ normal_noise = torch.empty_like(x).normal_()
99
+ if inplace:
100
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
101
+ else:
102
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
103
+ else:
104
+ block_mask = 1 - block_mask
105
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
106
+ if inplace:
107
+ x.mul_(block_mask * normalize_scale)
108
+ else:
109
+ x = x * block_mask * normalize_scale
110
+ return x
111
+
112
+
113
+ class DropBlock2d(nn.Module):
114
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
115
+
116
+ def __init__(
117
+ self, drop_prob=0.1, block_size=7, gamma_scale=1.0, with_noise=False, inplace=False, batchwise=False, fast=True
118
+ ):
119
+ super(DropBlock2d, self).__init__()
120
+ self.drop_prob = drop_prob
121
+ self.gamma_scale = gamma_scale
122
+ self.block_size = block_size
123
+ self.with_noise = with_noise
124
+ self.inplace = inplace
125
+ self.batchwise = batchwise
126
+ self.fast = fast # FIXME finish comparisons of fast vs not
127
+
128
+ def forward(self, x):
129
+ if not self.training or not self.drop_prob:
130
+ return x
131
+ if self.fast:
132
+ return drop_block_fast_2d(
133
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace
134
+ )
135
+ else:
136
+ return drop_block_2d(
137
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise
138
+ )
139
+
140
+
141
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
142
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
143
+
144
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
145
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
146
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
147
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
148
+ 'survival rate' as the argument.
149
+
150
+ """
151
+ if drop_prob == 0.0 or not training:
152
+ return x
153
+ keep_prob = 1 - drop_prob
154
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
155
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
156
+ if keep_prob > 0.0 and scale_by_keep:
157
+ random_tensor.div_(keep_prob)
158
+ return x * random_tensor
159
+
160
+
161
+ class DropPath(nn.Module):
162
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
163
+
164
+ def __init__(self, drop_prob=None, scale_by_keep=True):
165
+ super(DropPath, self).__init__()
166
+ self.drop_prob = drop_prob
167
+ self.scale_by_keep = scale_by_keep
168
+
169
+ def forward(self, x):
170
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
171
+
172
+
173
+ def create_conv3d(in_channels, out_channels, kernel_size, **kwargs):
174
+ """Select a 2d convolution implementation based on arguments
175
+ Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv3d, or CondConv2d.
176
+
177
+ Used extensively by EfficientNet, MobileNetv3 and related networks.
178
+ """
179
+
180
+ depthwise = kwargs.pop("depthwise", False)
181
+ # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
182
+ groups = in_channels if depthwise else kwargs.pop("groups", 1)
183
+
184
+ m = create_conv3d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
185
+ return m
186
+
187
+
188
+ def conv3d_same(
189
+ x,
190
+ weight: torch.Tensor,
191
+ bias: Optional[torch.Tensor] = None,
192
+ stride: Tuple[int, int] = (1, 1, 1),
193
+ padding: Tuple[int, int] = (0, 0, 0),
194
+ dilation: Tuple[int, int] = (1, 1, 1),
195
+ groups: int = 1,
196
+ ):
197
+ x = pad_same(x, weight.shape[-3:], stride, dilation)
198
+ return F.conv3d(x, weight, bias, stride, (0, 0, 0), dilation, groups)
199
+
200
+
201
+ class Conv3dSame(nn.Conv2d):
202
+ """Tensorflow like 'SAME' convolution wrapper for 2D convolutions"""
203
+
204
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
205
+ super(Conv3dSame, self).__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
206
+
207
+ def forward(self, x):
208
+ return conv3d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
209
+
210
+
211
+ def create_conv3d_pad(in_chs, out_chs, kernel_size, **kwargs):
212
+ padding = kwargs.pop("padding", "")
213
+ kwargs.setdefault("bias", False)
214
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
215
+ if is_dynamic:
216
+ return Conv3dSame(in_chs, out_chs, kernel_size, **kwargs)
217
+ else:
218
+ return nn.Conv3d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
219
+
220
+
221
+ # Calculate symmetric padding for a convolution
222
+ def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
223
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
224
+ return padding
225
+
226
+
227
+ # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
228
+ def get_same_padding(x: int, k: int, s: int, d: int):
229
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
230
+
231
+
232
+ # Can SAME padding for given args be done statically?
233
+ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
234
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
235
+
236
+
237
+ # Dynamically pad input x with 'SAME' padding for conv with specified args
238
+ def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1, 1), value: float = 0):
239
+ id, ih, iw = x.size()[-3:]
240
+ pad_d, pad_h, pad_w = (
241
+ get_same_padding(id, k[0], s[0], d[0]),
242
+ get_same_padding(ih, k[1], s[1], d[1]),
243
+ get_same_padding(iw, k[2], s[2], d[2]),
244
+ )
245
+ if pad_d > 0 or pad_h > 0 or pad_w > 0:
246
+ x = F.pad(
247
+ x,
248
+ [pad_d // 2, pad_d - pad_d // 2, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
249
+ value=value,
250
+ )
251
+ return x
252
+
253
+
254
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
255
+ dynamic = False
256
+ if isinstance(padding, str):
257
+ # for any string padding, the padding will be calculated for you, one of three ways
258
+ padding = padding.lower()
259
+ if padding == "same":
260
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
261
+ if is_static_pad(kernel_size, **kwargs):
262
+ # static case, no extra overhead
263
+ padding = get_padding(kernel_size, **kwargs)
264
+ else:
265
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
266
+ padding = 0
267
+ dynamic = True
268
+ elif padding == "valid":
269
+ # 'VALID' padding, same as padding=0
270
+ padding = 0
271
+ else:
272
+ # Default to PyTorch style 'same'-ish symmetric padding
273
+ padding = get_padding(kernel_size, **kwargs)
274
+ return padding, dynamic
275
+
276
+
277
+ # From PyTorch internals
278
+ def _ntuple(n):
279
+ def parse(x):
280
+ if isinstance(x, collections.abc.Iterable):
281
+ return x
282
+ return tuple(repeat(x, n))
283
+
284
+ return parse
285
+
286
+
287
+ to_1tuple = _ntuple(1)
288
+ to_2tuple = _ntuple(2)
289
+ to_3tuple = _ntuple(3)
290
+ to_4tuple = _ntuple(4)
291
+ to_ntuple = _ntuple
292
+
293
+
294
+ def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
295
+ min_value = min_value or divisor
296
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
297
+ # Make sure that round down does not go down by more than 10%.
298
+ if new_v < round_limit * v:
299
+ new_v += divisor
300
+ return new_v
301
+
302
+
303
+ class Linear(nn.Linear):
304
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
305
+
306
+ Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
307
+ weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
308
+ """
309
+
310
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
311
+ if torch.jit.is_scripting():
312
+ bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
313
+ return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
314
+ else:
315
+ return F.linear(input, self.weight, self.bias)
316
+
317
+
318
+ class Mlp(nn.Module):
319
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
320
+
321
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
322
+ super().__init__()
323
+ out_features = out_features or in_features
324
+ hidden_features = hidden_features or in_features
325
+ drop_probs = to_2tuple(drop)
326
+
327
+ self.fc1 = nn.Linear(in_features, hidden_features)
328
+ self.act = act_layer()
329
+ self.drop1 = nn.Dropout(drop_probs[0])
330
+ self.fc2 = nn.Linear(hidden_features, out_features)
331
+ self.drop2 = nn.Dropout(drop_probs[1])
332
+
333
+ def forward(self, x):
334
+ x = self.fc1(x)
335
+ x = self.act(x)
336
+ x = self.drop1(x)
337
+ x = self.fc2(x)
338
+ x = self.drop2(x)
339
+ return x
340
+
341
+
342
+ def avg_pool3d_same(
343
+ x,
344
+ kernel_size: List[int],
345
+ stride: List[int],
346
+ padding: List[int] = (0, 0, 0),
347
+ ceil_mode: bool = False,
348
+ count_include_pad: bool = True,
349
+ ):
350
+ # FIXME how to deal with count_include_pad vs not for external padding?
351
+ x = pad_same(x, kernel_size, stride)
352
+ return F.avg_pool3d(x, kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
353
+
354
+
355
+ class AvgPool3dSame(nn.AvgPool2d):
356
+ """Tensorflow like 'SAME' wrapper for 2D average pooling"""
357
+
358
+ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
359
+ kernel_size = to_2tuple(kernel_size)
360
+ stride = to_2tuple(stride)
361
+ super(AvgPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), ceil_mode, count_include_pad)
362
+
363
+ def forward(self, x):
364
+ x = pad_same(x, self.kernel_size, self.stride)
365
+ return F.avg_pool3d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
366
+
367
+
368
+ def max_pool3d_same(
369
+ x,
370
+ kernel_size: List[int],
371
+ stride: List[int],
372
+ padding: List[int] = (0, 0, 0),
373
+ dilation: List[int] = (1, 1, 1),
374
+ ceil_mode: bool = False,
375
+ ):
376
+ x = pad_same(x, kernel_size, stride, value=-float("inf"))
377
+ return F.max_pool3d(x, kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
378
+
379
+
380
+ class MaxPool3dSame(nn.MaxPool2d):
381
+ """Tensorflow like 'SAME' wrapper for 3D max pooling"""
382
+
383
+ def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
384
+ kernel_size = to_2tuple(kernel_size)
385
+ stride = to_2tuple(stride)
386
+ dilation = to_2tuple(dilation)
387
+ super(MaxPool3dSame, self).__init__(kernel_size, stride, (0, 0, 0), dilation, ceil_mode)
388
+
389
+ def forward(self, x):
390
+ x = pad_same(x, self.kernel_size, self.stride, value=-float("inf"))
391
+ return F.max_pool3d(x, self.kernel_size, self.stride, (0, 0, 0), self.dilation, self.ceil_mode)
392
+
393
+
394
+ def create_pool3d(pool_type, kernel_size, stride=None, **kwargs):
395
+ stride = stride or kernel_size
396
+ padding = kwargs.pop("padding", "")
397
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
398
+ if is_dynamic:
399
+ if pool_type == "avg":
400
+ return AvgPool3dSame(kernel_size, stride=stride, **kwargs)
401
+ elif pool_type == "max":
402
+ return MaxPool3dSame(kernel_size, stride=stride, **kwargs)
403
+ else:
404
+ raise AssertionError()
405
+
406
+ # assert False, f"Unsupported pool type {pool_type}"
407
+ else:
408
+ if pool_type == "avg":
409
+ return nn.AvgPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
410
+ elif pool_type == "max":
411
+ return nn.MaxPool3d(kernel_size, stride=stride, padding=padding, **kwargs)
412
+ else:
413
+ raise AssertionError()
414
+
415
+ # assert False, f"Unsupported pool type {pool_type}"
416
+
417
+
418
+ def _float_to_int(x: float) -> int:
419
+ """
420
+ Symbolic tracing helper to substitute for inbuilt `int`.
421
+ Hint: Inbuilt `int` can't accept an argument of type `Proxy`
422
+ """
423
+ return int(x)
424
+
425
+
426
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
427
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
428
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
429
+ def norm_cdf(x):
430
+ # Computes standard normal cumulative distribution function
431
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
432
+
433
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
434
+ warnings.warn(
435
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
436
+ "The distribution of values may be incorrect.",
437
+ stacklevel=2,
438
+ )
439
+
440
+ with torch.no_grad():
441
+ # Values are generated by using a truncated uniform distribution and
442
+ # then using the inverse CDF for the normal distribution.
443
+ # Get upper and lower cdf values
444
+ l = norm_cdf((a - mean) / std)
445
+ u = norm_cdf((b - mean) / std)
446
+
447
+ # Uniformly fill tensor with values from [l, u], then translate to
448
+ # [2l-1, 2u-1].
449
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
450
+
451
+ # Use inverse cdf transform for normal distribution to get truncated
452
+ # standard normal
453
+ tensor.erfinv_()
454
+
455
+ # Transform to proper mean, std
456
+ tensor.mul_(std * math.sqrt(2.0))
457
+ tensor.add_(mean)
458
+
459
+ # Clamp to ensure it's in the proper range
460
+ tensor.clamp_(min=a, max=b)
461
+ return tensor
462
+
463
+
464
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
465
+ r"""Fills the input Tensor with values drawn from a truncated
466
+ normal distribution. The values are effectively drawn from the
467
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
468
+ with values outside :math:`[a, b]` redrawn until they are within
469
+ the bounds. The method used for generating the random values works
470
+ best when :math:`a \leq \text{mean} \leq b`.
471
+ Args:
472
+ tensor: an n-dimensional `torch.Tensor`
473
+ mean: the mean of the normal distribution
474
+ std: the standard deviation of the normal distribution
475
+ a: the minimum cutoff value
476
+ b: the maximum cutoff value
477
+ Examples:
478
+ >>> w = torch.empty(3, 5)
479
+ >>> nn.init.trunc_normal_(w)
480
+ """
481
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
scripts/networks/nest_transformer_3D.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # =========================================================================
4
+ # Adapted from https://github.com/google-research/nested-transformer.
5
+ # which has the following license...
6
+ # https://github.com/pytorch/vision/blob/main/LICENSE
7
+ #
8
+ # BSD 3-Clause License
9
+
10
+
11
+ # Redistribution and use in source and binary forms, with or without
12
+ # modification, are permitted provided that the following conditions are met:
13
+
14
+ # * Redistributions of source code must retain the above copyright notice, this
15
+ # list of conditions and the following disclaimer.
16
+
17
+ # * Redistributions in binary form must reproduce the above copyright notice,
18
+ # this list of conditions and the following disclaimer in the documentation
19
+ # and/or other materials provided with the distribution.
20
+
21
+ # * Neither the name of the copyright holder nor the names of its
22
+ # contributors may be used to endorse or promote products derived from
23
+ # this software without specific prior written permission.
24
+
25
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+
36
+ """ Nested Transformer (NesT) in PyTorch
37
+ A PyTorch implement of Aggregating Nested Transformers as described in:
38
+ 'Aggregating Nested Transformers'
39
+ - https://arxiv.org/abs/2105.12723
40
+ The official Jax code is released and available at https://github.com/google-research/nested-transformer.
41
+ The weights have been converted with convert/convert_nest_flax.py
42
+ Acknowledgments:
43
+ * The paper authors for sharing their research, code, and model weights
44
+ * Ross Wightman's existing code off which I based this
45
+ Copyright 2021 Alexander Soare
46
+
47
+ """
48
+
49
+ import collections.abc
50
+ import logging
51
+ import math
52
+ from functools import partial
53
+ from typing import Callable, Sequence
54
+
55
+ import torch
56
+ import torch.nn.functional as F
57
+ from torch import nn
58
+
59
+ from .nest import DropPath, Mlp, _assert, create_conv3d, create_pool3d, to_ntuple, trunc_normal_
60
+ from .patchEmbed3D import PatchEmbed3D
61
+
62
+ _logger = logging.getLogger(__name__)
63
+
64
+
65
+ class Attention(nn.Module):
66
+ """
67
+ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
68
+ an extra "image block" dim
69
+ """
70
+
71
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
72
+ super().__init__()
73
+ self.num_heads = num_heads
74
+ head_dim = dim // num_heads
75
+ self.scale = head_dim**-0.5
76
+
77
+ self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
78
+ self.attn_drop = nn.Dropout(attn_drop)
79
+ self.proj = nn.Linear(dim, dim)
80
+ self.proj_drop = nn.Dropout(proj_drop)
81
+
82
+ def forward(self, x):
83
+ """
84
+ x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
85
+ """
86
+ b, t, n, c = x.shape
87
+ # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
88
+ qkv = self.qkv(x).reshape(b, t, n, 3, self.num_heads, c // self.num_heads).permute(3, 0, 4, 1, 2, 5)
89
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
90
+
91
+ attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
92
+ attn = attn.softmax(dim=-1)
93
+ attn = self.attn_drop(attn)
94
+
95
+ x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(b, t, n, c)
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x # (B, T, N, C)
99
+
100
+
101
+ class TransformerLayer(nn.Module):
102
+ """
103
+ This is much like `.vision_transformer.Block` but:
104
+ - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
105
+ - Uses modified Attention layer that handles the "block" dimension
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim,
111
+ num_heads,
112
+ mlp_ratio=4.0,
113
+ qkv_bias=False,
114
+ drop=0.0,
115
+ attn_drop=0.0,
116
+ drop_path=0.0,
117
+ act_layer=nn.GELU,
118
+ norm_layer=nn.LayerNorm,
119
+ ):
120
+ super().__init__()
121
+ self.norm1 = norm_layer(dim)
122
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
127
+
128
+ def forward(self, x):
129
+ y = self.norm1(x)
130
+ x = x + self.drop_path(self.attn(y))
131
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
132
+ return x
133
+
134
+
135
+ class ConvPool(nn.Module):
136
+ def __init__(self, in_channels, out_channels, norm_layer, pad_type=""):
137
+ super().__init__()
138
+ self.conv = create_conv3d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
139
+ self.norm = norm_layer(out_channels)
140
+ self.pool = create_pool3d("max", kernel_size=3, stride=2, padding=pad_type)
141
+
142
+ def forward(self, x):
143
+ """
144
+ x is expected to have shape (B, C, D, H, W)
145
+ """
146
+ _assert(x.shape[-3] % 2 == 0, "BlockAggregation requires even input spatial dims")
147
+ _assert(x.shape[-2] % 2 == 0, "BlockAggregation requires even input spatial dims")
148
+ _assert(x.shape[-1] % 2 == 0, "BlockAggregation requires even input spatial dims")
149
+
150
+ # print('In ConvPool x : {}'.format(x.shape))
151
+ x = self.conv(x)
152
+ # Layer norm done over channel dim only
153
+ x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
154
+ x = self.pool(x)
155
+ return x # (B, C, D//2, H//2, W//2)
156
+
157
+
158
+ def blockify(x, block_size: int):
159
+ """image to blocks
160
+ Args:
161
+ x (Tensor): with shape (B, D, H, W, C)
162
+ block_size (int): edge length of a single square block in units of D, H, W
163
+ """
164
+ b, d, h, w, c = x.shape
165
+ _assert(d % block_size == 0, "`block_size` must divide input depth evenly")
166
+ _assert(h % block_size == 0, "`block_size` must divide input height evenly")
167
+ _assert(w % block_size == 0, "`block_size` must divide input width evenly")
168
+ grid_depth = d // block_size
169
+ grid_height = h // block_size
170
+ grid_width = w // block_size
171
+ x = x.reshape(b, grid_depth, block_size, grid_height, block_size, grid_width, block_size, c)
172
+
173
+ x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
174
+ b, grid_depth * grid_height * grid_width, -1, c
175
+ ) # shape [2, 512, 27, 128]
176
+
177
+ return x # (B, T, N, C)
178
+
179
+
180
+ # @register_notrace_function # reason: int receives Proxy
181
+ def deblockify(x, block_size: int):
182
+ """blocks to image
183
+ Args:
184
+ x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
185
+ block_size (int): edge length of a single square block in units of desired D, H, W
186
+ """
187
+ b, t, _, c = x.shape
188
+ grid_size = round(math.pow(t, 1 / 3))
189
+ depth = height = width = grid_size * block_size
190
+ x = x.reshape(b, grid_size, grid_size, grid_size, block_size, block_size, block_size, c)
191
+
192
+ x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, depth, height, width, c)
193
+
194
+ return x # (B, D, H, W, C)
195
+
196
+
197
+ class NestLevel(nn.Module):
198
+ """Single hierarchical level of a Nested Transformer"""
199
+
200
+ def __init__(
201
+ self,
202
+ num_blocks,
203
+ block_size,
204
+ seq_length,
205
+ num_heads,
206
+ depth,
207
+ embed_dim,
208
+ prev_embed_dim=None,
209
+ mlp_ratio=4.0,
210
+ qkv_bias=True,
211
+ drop_rate=0.0,
212
+ attn_drop_rate=0.0,
213
+ drop_path_rates: Sequence[int] = (),
214
+ norm_layer=None,
215
+ act_layer=None,
216
+ pad_type="",
217
+ ):
218
+ super().__init__()
219
+ self.block_size = block_size
220
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
221
+
222
+ if prev_embed_dim is not None:
223
+ self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
224
+ else:
225
+ self.pool = nn.Identity()
226
+
227
+ # Transformer encoder
228
+ if len(drop_path_rates):
229
+ assert len(drop_path_rates) == depth, "Must provide as many drop path rates as there are transformer layers"
230
+ self.transformer_encoder = nn.Sequential(
231
+ *[
232
+ TransformerLayer(
233
+ dim=embed_dim,
234
+ num_heads=num_heads,
235
+ mlp_ratio=mlp_ratio,
236
+ qkv_bias=qkv_bias,
237
+ drop=drop_rate,
238
+ attn_drop=attn_drop_rate,
239
+ drop_path=drop_path_rates[i],
240
+ norm_layer=norm_layer,
241
+ act_layer=act_layer,
242
+ )
243
+ for i in range(depth)
244
+ ]
245
+ )
246
+
247
+ def forward(self, x):
248
+ """
249
+ expects x as (B, C, D, H, W)
250
+ """
251
+ x = self.pool(x)
252
+ x = x.permute(0, 2, 3, 4, 1) # (B, H', W', C), switch to channels last for transformer
253
+
254
+ x = blockify(x, self.block_size) # (B, T, N, C')
255
+ x = x + self.pos_embed
256
+
257
+ x = self.transformer_encoder(x) # (B, ,T, N, C')
258
+
259
+ x = deblockify(x, self.block_size) # (B, D', H', W', C') [2, 24, 24, 24, 128]
260
+ # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
261
+ return x.permute(0, 4, 1, 2, 3) # (B, C, D', H', W')
262
+
263
+
264
+ class NestTransformer3D(nn.Module):
265
+ """Nested Transformer (NesT)
266
+ A PyTorch impl of : `Aggregating Nested Transformers`
267
+ - https://arxiv.org/abs/2105.12723
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ img_size=96,
273
+ in_chans=1,
274
+ patch_size=2,
275
+ num_levels=3,
276
+ embed_dims=(128, 256, 512),
277
+ num_heads=(4, 8, 16),
278
+ depths=(2, 2, 20),
279
+ num_classes=1000,
280
+ mlp_ratio=4.0,
281
+ qkv_bias=True,
282
+ drop_rate=0.0,
283
+ attn_drop_rate=0.0,
284
+ drop_path_rate=0.5,
285
+ norm_layer=None,
286
+ act_layer=None,
287
+ pad_type="",
288
+ weight_init="",
289
+ global_pool="avg",
290
+ ):
291
+ """
292
+ Args:
293
+ img_size (int, tuple): input image size
294
+ in_chans (int): number of input channels
295
+ patch_size (int): patch size
296
+ num_levels (int): number of block hierarchies (T_d in the paper)
297
+ embed_dims (int, tuple): embedding dimensions of each level
298
+ num_heads (int, tuple): number of attention heads for each level
299
+ depths (int, tuple): number of transformer layers for each level
300
+ num_classes (int): number of classes for classification head
301
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers
302
+ qkv_bias (bool): enable bias for qkv if True
303
+ drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier
304
+ attn_drop_rate (float): attention dropout rate
305
+ drop_path_rate (float): stochastic depth rate
306
+ norm_layer: (nn.Module): normalization layer for transformer layers
307
+ act_layer: (nn.Module): activation layer in MLP of transformer layers
308
+ pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
309
+ weight_init: (str): weight init scheme
310
+ global_pool: (str): type of pooling operation to apply to final feature map
311
+ Notes:
312
+ - Default values follow NesT-B from the original Jax code.
313
+ - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`.
314
+ - For those following the paper, Table A1 may have errors!
315
+ - https://github.com/google-research/nested-transformer/issues/2
316
+ """
317
+ super().__init__()
318
+
319
+ for param_name in ["embed_dims", "num_heads", "depths"]:
320
+ param_value = locals()[param_name]
321
+ if isinstance(param_value, collections.abc.Sequence):
322
+ assert len(param_value) == num_levels, f"Require `len({param_name}) == num_levels`"
323
+
324
+ embed_dims = to_ntuple(num_levels)(embed_dims)
325
+ num_heads = to_ntuple(num_levels)(num_heads)
326
+ depths = to_ntuple(num_levels)(depths)
327
+ self.num_classes = num_classes
328
+ self.num_features = embed_dims[-1]
329
+ self.feature_info = []
330
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
331
+ act_layer = act_layer or nn.GELU
332
+ self.drop_rate = drop_rate
333
+ self.num_levels = num_levels
334
+ if isinstance(img_size, collections.abc.Sequence):
335
+ assert img_size[0] == img_size[1], "Model only handles square inputs"
336
+ img_size = img_size[0]
337
+ assert img_size % patch_size == 0, "`patch_size` must divide `img_size` evenly"
338
+ self.patch_size = patch_size
339
+
340
+ # Number of blocks at each level
341
+ self.num_blocks = (8 ** torch.arange(num_levels)).flip(0).tolist()
342
+ assert (img_size // patch_size) % round(
343
+ math.pow(self.num_blocks[0], 1 / 3)
344
+ ) == 0, "First level blocks don't fit evenly. Check `img_size`, `patch_size`, and `num_levels`"
345
+
346
+ # Block edge size in units of patches
347
+ # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
348
+ # number of blocks along edge of image
349
+ self.block_size = int((img_size // patch_size) // round(math.pow(self.num_blocks[0], 1 / 3)))
350
+
351
+ # Patch embedding
352
+ self.patch_embed = PatchEmbed3D(
353
+ img_size=[img_size, img_size, img_size],
354
+ patch_size=[patch_size, patch_size, patch_size],
355
+ in_chans=in_chans,
356
+ embed_dim=embed_dims[0],
357
+ )
358
+ self.num_patches = self.patch_embed.num_patches
359
+ self.seq_length = self.num_patches // self.num_blocks[0]
360
+ # Build up each hierarchical level
361
+ levels = []
362
+
363
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
364
+ prev_dim = None
365
+ curr_stride = 4
366
+ for i in range(len(self.num_blocks)):
367
+ dim = embed_dims[i]
368
+ levels.append(
369
+ NestLevel(
370
+ self.num_blocks[i],
371
+ self.block_size,
372
+ self.seq_length,
373
+ num_heads[i],
374
+ depths[i],
375
+ dim,
376
+ prev_dim,
377
+ mlp_ratio,
378
+ qkv_bias,
379
+ drop_rate,
380
+ attn_drop_rate,
381
+ dp_rates[i],
382
+ norm_layer,
383
+ act_layer,
384
+ pad_type=pad_type,
385
+ )
386
+ )
387
+ self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f"levels.{i}")]
388
+ prev_dim = dim
389
+ curr_stride *= 2
390
+
391
+ self.levels = nn.ModuleList([levels[i] for i in range(num_levels)])
392
+
393
+ # Final normalization layer
394
+ self.norm = norm_layer(embed_dims[-1])
395
+
396
+ self.init_weights(weight_init)
397
+
398
+ def init_weights(self, mode=""):
399
+ assert mode in ("nlhb", "")
400
+ head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
401
+ for level in self.levels:
402
+ trunc_normal_(level.pos_embed, std=0.02, a=-2, b=2)
403
+ named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
404
+
405
+ @torch.jit.ignore
406
+ def no_weight_decay(self):
407
+ return {f"level.{i}.pos_embed" for i in range(len(self.levels))}
408
+
409
+ def get_classifier(self):
410
+ return self.head
411
+
412
+ def forward_features(self, x):
413
+ """x shape (B, C, D, H, W)"""
414
+ x = self.patch_embed(x)
415
+
416
+ hidden_states_out = [x]
417
+
418
+ for _, level in enumerate(self.levels):
419
+ x = level(x)
420
+ hidden_states_out.append(x)
421
+ # Layer norm done over channel dim only (to NDHWC and back)
422
+ x = self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
423
+ return x, hidden_states_out
424
+
425
+ def forward(self, x):
426
+ """x shape (B, C, D, H, W)"""
427
+ x = self.forward_features(x)
428
+
429
+ if self.drop_rate > 0.0:
430
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
431
+ return x
432
+
433
+
434
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
435
+ if not depth_first and include_root:
436
+ fn(module=module, name=name)
437
+ for child_name, child_module in module.named_children():
438
+ child_name = ".".join((name, child_name)) if name else child_name
439
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
440
+ if depth_first and include_root:
441
+ fn(module=module, name=name)
442
+ return module
443
+
444
+
445
+ def _init_nest_weights(module: nn.Module, name: str = "", head_bias: float = 0.0):
446
+ """NesT weight initialization
447
+ Can replicate Jax implementation. Otherwise follows vision_transformer.py
448
+ """
449
+ if isinstance(module, nn.Linear):
450
+ if name.startswith("head"):
451
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
452
+ nn.init.constant_(module.bias, head_bias)
453
+ else:
454
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
455
+ if module.bias is not None:
456
+ nn.init.zeros_(module.bias)
457
+ elif isinstance(module, nn.Conv2d):
458
+ trunc_normal_(module.weight, std=0.02, a=-2, b=2)
459
+ if module.bias is not None:
460
+ nn.init.zeros_(module.bias)
461
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
462
+ nn.init.zeros_(module.bias)
463
+ nn.init.ones_(module.weight)
464
+
465
+
466
+ def resize_pos_embed(posemb, posemb_new):
467
+ """
468
+ Rescale the grid of position embeddings when loading from state_dict
469
+ Expected shape of position embeddings is (1, T, N, C), and considers only square images
470
+ """
471
+ _logger.info("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
472
+ seq_length_old = posemb.shape[2]
473
+ num_blocks_new, seq_length_new = posemb_new.shape[1:3]
474
+ size_new = int(math.sqrt(num_blocks_new * seq_length_new))
475
+ # First change to (1, C, H, W)
476
+ posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
477
+ posemb = F.interpolate(posemb, size=[size_new, size_new], mode="bicubic", align_corners=False)
478
+ # Now change to new (1, T, N, C)
479
+ posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
480
+ return posemb
481
+
482
+
483
+ def checkpoint_filter_fn(state_dict, model):
484
+ """resize positional embeddings of pretrained weights"""
485
+ pos_embed_keys = [k for k in state_dict.keys() if k.startswith("pos_embed_")]
486
+ for k in pos_embed_keys:
487
+ if state_dict[k].shape != getattr(model, k).shape:
488
+ state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k))
489
+ return state_dict
scripts/networks/patchEmbed3D.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020 - 2021 MONAI Consortium
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+
15
+ import math
16
+ from typing import Sequence, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from monai.utils import optional_import
22
+
23
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
24
+
25
+
26
+ class PatchEmbeddingBlock(nn.Module):
27
+ """
28
+ A patch embedding block, based on: "Dosovitskiy et al.,
29
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ img_size: Tuple[int, int, int],
36
+ patch_size: Tuple[int, int, int],
37
+ hidden_size: int,
38
+ num_heads: int,
39
+ pos_embed: str,
40
+ dropout_rate: float = 0.0,
41
+ ) -> None:
42
+ """
43
+ Args:
44
+ in_channels: dimension of input channels.
45
+ img_size: dimension of input image.
46
+ patch_size: dimension of patch size.
47
+ hidden_size: dimension of hidden layer.
48
+ num_heads: number of attention heads.
49
+ pos_embed: position embedding layer type.
50
+ dropout_rate: faction of the input units to drop.
51
+
52
+ """
53
+
54
+ super().__init__()
55
+
56
+ if not (0 <= dropout_rate <= 1):
57
+ raise AssertionError("dropout_rate should be between 0 and 1.")
58
+
59
+ if hidden_size % num_heads != 0:
60
+ raise AssertionError("hidden size should be divisible by num_heads.")
61
+
62
+ for m, p in zip(img_size, patch_size):
63
+ if m < p:
64
+ raise AssertionError("patch_size should be smaller than img_size.")
65
+
66
+ if pos_embed not in ["conv", "perceptron"]:
67
+ raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
68
+
69
+ if pos_embed == "perceptron":
70
+ if img_size[0] % patch_size[0] != 0:
71
+ raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")
72
+
73
+ self.n_patches = (
74
+ (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
75
+ )
76
+ self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]
77
+
78
+ self.pos_embed = pos_embed
79
+ self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
80
+ if self.pos_embed == "conv":
81
+ self.patch_embeddings = nn.Conv3d(
82
+ in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
83
+ )
84
+ elif self.pos_embed == "perceptron":
85
+ self.patch_embeddings = nn.Sequential(
86
+ Rearrange(
87
+ "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
88
+ p1=patch_size[0],
89
+ p2=patch_size[1],
90
+ p3=patch_size[2],
91
+ ),
92
+ nn.Linear(self.patch_dim, hidden_size),
93
+ )
94
+ self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
95
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
96
+ self.dropout = nn.Dropout(dropout_rate)
97
+ self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, m):
101
+ if isinstance(m, nn.Linear):
102
+ self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)
103
+ if isinstance(m, nn.Linear) and m.bias is not None:
104
+ nn.init.constant_(m.bias, 0)
105
+ elif isinstance(m, nn.LayerNorm):
106
+ nn.init.constant_(m.bias, 0)
107
+ nn.init.constant_(m.weight, 1.0)
108
+
109
+ def trunc_normal_(self, tensor, mean, std, a, b):
110
+ # From PyTorch official master until it's in a few official releases - RW
111
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
112
+ def norm_cdf(x):
113
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
114
+
115
+ with torch.no_grad():
116
+ l = norm_cdf((a - mean) / std)
117
+ u = norm_cdf((b - mean) / std)
118
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
119
+ tensor.erfinv_()
120
+ tensor.mul_(std * math.sqrt(2.0))
121
+ tensor.add_(mean)
122
+ tensor.clamp_(min=a, max=b)
123
+ return tensor
124
+
125
+ def forward(self, x):
126
+ if self.pos_embed == "conv":
127
+ x = self.patch_embeddings(x)
128
+ x = x.flatten(2)
129
+ x = x.transpose(-1, -2)
130
+ elif self.pos_embed == "perceptron":
131
+ x = self.patch_embeddings(x)
132
+ embeddings = x + self.position_embeddings
133
+ embeddings = self.dropout(embeddings)
134
+ return embeddings
135
+
136
+
137
+ class PatchEmbed3D(nn.Module):
138
+ """Video to Patch Embedding.
139
+
140
+ Args:
141
+ patch_size (int): Patch token size. Default: (2,4,4).
142
+ in_chans (int): Number of input video channels. Default: 3.
143
+ embed_dim (int): Number of linear projection output channels. Default: 96.
144
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ img_size: Sequence[int] = (96, 96, 96),
150
+ patch_size=(4, 4, 4),
151
+ in_chans: int = 1,
152
+ embed_dim: int = 96,
153
+ norm_layer=None,
154
+ ):
155
+ super().__init__()
156
+ self.patch_size = patch_size
157
+
158
+ self.in_chans = in_chans
159
+ self.embed_dim = embed_dim
160
+
161
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1], img_size[2] // patch_size[2])
162
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
163
+
164
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
165
+
166
+ if norm_layer is not None:
167
+ self.norm = norm_layer(embed_dim)
168
+ else:
169
+ self.norm = None
170
+
171
+ def forward(self, x):
172
+ """Forward function."""
173
+ # padding
174
+ _, _, d, h, w = x.size()
175
+ if w % self.patch_size[2] != 0:
176
+ x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))
177
+ if h % self.patch_size[1] != 0:
178
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))
179
+ if d % self.patch_size[0] != 0:
180
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))
181
+
182
+ x = self.proj(x) # B C D Wh Ww
183
+ if self.norm is not None:
184
+ d, wh, ww = x.size(2), x.size(3), x.size(4)
185
+ x = x.flatten(2).transpose(1, 2)
186
+ x = self.norm(x)
187
+ x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)
188
+ # pdb.set_trace()
189
+
190
+ return x
scripts/networks/unest.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ The 3D NEST transformer based segmentation model
5
+
6
+ MASI Lab, Vanderbilty University
7
+
8
+
9
+ Authors: Xin Yu, Yinchi Zhou, Yucheng Tang, Bennett Landman
10
+
11
+
12
+ The NEST code is partly from
13
+
14
+ Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and
15
+ Interpretable Visual Understanding
16
+ https://arxiv.org/pdf/2105.12723.pdf
17
+
18
+
19
+ """
20
+
21
+
22
+ # limitations under the License.
23
+ from typing import Sequence, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from monai.networks.blocks import Convolution
28
+ from monai.networks.blocks.dynunet_block import UnetOutBlock
29
+
30
+ # from scripts.networks.swin_transformer_3d import SwinTransformer3D
31
+ from scripts.networks.nest_transformer_3D import NestTransformer3D
32
+ from scripts.networks.unest_block import UNesTBlock, UNesTConvBlock, UNestUpBlock
33
+
34
+ # from monai.networks.blocks.unetr_block import UnetstrBasicBlock, UnetrPrUpBlock, UnetResBlock
35
+
36
+
37
+ class UNesT(nn.Module):
38
+ """
39
+ UNesT model implementation
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ in_channels: int,
45
+ out_channels: int,
46
+ img_size: Sequence[int] = (96, 96, 96),
47
+ feature_size: int = 16,
48
+ patch_size: int = 2,
49
+ depths: Sequence[int] = (2, 2, 2, 2),
50
+ num_heads: Sequence[int] = (3, 6, 12, 24),
51
+ window_size: Sequence[int] = (7, 7, 7),
52
+ norm_name: Union[Tuple, str] = "instance",
53
+ conv_block: bool = False,
54
+ res_block: bool = True,
55
+ # featResBlock: bool = False,
56
+ dropout_rate: float = 0.0,
57
+ ) -> None:
58
+ """
59
+ Args:
60
+ in_channels: dimension of input channels.
61
+ out_channels: dimension of output channels.
62
+ img_size: dimension of input image.
63
+ feature_size: dimension of network feature size.
64
+ hidden_size: dimension of hidden layer.
65
+ mlp_dim: dimension of feedforward layer.
66
+ num_heads: number of attention heads.
67
+ pos_embed: position embedding layer type.
68
+ norm_name: feature normalization type and arguments.
69
+ conv_block: bool argument to determine if convolutional block is used.
70
+ res_block: bool argument to determine if residual block is used.
71
+ dropout_rate: faction of the input units to drop.
72
+
73
+ """
74
+
75
+ super().__init__()
76
+
77
+ if not (0 <= dropout_rate <= 1):
78
+ raise AssertionError("dropout_rate should be between 0 and 1.")
79
+
80
+ self.embed_dim = [128, 256, 512]
81
+
82
+ self.nestViT = NestTransformer3D(
83
+ img_size=96,
84
+ in_chans=1,
85
+ patch_size=4,
86
+ num_levels=3,
87
+ embed_dims=(128, 256, 512),
88
+ num_heads=(4, 8, 16),
89
+ depths=(2, 2, 8),
90
+ num_classes=1000,
91
+ mlp_ratio=4.0,
92
+ qkv_bias=True,
93
+ drop_rate=0.0,
94
+ attn_drop_rate=0.0,
95
+ drop_path_rate=0.5,
96
+ norm_layer=None,
97
+ act_layer=None,
98
+ pad_type="",
99
+ weight_init="",
100
+ global_pool="avg",
101
+ )
102
+
103
+ self.encoder1 = UNesTConvBlock(
104
+ spatial_dims=3,
105
+ in_channels=1,
106
+ out_channels=feature_size * 2,
107
+ kernel_size=3,
108
+ stride=1,
109
+ norm_name=norm_name,
110
+ res_block=res_block,
111
+ )
112
+ self.encoder2 = UNestUpBlock(
113
+ spatial_dims=3,
114
+ in_channels=self.embed_dim[0],
115
+ out_channels=feature_size * 4,
116
+ num_layer=1,
117
+ kernel_size=3,
118
+ stride=1,
119
+ upsample_kernel_size=2,
120
+ norm_name=norm_name,
121
+ conv_block=False,
122
+ res_block=False,
123
+ )
124
+
125
+ self.encoder3 = UNesTConvBlock(
126
+ spatial_dims=3,
127
+ in_channels=self.embed_dim[0],
128
+ out_channels=8 * feature_size,
129
+ kernel_size=3,
130
+ stride=1,
131
+ norm_name=norm_name,
132
+ res_block=res_block,
133
+ )
134
+
135
+ self.encoder4 = UNesTConvBlock(
136
+ spatial_dims=3,
137
+ in_channels=self.embed_dim[1],
138
+ out_channels=16 * feature_size,
139
+ kernel_size=3,
140
+ stride=1,
141
+ norm_name=norm_name,
142
+ res_block=res_block,
143
+ )
144
+ self.decoder5 = UNesTBlock(
145
+ spatial_dims=3,
146
+ in_channels=2 * self.embed_dim[2],
147
+ out_channels=feature_size * 32,
148
+ stride=1,
149
+ kernel_size=3,
150
+ upsample_kernel_size=2,
151
+ norm_name=norm_name,
152
+ res_block=res_block,
153
+ )
154
+ self.decoder4 = UNesTBlock(
155
+ spatial_dims=3,
156
+ in_channels=self.embed_dim[2],
157
+ out_channels=feature_size * 16,
158
+ stride=1,
159
+ kernel_size=3,
160
+ upsample_kernel_size=2,
161
+ norm_name=norm_name,
162
+ res_block=res_block,
163
+ )
164
+ self.decoder3 = UNesTBlock(
165
+ spatial_dims=3,
166
+ in_channels=feature_size * 16,
167
+ out_channels=feature_size * 8,
168
+ stride=1,
169
+ kernel_size=3,
170
+ upsample_kernel_size=2,
171
+ norm_name=norm_name,
172
+ res_block=res_block,
173
+ )
174
+ self.decoder2 = UNesTBlock(
175
+ spatial_dims=3,
176
+ in_channels=feature_size * 8,
177
+ out_channels=feature_size * 4,
178
+ stride=1,
179
+ kernel_size=3,
180
+ upsample_kernel_size=2,
181
+ norm_name=norm_name,
182
+ res_block=res_block,
183
+ )
184
+
185
+ self.decoder1 = UNesTBlock(
186
+ spatial_dims=3,
187
+ in_channels=feature_size * 4,
188
+ out_channels=feature_size * 2,
189
+ stride=1,
190
+ kernel_size=3,
191
+ upsample_kernel_size=2,
192
+ norm_name=norm_name,
193
+ res_block=res_block,
194
+ )
195
+
196
+ self.encoder10 = Convolution(
197
+ spatial_dims=3,
198
+ in_channels=32 * feature_size,
199
+ out_channels=64 * feature_size,
200
+ strides=2,
201
+ adn_ordering="ADN",
202
+ dropout=0.0,
203
+ )
204
+
205
+ self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) # type: ignore
206
+
207
+ def proj_feat(self, x, hidden_size, feat_size):
208
+ x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
209
+ x = x.permute(0, 4, 1, 2, 3).contiguous()
210
+ return x
211
+
212
+ def load_from(self, weights):
213
+ with torch.no_grad():
214
+ # copy weights from patch embedding
215
+ for i in weights["state_dict"]:
216
+ print(i)
217
+ self.vit.patch_embedding.position_embeddings.copy_(
218
+ weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
219
+ )
220
+ self.vit.patch_embedding.cls_token.copy_(
221
+ weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
222
+ )
223
+ self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
224
+ weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.weight"]
225
+ )
226
+ self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
227
+ weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.bias"]
228
+ )
229
+
230
+ # copy weights from encoding blocks (default: num of blocks: 12)
231
+ for bname, block in self.vit.blocks.named_children():
232
+ print(block)
233
+ block.loadFrom(weights, n_block=bname)
234
+ # last norm layer of transformer
235
+ self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
236
+ self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
237
+
238
+ def forward(self, x_in):
239
+ x, hidden_states_out = self.nestViT(x_in)
240
+
241
+ enc0 = self.encoder1(x_in) # 2, 32, 96, 96, 96
242
+
243
+ x1 = hidden_states_out[0] # 2, 128, 24, 24, 24
244
+
245
+ enc1 = self.encoder2(x1) # 2, 64, 48, 48, 48
246
+
247
+ x2 = hidden_states_out[1] # 2, 128, 24, 24, 24
248
+
249
+ enc2 = self.encoder3(x2) # 2, 128, 24, 24, 24
250
+
251
+ x3 = hidden_states_out[2] # 2, 256, 12, 12, 12
252
+
253
+ enc3 = self.encoder4(x3) # 2, 256, 12, 12, 12
254
+
255
+ x4 = hidden_states_out[3]
256
+
257
+ enc4 = x4 # 2, 512, 6, 6, 6
258
+
259
+ dec4 = x # 2, 512, 6, 6, 6
260
+
261
+ dec4 = self.encoder10(dec4) # 2, 1024, 3, 3, 3
262
+
263
+ dec3 = self.decoder5(dec4, enc4) # 2, 512, 6, 6, 6
264
+
265
+ dec2 = self.decoder4(dec3, enc3) # 2, 256, 12, 12, 12
266
+
267
+ dec1 = self.decoder3(dec2, enc2) # 2, 128, 24, 24, 24
268
+
269
+ dec0 = self.decoder2(dec1, enc1) # 2, 64, 48, 48, 48
270
+
271
+ out = self.decoder1(dec0, enc0) # 2, 32, 96, 96, 96
272
+
273
+ logits = self.out(out)
274
+ return logits
scripts/networks/unest_block.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
8
+
9
+
10
+ class UNesTBlock(nn.Module):
11
+ """ """
12
+
13
+ def __init__(
14
+ self,
15
+ spatial_dims: int,
16
+ in_channels: int,
17
+ out_channels: int, # type: ignore
18
+ kernel_size: Union[Sequence[int], int],
19
+ stride: Union[Sequence[int], int],
20
+ upsample_kernel_size: Union[Sequence[int], int],
21
+ norm_name: Union[Tuple, str],
22
+ res_block: bool = False,
23
+ ) -> None:
24
+ """
25
+ Args:
26
+ spatial_dims: number of spatial dimensions.
27
+ in_channels: number of input channels.
28
+ out_channels: number of output channels.
29
+ kernel_size: convolution kernel size.
30
+ stride: convolution stride.
31
+ upsample_kernel_size: convolution kernel size for transposed convolution layers.
32
+ norm_name: feature normalization type and arguments.
33
+ res_block: bool argument to determine if residual block is used.
34
+
35
+ """
36
+
37
+ super(UNesTBlock, self).__init__()
38
+ upsample_stride = upsample_kernel_size
39
+ self.transp_conv = get_conv_layer(
40
+ spatial_dims,
41
+ in_channels,
42
+ out_channels,
43
+ kernel_size=upsample_kernel_size,
44
+ stride=upsample_stride,
45
+ conv_only=True,
46
+ is_transposed=True,
47
+ )
48
+
49
+ if res_block:
50
+ self.conv_block = UnetResBlock(
51
+ spatial_dims,
52
+ out_channels + out_channels,
53
+ out_channels,
54
+ kernel_size=kernel_size,
55
+ stride=1,
56
+ norm_name=norm_name,
57
+ )
58
+ else:
59
+ self.conv_block = UnetBasicBlock( # type: ignore
60
+ spatial_dims,
61
+ out_channels + out_channels,
62
+ out_channels,
63
+ kernel_size=kernel_size,
64
+ stride=1,
65
+ norm_name=norm_name,
66
+ )
67
+
68
+ def forward(self, inp, skip):
69
+ # number of channels for skip should equals to out_channels
70
+ out = self.transp_conv(inp)
71
+ # print(out.shape)
72
+ # print(skip.shape)
73
+ out = torch.cat((out, skip), dim=1)
74
+ out = self.conv_block(out)
75
+ return out
76
+
77
+
78
+ class UNestUpBlock(nn.Module):
79
+ """ """
80
+
81
+ def __init__(
82
+ self,
83
+ spatial_dims: int,
84
+ in_channels: int,
85
+ out_channels: int,
86
+ num_layer: int,
87
+ kernel_size: Union[Sequence[int], int],
88
+ stride: Union[Sequence[int], int],
89
+ upsample_kernel_size: Union[Sequence[int], int],
90
+ norm_name: Union[Tuple, str],
91
+ conv_block: bool = False,
92
+ res_block: bool = False,
93
+ ) -> None:
94
+ """
95
+ Args:
96
+ spatial_dims: number of spatial dimensions.
97
+ in_channels: number of input channels.
98
+ out_channels: number of output channels.
99
+ num_layer: number of upsampling blocks.
100
+ kernel_size: convolution kernel size.
101
+ stride: convolution stride.
102
+ upsample_kernel_size: convolution kernel size for transposed convolution layers.
103
+ norm_name: feature normalization type and arguments.
104
+ conv_block: bool argument to determine if convolutional block is used.
105
+ res_block: bool argument to determine if residual block is used.
106
+
107
+ """
108
+
109
+ super().__init__()
110
+
111
+ upsample_stride = upsample_kernel_size
112
+ self.transp_conv_init = get_conv_layer(
113
+ spatial_dims,
114
+ in_channels,
115
+ out_channels,
116
+ kernel_size=upsample_kernel_size,
117
+ stride=upsample_stride,
118
+ conv_only=True,
119
+ is_transposed=True,
120
+ )
121
+ if conv_block:
122
+ if res_block:
123
+ self.blocks = nn.ModuleList(
124
+ [
125
+ nn.Sequential(
126
+ get_conv_layer(
127
+ spatial_dims,
128
+ out_channels,
129
+ out_channels,
130
+ kernel_size=upsample_kernel_size,
131
+ stride=upsample_stride,
132
+ conv_only=True,
133
+ is_transposed=True,
134
+ ),
135
+ UnetResBlock(
136
+ spatial_dims=3,
137
+ in_channels=out_channels,
138
+ out_channels=out_channels,
139
+ kernel_size=kernel_size,
140
+ stride=stride,
141
+ norm_name=norm_name,
142
+ ),
143
+ )
144
+ for i in range(num_layer)
145
+ ]
146
+ )
147
+ else:
148
+ self.blocks = nn.ModuleList(
149
+ [
150
+ nn.Sequential(
151
+ get_conv_layer(
152
+ spatial_dims,
153
+ out_channels,
154
+ out_channels,
155
+ kernel_size=upsample_kernel_size,
156
+ stride=upsample_stride,
157
+ conv_only=True,
158
+ is_transposed=True,
159
+ ),
160
+ UnetBasicBlock(
161
+ spatial_dims=3,
162
+ in_channels=out_channels,
163
+ out_channels=out_channels,
164
+ kernel_size=kernel_size,
165
+ stride=stride,
166
+ norm_name=norm_name,
167
+ ),
168
+ )
169
+ for i in range(num_layer)
170
+ ]
171
+ )
172
+ else:
173
+ self.blocks = nn.ModuleList(
174
+ [
175
+ get_conv_layer(
176
+ spatial_dims,
177
+ out_channels,
178
+ out_channels,
179
+ kernel_size=1,
180
+ stride=1,
181
+ conv_only=True,
182
+ is_transposed=True,
183
+ )
184
+ for i in range(num_layer)
185
+ ]
186
+ )
187
+
188
+ def forward(self, x):
189
+ x = self.transp_conv_init(x)
190
+ for blk in self.blocks:
191
+ x = blk(x)
192
+ return x
193
+
194
+
195
+ class UNesTConvBlock(nn.Module):
196
+ """
197
+ UNesT block with skip connections
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ spatial_dims: int,
203
+ in_channels: int,
204
+ out_channels: int,
205
+ kernel_size: Union[Sequence[int], int],
206
+ stride: Union[Sequence[int], int],
207
+ norm_name: Union[Tuple, str],
208
+ res_block: bool = False,
209
+ ) -> None:
210
+ """
211
+ Args:
212
+ spatial_dims: number of spatial dimensions.
213
+ in_channels: number of input channels.
214
+ out_channels: number of output channels.
215
+ kernel_size: convolution kernel size.
216
+ stride: convolution stride.
217
+ norm_name: feature normalization type and arguments.
218
+ res_block: bool argument to determine if residual block is used.
219
+
220
+ """
221
+
222
+ super().__init__()
223
+
224
+ if res_block:
225
+ self.layer = UnetResBlock(
226
+ spatial_dims=spatial_dims,
227
+ in_channels=in_channels,
228
+ out_channels=out_channels,
229
+ kernel_size=kernel_size,
230
+ stride=stride,
231
+ norm_name=norm_name,
232
+ )
233
+ else:
234
+ self.layer = UnetBasicBlock( # type: ignore
235
+ spatial_dims=spatial_dims,
236
+ in_channels=in_channels,
237
+ out_channels=out_channels,
238
+ kernel_size=kernel_size,
239
+ stride=stride,
240
+ norm_name=norm_name,
241
+ )
242
+
243
+ def forward(self, inp):
244
+ out = self.layer(inp)
245
+ return out