Spaces:
Running
on
Zero
Running
on
Zero
Upload 76 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__init__.py +6 -0
- src/core/__init__.py +9 -0
- src/core/_config.py +299 -0
- src/core/workspace.py +178 -0
- src/core/yaml_config.py +187 -0
- src/core/yaml_utils.py +126 -0
- src/data/__init__.py +20 -0
- src/data/_misc.py +62 -0
- src/data/dataloader.py +122 -0
- src/data/dataset/__init__.py +17 -0
- src/data/dataset/_dataset.py +27 -0
- src/data/dataset/cifar_dataset.py +25 -0
- src/data/dataset/coco_dataset.py +282 -0
- src/data/dataset/coco_eval.py +214 -0
- src/data/dataset/coco_utils.py +191 -0
- src/data/dataset/voc_detection.py +86 -0
- src/data/dataset/voc_eval.py +12 -0
- src/data/transforms/__init__.py +21 -0
- src/data/transforms/_transforms.py +161 -0
- src/data/transforms/container.py +99 -0
- src/data/transforms/functional.py +172 -0
- src/data/transforms/mosaic.py +83 -0
- src/data/transforms/presets.py +4 -0
- src/misc/__init__.py +9 -0
- src/misc/box_ops.py +106 -0
- src/misc/dist_utils.py +281 -0
- src/misc/lazy_loader.py +70 -0
- src/misc/logger.py +255 -0
- src/misc/profiler_utils.py +30 -0
- src/misc/visualizer.py +121 -0
- src/nn/__init__.py +16 -0
- src/nn/arch/__init__.py +7 -0
- src/nn/arch/classification.py +45 -0
- src/nn/arch/yolo.py +42 -0
- src/nn/backbone/__init__.py +17 -0
- src/nn/backbone/common.py +117 -0
- src/nn/backbone/csp_darknet.py +203 -0
- src/nn/backbone/csp_resnet.py +302 -0
- src/nn/backbone/hgnetv2.py +581 -0
- src/nn/backbone/presnet.py +263 -0
- src/nn/backbone/test_resnet.py +83 -0
- src/nn/backbone/timm_model.py +66 -0
- src/nn/backbone/torchvision_model.py +50 -0
- src/nn/backbone/utils.py +56 -0
- src/nn/criterion/__init__.py +11 -0
- src/nn/criterion/det_criterion.py +188 -0
- src/nn/postprocessor/__init__.py +6 -0
- src/nn/postprocessor/box_revert.py +66 -0
- src/nn/postprocessor/detr_postprocessor.py +86 -0
- src/nn/postprocessor/nms_postprocessor.py +86 -0
src/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# for register purpose
|
6 |
+
from . import data, nn, optim, zoo
|
src/core/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from ._config import BaseConfig
|
7 |
+
from .workspace import GLOBAL_CONFIG, create, register
|
8 |
+
from .yaml_config import YAMLConfig
|
9 |
+
from .yaml_utils import *
|
src/core/_config.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Callable, Dict, List
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.cuda.amp.grad_scaler import GradScaler
|
12 |
+
from torch.optim import Optimizer
|
13 |
+
from torch.optim.lr_scheduler import LRScheduler
|
14 |
+
from torch.utils.data import DataLoader, Dataset
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"BaseConfig",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
class BaseConfig(object):
|
23 |
+
# TODO property
|
24 |
+
|
25 |
+
def __init__(self) -> None:
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.task: str = None
|
29 |
+
|
30 |
+
# instance / function
|
31 |
+
self._model: nn.Module = None
|
32 |
+
self._postprocessor: nn.Module = None
|
33 |
+
self._criterion: nn.Module = None
|
34 |
+
self._optimizer: Optimizer = None
|
35 |
+
self._lr_scheduler: LRScheduler = None
|
36 |
+
self._lr_warmup_scheduler: LRScheduler = None
|
37 |
+
self._train_dataloader: DataLoader = None
|
38 |
+
self._val_dataloader: DataLoader = None
|
39 |
+
self._ema: nn.Module = None
|
40 |
+
self._scaler: GradScaler = None
|
41 |
+
self._train_dataset: Dataset = None
|
42 |
+
self._val_dataset: Dataset = None
|
43 |
+
self._collate_fn: Callable = None
|
44 |
+
self._evaluator: Callable[[nn.Module, DataLoader, str],] = None
|
45 |
+
self._writer: SummaryWriter = None
|
46 |
+
|
47 |
+
# dataset
|
48 |
+
self.num_workers: int = 0
|
49 |
+
self.batch_size: int = None
|
50 |
+
self._train_batch_size: int = None
|
51 |
+
self._val_batch_size: int = None
|
52 |
+
self._train_shuffle: bool = None
|
53 |
+
self._val_shuffle: bool = None
|
54 |
+
|
55 |
+
# runtime
|
56 |
+
self.resume: str = None
|
57 |
+
self.tuning: str = None
|
58 |
+
|
59 |
+
self.epochs: int = None
|
60 |
+
self.last_epoch: int = -1
|
61 |
+
|
62 |
+
self.use_amp: bool = False
|
63 |
+
self.use_ema: bool = False
|
64 |
+
self.ema_decay: float = 0.9999
|
65 |
+
self.ema_warmups: int = 2000
|
66 |
+
self.sync_bn: bool = False
|
67 |
+
self.clip_max_norm: float = 0.0
|
68 |
+
self.find_unused_parameters: bool = None
|
69 |
+
|
70 |
+
self.seed: int = None
|
71 |
+
self.print_freq: int = None
|
72 |
+
self.checkpoint_freq: int = 1
|
73 |
+
self.output_dir: str = None
|
74 |
+
self.summary_dir: str = None
|
75 |
+
self.device: str = ""
|
76 |
+
|
77 |
+
@property
|
78 |
+
def model(self) -> nn.Module:
|
79 |
+
return self._model
|
80 |
+
|
81 |
+
@model.setter
|
82 |
+
def model(self, m):
|
83 |
+
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
|
84 |
+
self._model = m
|
85 |
+
|
86 |
+
@property
|
87 |
+
def postprocessor(self) -> nn.Module:
|
88 |
+
return self._postprocessor
|
89 |
+
|
90 |
+
@postprocessor.setter
|
91 |
+
def postprocessor(self, m):
|
92 |
+
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
|
93 |
+
self._postprocessor = m
|
94 |
+
|
95 |
+
@property
|
96 |
+
def criterion(self) -> nn.Module:
|
97 |
+
return self._criterion
|
98 |
+
|
99 |
+
@criterion.setter
|
100 |
+
def criterion(self, m):
|
101 |
+
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
|
102 |
+
self._criterion = m
|
103 |
+
|
104 |
+
@property
|
105 |
+
def optimizer(self) -> Optimizer:
|
106 |
+
return self._optimizer
|
107 |
+
|
108 |
+
@optimizer.setter
|
109 |
+
def optimizer(self, m):
|
110 |
+
assert isinstance(
|
111 |
+
m, Optimizer
|
112 |
+
), f"{type(m)} != optim.Optimizer, please check your model class"
|
113 |
+
self._optimizer = m
|
114 |
+
|
115 |
+
@property
|
116 |
+
def lr_scheduler(self) -> LRScheduler:
|
117 |
+
return self._lr_scheduler
|
118 |
+
|
119 |
+
@lr_scheduler.setter
|
120 |
+
def lr_scheduler(self, m):
|
121 |
+
assert isinstance(
|
122 |
+
m, LRScheduler
|
123 |
+
), f"{type(m)} != LRScheduler, please check your model class"
|
124 |
+
self._lr_scheduler = m
|
125 |
+
|
126 |
+
@property
|
127 |
+
def lr_warmup_scheduler(self) -> LRScheduler:
|
128 |
+
return self._lr_warmup_scheduler
|
129 |
+
|
130 |
+
@lr_warmup_scheduler.setter
|
131 |
+
def lr_warmup_scheduler(self, m):
|
132 |
+
self._lr_warmup_scheduler = m
|
133 |
+
|
134 |
+
@property
|
135 |
+
def train_dataloader(self) -> DataLoader:
|
136 |
+
if self._train_dataloader is None and self.train_dataset is not None:
|
137 |
+
loader = DataLoader(
|
138 |
+
self.train_dataset,
|
139 |
+
batch_size=self.train_batch_size,
|
140 |
+
num_workers=self.num_workers,
|
141 |
+
collate_fn=self.collate_fn,
|
142 |
+
shuffle=self.train_shuffle,
|
143 |
+
)
|
144 |
+
loader.shuffle = self.train_shuffle
|
145 |
+
self._train_dataloader = loader
|
146 |
+
|
147 |
+
return self._train_dataloader
|
148 |
+
|
149 |
+
@train_dataloader.setter
|
150 |
+
def train_dataloader(self, loader):
|
151 |
+
self._train_dataloader = loader
|
152 |
+
|
153 |
+
@property
|
154 |
+
def val_dataloader(self) -> DataLoader:
|
155 |
+
if self._val_dataloader is None and self.val_dataset is not None:
|
156 |
+
loader = DataLoader(
|
157 |
+
self.val_dataset,
|
158 |
+
batch_size=self.val_batch_size,
|
159 |
+
num_workers=self.num_workers,
|
160 |
+
drop_last=False,
|
161 |
+
collate_fn=self.collate_fn,
|
162 |
+
shuffle=self.val_shuffle,
|
163 |
+
persistent_workers=True,
|
164 |
+
)
|
165 |
+
loader.shuffle = self.val_shuffle
|
166 |
+
self._val_dataloader = loader
|
167 |
+
|
168 |
+
return self._val_dataloader
|
169 |
+
|
170 |
+
@val_dataloader.setter
|
171 |
+
def val_dataloader(self, loader):
|
172 |
+
self._val_dataloader = loader
|
173 |
+
|
174 |
+
@property
|
175 |
+
def ema(self) -> nn.Module:
|
176 |
+
if self._ema is None and self.use_ema and self.model is not None:
|
177 |
+
from ..optim import ModelEMA
|
178 |
+
|
179 |
+
self._ema = ModelEMA(self.model, self.ema_decay, self.ema_warmups)
|
180 |
+
return self._ema
|
181 |
+
|
182 |
+
@ema.setter
|
183 |
+
def ema(self, obj):
|
184 |
+
self._ema = obj
|
185 |
+
|
186 |
+
@property
|
187 |
+
def scaler(self) -> GradScaler:
|
188 |
+
if self._scaler is None and self.use_amp and torch.cuda.is_available():
|
189 |
+
self._scaler = GradScaler()
|
190 |
+
return self._scaler
|
191 |
+
|
192 |
+
@scaler.setter
|
193 |
+
def scaler(self, obj: GradScaler):
|
194 |
+
self._scaler = obj
|
195 |
+
|
196 |
+
@property
|
197 |
+
def val_shuffle(self) -> bool:
|
198 |
+
if self._val_shuffle is None:
|
199 |
+
print("warning: set default val_shuffle=False")
|
200 |
+
return False
|
201 |
+
return self._val_shuffle
|
202 |
+
|
203 |
+
@val_shuffle.setter
|
204 |
+
def val_shuffle(self, shuffle):
|
205 |
+
assert isinstance(shuffle, bool), "shuffle must be bool"
|
206 |
+
self._val_shuffle = shuffle
|
207 |
+
|
208 |
+
@property
|
209 |
+
def train_shuffle(self) -> bool:
|
210 |
+
if self._train_shuffle is None:
|
211 |
+
print("warning: set default train_shuffle=True")
|
212 |
+
return True
|
213 |
+
return self._train_shuffle
|
214 |
+
|
215 |
+
@train_shuffle.setter
|
216 |
+
def train_shuffle(self, shuffle):
|
217 |
+
assert isinstance(shuffle, bool), "shuffle must be bool"
|
218 |
+
self._train_shuffle = shuffle
|
219 |
+
|
220 |
+
@property
|
221 |
+
def train_batch_size(self) -> int:
|
222 |
+
if self._train_batch_size is None and isinstance(self.batch_size, int):
|
223 |
+
print(f"warning: set train_batch_size=batch_size={self.batch_size}")
|
224 |
+
return self.batch_size
|
225 |
+
return self._train_batch_size
|
226 |
+
|
227 |
+
@train_batch_size.setter
|
228 |
+
def train_batch_size(self, batch_size):
|
229 |
+
assert isinstance(batch_size, int), "batch_size must be int"
|
230 |
+
self._train_batch_size = batch_size
|
231 |
+
|
232 |
+
@property
|
233 |
+
def val_batch_size(self) -> int:
|
234 |
+
if self._val_batch_size is None:
|
235 |
+
print(f"warning: set val_batch_size=batch_size={self.batch_size}")
|
236 |
+
return self.batch_size
|
237 |
+
return self._val_batch_size
|
238 |
+
|
239 |
+
@val_batch_size.setter
|
240 |
+
def val_batch_size(self, batch_size):
|
241 |
+
assert isinstance(batch_size, int), "batch_size must be int"
|
242 |
+
self._val_batch_size = batch_size
|
243 |
+
|
244 |
+
@property
|
245 |
+
def train_dataset(self) -> Dataset:
|
246 |
+
return self._train_dataset
|
247 |
+
|
248 |
+
@train_dataset.setter
|
249 |
+
def train_dataset(self, dataset):
|
250 |
+
assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset"
|
251 |
+
self._train_dataset = dataset
|
252 |
+
|
253 |
+
@property
|
254 |
+
def val_dataset(self) -> Dataset:
|
255 |
+
return self._val_dataset
|
256 |
+
|
257 |
+
@val_dataset.setter
|
258 |
+
def val_dataset(self, dataset):
|
259 |
+
assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset"
|
260 |
+
self._val_dataset = dataset
|
261 |
+
|
262 |
+
@property
|
263 |
+
def collate_fn(self) -> Callable:
|
264 |
+
return self._collate_fn
|
265 |
+
|
266 |
+
@collate_fn.setter
|
267 |
+
def collate_fn(self, fn):
|
268 |
+
assert isinstance(fn, Callable), f"{type(fn)} must be Callable"
|
269 |
+
self._collate_fn = fn
|
270 |
+
|
271 |
+
@property
|
272 |
+
def evaluator(self) -> Callable:
|
273 |
+
return self._evaluator
|
274 |
+
|
275 |
+
@evaluator.setter
|
276 |
+
def evaluator(self, fn):
|
277 |
+
assert isinstance(fn, Callable), f"{type(fn)} must be Callable"
|
278 |
+
self._evaluator = fn
|
279 |
+
|
280 |
+
@property
|
281 |
+
def writer(self) -> SummaryWriter:
|
282 |
+
if self._writer is None:
|
283 |
+
if self.summary_dir:
|
284 |
+
self._writer = SummaryWriter(self.summary_dir)
|
285 |
+
elif self.output_dir:
|
286 |
+
self._writer = SummaryWriter(Path(self.output_dir) / "summary")
|
287 |
+
return self._writer
|
288 |
+
|
289 |
+
@writer.setter
|
290 |
+
def writer(self, m):
|
291 |
+
assert isinstance(m, SummaryWriter), f"{type(m)} must be SummaryWriter"
|
292 |
+
self._writer = m
|
293 |
+
|
294 |
+
def __repr__(self):
|
295 |
+
s = ""
|
296 |
+
for k, v in self.__dict__.items():
|
297 |
+
if not k.startswith("_"):
|
298 |
+
s += f"{k}: {v}\n"
|
299 |
+
return s
|
src/core/workspace.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import functools
|
7 |
+
import importlib
|
8 |
+
import inspect
|
9 |
+
from collections import defaultdict
|
10 |
+
from typing import Any, Dict, List, Optional
|
11 |
+
|
12 |
+
GLOBAL_CONFIG = defaultdict(dict)
|
13 |
+
|
14 |
+
|
15 |
+
def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
|
16 |
+
"""
|
17 |
+
dct:
|
18 |
+
if dct is Dict, register foo into dct as key-value pair
|
19 |
+
if dct is Clas, register as modules attibute
|
20 |
+
force
|
21 |
+
whether force register.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def decorator(foo):
|
25 |
+
register_name = foo.__name__ if name is None else name
|
26 |
+
if not force:
|
27 |
+
if inspect.isclass(dct):
|
28 |
+
assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
|
29 |
+
else:
|
30 |
+
assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"
|
31 |
+
|
32 |
+
if inspect.isfunction(foo):
|
33 |
+
|
34 |
+
@functools.wraps(foo)
|
35 |
+
def wrap_func(*args, **kwargs):
|
36 |
+
return foo(*args, **kwargs)
|
37 |
+
|
38 |
+
if isinstance(dct, dict):
|
39 |
+
dct[foo.__name__] = wrap_func
|
40 |
+
elif inspect.isclass(dct):
|
41 |
+
setattr(dct, foo.__name__, wrap_func)
|
42 |
+
else:
|
43 |
+
raise AttributeError("")
|
44 |
+
return wrap_func
|
45 |
+
|
46 |
+
elif inspect.isclass(foo):
|
47 |
+
dct[register_name] = extract_schema(foo)
|
48 |
+
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Do not support {type(foo)} register")
|
51 |
+
|
52 |
+
return foo
|
53 |
+
|
54 |
+
return decorator
|
55 |
+
|
56 |
+
|
57 |
+
def extract_schema(module: type):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
module (type),
|
61 |
+
Return:
|
62 |
+
Dict,
|
63 |
+
"""
|
64 |
+
argspec = inspect.getfullargspec(module.__init__)
|
65 |
+
arg_names = [arg for arg in argspec.args if arg != "self"]
|
66 |
+
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
|
67 |
+
num_requires = len(arg_names) - num_defualts
|
68 |
+
|
69 |
+
schame = dict()
|
70 |
+
schame["_name"] = module.__name__
|
71 |
+
schame["_pymodule"] = importlib.import_module(module.__module__)
|
72 |
+
schame["_inject"] = getattr(module, "__inject__", [])
|
73 |
+
schame["_share"] = getattr(module, "__share__", [])
|
74 |
+
schame["_kwargs"] = {}
|
75 |
+
for i, name in enumerate(arg_names):
|
76 |
+
if name in schame["_share"]:
|
77 |
+
assert i >= num_requires, "share config must have default value."
|
78 |
+
value = argspec.defaults[i - num_requires]
|
79 |
+
|
80 |
+
elif i >= num_requires:
|
81 |
+
value = argspec.defaults[i - num_requires]
|
82 |
+
|
83 |
+
else:
|
84 |
+
value = None
|
85 |
+
|
86 |
+
schame[name] = value
|
87 |
+
schame["_kwargs"][name] = value
|
88 |
+
|
89 |
+
return schame
|
90 |
+
|
91 |
+
|
92 |
+
def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
|
93 |
+
""" """
|
94 |
+
assert type(type_or_name) in (type, str), "create should be modules or name."
|
95 |
+
|
96 |
+
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
|
97 |
+
|
98 |
+
if name in global_cfg:
|
99 |
+
if hasattr(global_cfg[name], "__dict__"):
|
100 |
+
return global_cfg[name]
|
101 |
+
else:
|
102 |
+
raise ValueError("The module {} is not registered".format(name))
|
103 |
+
|
104 |
+
cfg = global_cfg[name]
|
105 |
+
|
106 |
+
if isinstance(cfg, dict) and "type" in cfg:
|
107 |
+
_cfg: dict = global_cfg[cfg["type"]]
|
108 |
+
# clean args
|
109 |
+
_keys = [k for k in _cfg.keys() if not k.startswith("_")]
|
110 |
+
for _arg in _keys:
|
111 |
+
del _cfg[_arg]
|
112 |
+
_cfg.update(_cfg["_kwargs"]) # restore default args
|
113 |
+
_cfg.update(cfg) # load config args
|
114 |
+
_cfg.update(kwargs) # TODO recive extra kwargs
|
115 |
+
name = _cfg.pop("type") # pop extra key `type` (from cfg)
|
116 |
+
|
117 |
+
return create(name, global_cfg)
|
118 |
+
|
119 |
+
module = getattr(cfg["_pymodule"], name)
|
120 |
+
module_kwargs = {}
|
121 |
+
module_kwargs.update(cfg)
|
122 |
+
|
123 |
+
# shared var
|
124 |
+
for k in cfg["_share"]:
|
125 |
+
if k in global_cfg:
|
126 |
+
module_kwargs[k] = global_cfg[k]
|
127 |
+
else:
|
128 |
+
module_kwargs[k] = cfg[k]
|
129 |
+
|
130 |
+
# inject
|
131 |
+
for k in cfg["_inject"]:
|
132 |
+
_k = cfg[k]
|
133 |
+
|
134 |
+
if _k is None:
|
135 |
+
continue
|
136 |
+
|
137 |
+
if isinstance(_k, str):
|
138 |
+
if _k not in global_cfg:
|
139 |
+
raise ValueError(f"Missing inject config of {_k}.")
|
140 |
+
|
141 |
+
_cfg = global_cfg[_k]
|
142 |
+
|
143 |
+
if isinstance(_cfg, dict):
|
144 |
+
module_kwargs[k] = create(_cfg["_name"], global_cfg)
|
145 |
+
else:
|
146 |
+
module_kwargs[k] = _cfg
|
147 |
+
|
148 |
+
elif isinstance(_k, dict):
|
149 |
+
if "type" not in _k.keys():
|
150 |
+
raise ValueError("Missing inject for `type` style.")
|
151 |
+
|
152 |
+
_type = str(_k["type"])
|
153 |
+
if _type not in global_cfg:
|
154 |
+
raise ValueError(f"Missing {_type} in inspect stage.")
|
155 |
+
|
156 |
+
# TODO
|
157 |
+
_cfg: dict = global_cfg[_type]
|
158 |
+
# clean args
|
159 |
+
_keys = [k for k in _cfg.keys() if not k.startswith("_")]
|
160 |
+
for _arg in _keys:
|
161 |
+
del _cfg[_arg]
|
162 |
+
_cfg.update(_cfg["_kwargs"]) # restore default values
|
163 |
+
_cfg.update(_k) # load config args
|
164 |
+
name = _cfg.pop("type") # pop extra key (`type` from _k)
|
165 |
+
module_kwargs[k] = create(name, global_cfg)
|
166 |
+
|
167 |
+
else:
|
168 |
+
raise ValueError(f"Inject does not support {_k}")
|
169 |
+
|
170 |
+
# TODO hard code
|
171 |
+
module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}
|
172 |
+
|
173 |
+
# TODO for **kwargs
|
174 |
+
# extra_args = set(module_kwargs.keys()) - set(arg_names)
|
175 |
+
# if len(extra_args) > 0:
|
176 |
+
# raise RuntimeError(f'Error: unknown args {extra_args} for {module}')
|
177 |
+
|
178 |
+
return module(**module_kwargs)
|
src/core/yaml_config.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import re
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from ._config import BaseConfig
|
15 |
+
from .workspace import create
|
16 |
+
from .yaml_utils import load_config, merge_config, merge_dict
|
17 |
+
|
18 |
+
|
19 |
+
class YAMLConfig(BaseConfig):
|
20 |
+
def __init__(self, cfg_path: str, **kwargs) -> None:
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
cfg = load_config(cfg_path)
|
24 |
+
cfg = merge_dict(cfg, kwargs)
|
25 |
+
|
26 |
+
self.yaml_cfg = copy.deepcopy(cfg)
|
27 |
+
|
28 |
+
for k in super().__dict__:
|
29 |
+
if not k.startswith("_") and k in cfg:
|
30 |
+
self.__dict__[k] = cfg[k]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def global_cfg(self):
|
34 |
+
return merge_config(self.yaml_cfg, inplace=False, overwrite=False)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def model(self) -> torch.nn.Module:
|
38 |
+
if self._model is None and "model" in self.yaml_cfg:
|
39 |
+
self._model = create(self.yaml_cfg["model"], self.global_cfg)
|
40 |
+
return super().model
|
41 |
+
|
42 |
+
@property
|
43 |
+
def postprocessor(self) -> torch.nn.Module:
|
44 |
+
if self._postprocessor is None and "postprocessor" in self.yaml_cfg:
|
45 |
+
self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg)
|
46 |
+
return super().postprocessor
|
47 |
+
|
48 |
+
@property
|
49 |
+
def criterion(self) -> torch.nn.Module:
|
50 |
+
if self._criterion is None and "criterion" in self.yaml_cfg:
|
51 |
+
self._criterion = create(self.yaml_cfg["criterion"], self.global_cfg)
|
52 |
+
return super().criterion
|
53 |
+
|
54 |
+
@property
|
55 |
+
def optimizer(self) -> optim.Optimizer:
|
56 |
+
if self._optimizer is None and "optimizer" in self.yaml_cfg:
|
57 |
+
params = self.get_optim_params(self.yaml_cfg["optimizer"], self.model)
|
58 |
+
self._optimizer = create("optimizer", self.global_cfg, params=params)
|
59 |
+
return super().optimizer
|
60 |
+
|
61 |
+
@property
|
62 |
+
def lr_scheduler(self) -> optim.lr_scheduler.LRScheduler:
|
63 |
+
if self._lr_scheduler is None and "lr_scheduler" in self.yaml_cfg:
|
64 |
+
self._lr_scheduler = create("lr_scheduler", self.global_cfg, optimizer=self.optimizer)
|
65 |
+
print(f"Initial lr: {self._lr_scheduler.get_last_lr()}")
|
66 |
+
return super().lr_scheduler
|
67 |
+
|
68 |
+
@property
|
69 |
+
def lr_warmup_scheduler(self) -> optim.lr_scheduler.LRScheduler:
|
70 |
+
if self._lr_warmup_scheduler is None and "lr_warmup_scheduler" in self.yaml_cfg:
|
71 |
+
self._lr_warmup_scheduler = create(
|
72 |
+
"lr_warmup_scheduler", self.global_cfg, lr_scheduler=self.lr_scheduler
|
73 |
+
)
|
74 |
+
return super().lr_warmup_scheduler
|
75 |
+
|
76 |
+
@property
|
77 |
+
def train_dataloader(self) -> DataLoader:
|
78 |
+
if self._train_dataloader is None and "train_dataloader" in self.yaml_cfg:
|
79 |
+
self._train_dataloader = self.build_dataloader("train_dataloader")
|
80 |
+
return super().train_dataloader
|
81 |
+
|
82 |
+
@property
|
83 |
+
def val_dataloader(self) -> DataLoader:
|
84 |
+
if self._val_dataloader is None and "val_dataloader" in self.yaml_cfg:
|
85 |
+
self._val_dataloader = self.build_dataloader("val_dataloader")
|
86 |
+
return super().val_dataloader
|
87 |
+
|
88 |
+
@property
|
89 |
+
def ema(self) -> torch.nn.Module:
|
90 |
+
if self._ema is None and self.yaml_cfg.get("use_ema", False):
|
91 |
+
self._ema = create("ema", self.global_cfg, model=self.model)
|
92 |
+
return super().ema
|
93 |
+
|
94 |
+
@property
|
95 |
+
def scaler(self):
|
96 |
+
if self._scaler is None and self.yaml_cfg.get("use_amp", False):
|
97 |
+
self._scaler = create("scaler", self.global_cfg)
|
98 |
+
return super().scaler
|
99 |
+
|
100 |
+
@property
|
101 |
+
def evaluator(self):
|
102 |
+
if self._evaluator is None and "evaluator" in self.yaml_cfg:
|
103 |
+
if self.yaml_cfg["evaluator"]["type"] == "CocoEvaluator":
|
104 |
+
from ..data import get_coco_api_from_dataset
|
105 |
+
|
106 |
+
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
|
107 |
+
self._evaluator = create("evaluator", self.global_cfg, coco_gt=base_ds)
|
108 |
+
else:
|
109 |
+
raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}")
|
110 |
+
return super().evaluator
|
111 |
+
|
112 |
+
@property
|
113 |
+
def use_wandb(self) -> bool:
|
114 |
+
return self.yaml_cfg.get("use_wandb", False)
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def get_optim_params(cfg: dict, model: nn.Module):
|
118 |
+
"""
|
119 |
+
E.g.:
|
120 |
+
^(?=.*a)(?=.*b).*$ means including a and b
|
121 |
+
^(?=.*(?:a|b)).*$ means including a or b
|
122 |
+
^(?=.*a)(?!.*b).*$ means including a, but not b
|
123 |
+
"""
|
124 |
+
assert "type" in cfg, ""
|
125 |
+
cfg = copy.deepcopy(cfg)
|
126 |
+
|
127 |
+
if "params" not in cfg:
|
128 |
+
return model.parameters()
|
129 |
+
|
130 |
+
assert isinstance(cfg["params"], list), ""
|
131 |
+
|
132 |
+
param_groups = []
|
133 |
+
visited = []
|
134 |
+
for pg in cfg["params"]:
|
135 |
+
pattern = pg["params"]
|
136 |
+
params = {
|
137 |
+
k: v
|
138 |
+
for k, v in model.named_parameters()
|
139 |
+
if v.requires_grad and len(re.findall(pattern, k)) > 0
|
140 |
+
}
|
141 |
+
pg["params"] = params.values()
|
142 |
+
param_groups.append(pg)
|
143 |
+
visited.extend(list(params.keys()))
|
144 |
+
# print(params.keys())
|
145 |
+
|
146 |
+
names = [k for k, v in model.named_parameters() if v.requires_grad]
|
147 |
+
|
148 |
+
if len(visited) < len(names):
|
149 |
+
unseen = set(names) - set(visited)
|
150 |
+
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
|
151 |
+
param_groups.append({"params": params.values()})
|
152 |
+
visited.extend(list(params.keys()))
|
153 |
+
# print(params.keys())
|
154 |
+
|
155 |
+
assert len(visited) == len(names), ""
|
156 |
+
|
157 |
+
return param_groups
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def get_rank_batch_size(cfg):
|
161 |
+
"""compute batch size for per rank if total_batch_size is provided."""
|
162 |
+
assert ("total_batch_size" in cfg or "batch_size" in cfg) and not (
|
163 |
+
"total_batch_size" in cfg and "batch_size" in cfg
|
164 |
+
), "`batch_size` or `total_batch_size` should be choosed one"
|
165 |
+
|
166 |
+
total_batch_size = cfg.get("total_batch_size", None)
|
167 |
+
if total_batch_size is None:
|
168 |
+
bs = cfg.get("batch_size")
|
169 |
+
else:
|
170 |
+
from ..misc import dist_utils
|
171 |
+
|
172 |
+
assert (
|
173 |
+
total_batch_size % dist_utils.get_world_size() == 0
|
174 |
+
), "total_batch_size should be divisible by world size"
|
175 |
+
bs = total_batch_size // dist_utils.get_world_size()
|
176 |
+
return bs
|
177 |
+
|
178 |
+
def build_dataloader(self, name: str):
|
179 |
+
bs = self.get_rank_batch_size(self.yaml_cfg[name])
|
180 |
+
global_cfg = self.global_cfg
|
181 |
+
if "total_batch_size" in global_cfg[name]:
|
182 |
+
# pop unexpected key for dataloader init
|
183 |
+
_ = global_cfg[name].pop("total_batch_size")
|
184 |
+
print(f"building {name} with batch_size={bs}...")
|
185 |
+
loader = create(name, global_cfg, batch_size=bs)
|
186 |
+
loader.shuffle = self.yaml_cfg[name].get("shuffle", False)
|
187 |
+
return loader
|
src/core/yaml_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import os
|
8 |
+
from typing import Any, Dict, List, Optional
|
9 |
+
|
10 |
+
import yaml
|
11 |
+
|
12 |
+
from .workspace import GLOBAL_CONFIG
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
"load_config",
|
16 |
+
"merge_config",
|
17 |
+
"merge_dict",
|
18 |
+
"parse_cli",
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
INCLUDE_KEY = "__include__"
|
23 |
+
|
24 |
+
|
25 |
+
def load_config(file_path, cfg=dict()):
|
26 |
+
"""load config"""
|
27 |
+
_, ext = os.path.splitext(file_path)
|
28 |
+
assert ext in [".yml", ".yaml"], "only support yaml files"
|
29 |
+
|
30 |
+
with open(file_path) as f:
|
31 |
+
file_cfg = yaml.load(f, Loader=yaml.Loader)
|
32 |
+
if file_cfg is None:
|
33 |
+
return {}
|
34 |
+
|
35 |
+
if INCLUDE_KEY in file_cfg:
|
36 |
+
base_yamls = list(file_cfg[INCLUDE_KEY])
|
37 |
+
for base_yaml in base_yamls:
|
38 |
+
if base_yaml.startswith("~"):
|
39 |
+
base_yaml = os.path.expanduser(base_yaml)
|
40 |
+
|
41 |
+
if not base_yaml.startswith("/"):
|
42 |
+
base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
|
43 |
+
|
44 |
+
with open(base_yaml) as f:
|
45 |
+
base_cfg = load_config(base_yaml, cfg)
|
46 |
+
merge_dict(cfg, base_cfg)
|
47 |
+
|
48 |
+
return merge_dict(cfg, file_cfg)
|
49 |
+
|
50 |
+
|
51 |
+
def merge_dict(dct, another_dct, inplace=True) -> Dict:
|
52 |
+
"""merge another_dct into dct"""
|
53 |
+
|
54 |
+
def _merge(dct, another) -> Dict:
|
55 |
+
for k in another:
|
56 |
+
if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict):
|
57 |
+
_merge(dct[k], another[k])
|
58 |
+
else:
|
59 |
+
dct[k] = another[k]
|
60 |
+
|
61 |
+
return dct
|
62 |
+
|
63 |
+
if not inplace:
|
64 |
+
dct = copy.deepcopy(dct)
|
65 |
+
|
66 |
+
return _merge(dct, another_dct)
|
67 |
+
|
68 |
+
|
69 |
+
def dictify(s: str, v: Any) -> Dict:
|
70 |
+
if "." not in s:
|
71 |
+
return {s: v}
|
72 |
+
key, rest = s.split(".", 1)
|
73 |
+
return {key: dictify(rest, v)}
|
74 |
+
|
75 |
+
|
76 |
+
def parse_cli(nargs: List[str]) -> Dict:
|
77 |
+
"""
|
78 |
+
parse command-line arguments
|
79 |
+
convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}`
|
80 |
+
"""
|
81 |
+
cfg = {}
|
82 |
+
if nargs is None or len(nargs) == 0:
|
83 |
+
return cfg
|
84 |
+
|
85 |
+
for s in nargs:
|
86 |
+
s = s.strip()
|
87 |
+
k, v = s.split("=", 1)
|
88 |
+
d = dictify(k, yaml.load(v, Loader=yaml.Loader))
|
89 |
+
cfg = merge_dict(cfg, d)
|
90 |
+
|
91 |
+
return cfg
|
92 |
+
|
93 |
+
|
94 |
+
def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False):
|
95 |
+
"""
|
96 |
+
Merge another_cfg into cfg, return the merged config
|
97 |
+
|
98 |
+
Example:
|
99 |
+
|
100 |
+
cfg1 = load_config('./dfine_r18vd_6x_coco.yml')
|
101 |
+
cfg1 = merge_config(cfg, inplace=True)
|
102 |
+
|
103 |
+
cfg2 = load_config('./dfine_r50vd_6x_coco.yml')
|
104 |
+
cfg2 = merge_config(cfg2, inplace=True)
|
105 |
+
|
106 |
+
model1 = create(cfg1['model'], cfg1)
|
107 |
+
model2 = create(cfg2['model'], cfg2)
|
108 |
+
"""
|
109 |
+
|
110 |
+
def _merge(dct, another):
|
111 |
+
for k in another:
|
112 |
+
if k not in dct:
|
113 |
+
dct[k] = another[k]
|
114 |
+
|
115 |
+
elif isinstance(dct[k], dict) and isinstance(another[k], dict):
|
116 |
+
_merge(dct[k], another[k])
|
117 |
+
|
118 |
+
elif overwrite:
|
119 |
+
dct[k] = another[k]
|
120 |
+
|
121 |
+
return cfg
|
122 |
+
|
123 |
+
if not inplace:
|
124 |
+
cfg = copy.deepcopy(cfg)
|
125 |
+
|
126 |
+
return _merge(cfg, another_cfg)
|
src/data/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from ._misc import convert_to_tv_tensor
|
7 |
+
from .dataloader import *
|
8 |
+
from .dataset import *
|
9 |
+
from .transforms import *
|
10 |
+
|
11 |
+
|
12 |
+
# def set_epoch(self, epoch) -> None:
|
13 |
+
# self.epoch = epoch
|
14 |
+
# def _set_epoch_func(datasets):
|
15 |
+
# """Add `set_epoch` for datasets
|
16 |
+
# """
|
17 |
+
# from ..core import register
|
18 |
+
# for ds in datasets:
|
19 |
+
# register(ds)(set_epoch)
|
20 |
+
# _set_epoch_func([CIFAR10, VOCDetection, CocoDetection])
|
src/data/_misc.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import importlib.metadata
|
7 |
+
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
if "0.15.2" in importlib.metadata.version("torchvision"):
|
11 |
+
import torchvision
|
12 |
+
|
13 |
+
torchvision.disable_beta_transforms_warning()
|
14 |
+
|
15 |
+
from torchvision.datapoints import BoundingBox as BoundingBoxes
|
16 |
+
from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video
|
17 |
+
from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes
|
18 |
+
|
19 |
+
_boxes_keys = ["format", "spatial_size"]
|
20 |
+
|
21 |
+
elif "0.17" > importlib.metadata.version("torchvision") >= "0.16":
|
22 |
+
import torchvision
|
23 |
+
|
24 |
+
torchvision.disable_beta_transforms_warning()
|
25 |
+
|
26 |
+
from torchvision.transforms.v2 import SanitizeBoundingBoxes
|
27 |
+
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
|
28 |
+
|
29 |
+
_boxes_keys = ["format", "canvas_size"]
|
30 |
+
|
31 |
+
elif importlib.metadata.version("torchvision") >= "0.17":
|
32 |
+
import torchvision
|
33 |
+
from torchvision.transforms.v2 import SanitizeBoundingBoxes
|
34 |
+
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
|
35 |
+
|
36 |
+
_boxes_keys = ["format", "canvas_size"]
|
37 |
+
|
38 |
+
else:
|
39 |
+
raise RuntimeError("Please make sure torchvision version >= 0.15.2")
|
40 |
+
|
41 |
+
|
42 |
+
def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor:
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
tensor (Tensor): input tensor
|
46 |
+
key (str): transform to key
|
47 |
+
|
48 |
+
Return:
|
49 |
+
Dict[str, TV_Tensor]
|
50 |
+
"""
|
51 |
+
assert key in (
|
52 |
+
"boxes",
|
53 |
+
"masks",
|
54 |
+
), "Only support 'boxes' and 'masks'"
|
55 |
+
|
56 |
+
if key == "boxes":
|
57 |
+
box_format = getattr(BoundingBoxFormat, box_format.upper())
|
58 |
+
_kwargs = dict(zip(_boxes_keys, [box_format, spatial_size]))
|
59 |
+
return BoundingBoxes(tensor, **_kwargs)
|
60 |
+
|
61 |
+
if key == "masks":
|
62 |
+
return Mask(tensor)
|
src/data/dataloader.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import random
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.data as data
|
12 |
+
import torchvision
|
13 |
+
import torchvision.transforms.v2 as VT
|
14 |
+
from torch.utils.data import default_collate
|
15 |
+
from torchvision.transforms.v2 import InterpolationMode
|
16 |
+
from torchvision.transforms.v2 import functional as VF
|
17 |
+
|
18 |
+
from ..core import register
|
19 |
+
|
20 |
+
torchvision.disable_beta_transforms_warning()
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
"DataLoader",
|
25 |
+
"BaseCollateFunction",
|
26 |
+
"BatchImageCollateFunction",
|
27 |
+
"batch_image_collate_fn",
|
28 |
+
]
|
29 |
+
|
30 |
+
|
31 |
+
@register()
|
32 |
+
class DataLoader(data.DataLoader):
|
33 |
+
__inject__ = ["dataset", "collate_fn"]
|
34 |
+
|
35 |
+
def __repr__(self) -> str:
|
36 |
+
format_string = self.__class__.__name__ + "("
|
37 |
+
for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]:
|
38 |
+
format_string += "\n"
|
39 |
+
format_string += " {0}: {1}".format(n, getattr(self, n))
|
40 |
+
format_string += "\n)"
|
41 |
+
return format_string
|
42 |
+
|
43 |
+
def set_epoch(self, epoch):
|
44 |
+
self._epoch = epoch
|
45 |
+
self.dataset.set_epoch(epoch)
|
46 |
+
self.collate_fn.set_epoch(epoch)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def epoch(self):
|
50 |
+
return self._epoch if hasattr(self, "_epoch") else -1
|
51 |
+
|
52 |
+
@property
|
53 |
+
def shuffle(self):
|
54 |
+
return self._shuffle
|
55 |
+
|
56 |
+
@shuffle.setter
|
57 |
+
def shuffle(self, shuffle):
|
58 |
+
assert isinstance(shuffle, bool), "shuffle must be a boolean"
|
59 |
+
self._shuffle = shuffle
|
60 |
+
|
61 |
+
|
62 |
+
@register()
|
63 |
+
def batch_image_collate_fn(items):
|
64 |
+
"""only batch image"""
|
65 |
+
return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
|
66 |
+
|
67 |
+
|
68 |
+
class BaseCollateFunction(object):
|
69 |
+
def set_epoch(self, epoch):
|
70 |
+
self._epoch = epoch
|
71 |
+
|
72 |
+
@property
|
73 |
+
def epoch(self):
|
74 |
+
return self._epoch if hasattr(self, "_epoch") else -1
|
75 |
+
|
76 |
+
def __call__(self, items):
|
77 |
+
raise NotImplementedError("")
|
78 |
+
|
79 |
+
|
80 |
+
def generate_scales(base_size, base_size_repeat):
|
81 |
+
scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32
|
82 |
+
scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)]
|
83 |
+
scales += [base_size] * base_size_repeat
|
84 |
+
scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)]
|
85 |
+
return scales
|
86 |
+
|
87 |
+
|
88 |
+
@register()
|
89 |
+
class BatchImageCollateFunction(BaseCollateFunction):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
stop_epoch=None,
|
93 |
+
ema_restart_decay=0.9999,
|
94 |
+
base_size=640,
|
95 |
+
base_size_repeat=None,
|
96 |
+
) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.base_size = base_size
|
99 |
+
self.scales = (
|
100 |
+
generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None
|
101 |
+
)
|
102 |
+
self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000
|
103 |
+
self.ema_restart_decay = ema_restart_decay
|
104 |
+
# self.interpolation = interpolation
|
105 |
+
|
106 |
+
def __call__(self, items):
|
107 |
+
images = torch.cat([x[0][None] for x in items], dim=0)
|
108 |
+
targets = [x[1] for x in items]
|
109 |
+
|
110 |
+
if self.scales is not None and self.epoch < self.stop_epoch:
|
111 |
+
# sz = random.choice(self.scales)
|
112 |
+
# sz = [sz] if isinstance(sz, int) else list(sz)
|
113 |
+
# VF.resize(inpt, sz, interpolation=self.interpolation)
|
114 |
+
|
115 |
+
sz = random.choice(self.scales)
|
116 |
+
images = F.interpolate(images, size=sz)
|
117 |
+
if "masks" in targets[0]:
|
118 |
+
for tg in targets:
|
119 |
+
tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest")
|
120 |
+
raise NotImplementedError("")
|
121 |
+
|
122 |
+
return images, targets
|
src/data/dataset/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
# from ._dataset import DetDataset
|
7 |
+
from .cifar_dataset import CIFAR10
|
8 |
+
from .coco_dataset import (
|
9 |
+
CocoDetection,
|
10 |
+
mscoco_category2label,
|
11 |
+
mscoco_category2name,
|
12 |
+
mscoco_label2category,
|
13 |
+
)
|
14 |
+
from .coco_eval import CocoEvaluator
|
15 |
+
from .coco_utils import get_coco_api_from_dataset
|
16 |
+
from .voc_detection import VOCDetection
|
17 |
+
from .voc_eval import VOCEvaluator
|
src/data/dataset/_dataset.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
|
10 |
+
class DetDataset(data.Dataset):
|
11 |
+
def __getitem__(self, index):
|
12 |
+
img, target = self.load_item(index)
|
13 |
+
if self.transforms is not None:
|
14 |
+
img, target, _ = self.transforms(img, target, self)
|
15 |
+
return img, target
|
16 |
+
|
17 |
+
def load_item(self, index):
|
18 |
+
raise NotImplementedError(
|
19 |
+
"Please implement this function to return item before `transforms`."
|
20 |
+
)
|
21 |
+
|
22 |
+
def set_epoch(self, epoch) -> None:
|
23 |
+
self._epoch = epoch
|
24 |
+
|
25 |
+
@property
|
26 |
+
def epoch(self):
|
27 |
+
return self._epoch if hasattr(self, "_epoch") else -1
|
src/data/dataset/cifar_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Callable, Optional
|
7 |
+
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from ...core import register
|
11 |
+
|
12 |
+
|
13 |
+
@register()
|
14 |
+
class CIFAR10(torchvision.datasets.CIFAR10):
|
15 |
+
__inject__ = ["transform", "target_transform"]
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
root: str,
|
20 |
+
train: bool = True,
|
21 |
+
transform: Optional[Callable] = None,
|
22 |
+
target_transform: Optional[Callable] = None,
|
23 |
+
download: bool = False,
|
24 |
+
) -> None:
|
25 |
+
super().__init__(root, train, transform, target_transform, download)
|
src/data/dataset/coco_dataset.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
4 |
+
|
5 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import faster_coco_eval
|
9 |
+
import faster_coco_eval.core.mask as coco_mask
|
10 |
+
import torch
|
11 |
+
import torch.utils.data
|
12 |
+
import torchvision
|
13 |
+
import os
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from ...core import register
|
17 |
+
from .._misc import convert_to_tv_tensor
|
18 |
+
from ._dataset import DetDataset
|
19 |
+
|
20 |
+
torchvision.disable_beta_transforms_warning()
|
21 |
+
faster_coco_eval.init_as_pycocotools()
|
22 |
+
Image.MAX_IMAGE_PIXELS = None
|
23 |
+
|
24 |
+
__all__ = ["CocoDetection"]
|
25 |
+
|
26 |
+
|
27 |
+
@register()
|
28 |
+
class CocoDetection(torchvision.datasets.CocoDetection, DetDataset):
|
29 |
+
__inject__ = [
|
30 |
+
"transforms",
|
31 |
+
]
|
32 |
+
__share__ = ["remap_mscoco_category"]
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False
|
36 |
+
):
|
37 |
+
super(CocoDetection, self).__init__(img_folder, ann_file)
|
38 |
+
self._transforms = transforms
|
39 |
+
self.prepare = ConvertCocoPolysToMask(return_masks)
|
40 |
+
self.img_folder = img_folder
|
41 |
+
self.ann_file = ann_file
|
42 |
+
self.return_masks = return_masks
|
43 |
+
self.remap_mscoco_category = remap_mscoco_category
|
44 |
+
|
45 |
+
def __getitem__(self, idx):
|
46 |
+
img, target = self.load_item(idx)
|
47 |
+
if self._transforms is not None:
|
48 |
+
img, target, _ = self._transforms(img, target, self)
|
49 |
+
return img, target
|
50 |
+
|
51 |
+
def load_item(self, idx):
|
52 |
+
image, target = super(CocoDetection, self).__getitem__(idx)
|
53 |
+
image_id = self.ids[idx]
|
54 |
+
image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"])
|
55 |
+
target = {"image_id": image_id, "image_path": image_path, "annotations": target}
|
56 |
+
|
57 |
+
if self.remap_mscoco_category:
|
58 |
+
image, target = self.prepare(image, target, category2label=mscoco_category2label)
|
59 |
+
else:
|
60 |
+
image, target = self.prepare(image, target)
|
61 |
+
|
62 |
+
target["idx"] = torch.tensor([idx])
|
63 |
+
|
64 |
+
if "boxes" in target:
|
65 |
+
target["boxes"] = convert_to_tv_tensor(
|
66 |
+
target["boxes"], key="boxes", spatial_size=image.size[::-1]
|
67 |
+
)
|
68 |
+
|
69 |
+
if "masks" in target:
|
70 |
+
target["masks"] = convert_to_tv_tensor(target["masks"], key="masks")
|
71 |
+
|
72 |
+
return image, target
|
73 |
+
|
74 |
+
def extra_repr(self) -> str:
|
75 |
+
s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n"
|
76 |
+
s += f" return_masks: {self.return_masks}\n"
|
77 |
+
if hasattr(self, "_transforms") and self._transforms is not None:
|
78 |
+
s += f" transforms:\n {repr(self._transforms)}"
|
79 |
+
if hasattr(self, "_preset") and self._preset is not None:
|
80 |
+
s += f" preset:\n {repr(self._preset)}"
|
81 |
+
return s
|
82 |
+
|
83 |
+
@property
|
84 |
+
def categories(
|
85 |
+
self,
|
86 |
+
):
|
87 |
+
return self.coco.dataset["categories"]
|
88 |
+
|
89 |
+
@property
|
90 |
+
def category2name(
|
91 |
+
self,
|
92 |
+
):
|
93 |
+
return {cat["id"]: cat["name"] for cat in self.categories}
|
94 |
+
|
95 |
+
@property
|
96 |
+
def category2label(
|
97 |
+
self,
|
98 |
+
):
|
99 |
+
return {cat["id"]: i for i, cat in enumerate(self.categories)}
|
100 |
+
|
101 |
+
@property
|
102 |
+
def label2category(
|
103 |
+
self,
|
104 |
+
):
|
105 |
+
return {i: cat["id"] for i, cat in enumerate(self.categories)}
|
106 |
+
|
107 |
+
|
108 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
109 |
+
masks = []
|
110 |
+
for polygons in segmentations:
|
111 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
112 |
+
mask = coco_mask.decode(rles)
|
113 |
+
if len(mask.shape) < 3:
|
114 |
+
mask = mask[..., None]
|
115 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
116 |
+
mask = mask.any(dim=2)
|
117 |
+
masks.append(mask)
|
118 |
+
if masks:
|
119 |
+
masks = torch.stack(masks, dim=0)
|
120 |
+
else:
|
121 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
122 |
+
return masks
|
123 |
+
|
124 |
+
|
125 |
+
class ConvertCocoPolysToMask(object):
|
126 |
+
def __init__(self, return_masks=False):
|
127 |
+
self.return_masks = return_masks
|
128 |
+
|
129 |
+
def __call__(self, image: Image.Image, target, **kwargs):
|
130 |
+
w, h = image.size
|
131 |
+
|
132 |
+
image_id = target["image_id"]
|
133 |
+
image_id = torch.tensor([image_id])
|
134 |
+
|
135 |
+
image_path = target["image_path"]
|
136 |
+
|
137 |
+
anno = target["annotations"]
|
138 |
+
|
139 |
+
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
140 |
+
|
141 |
+
boxes = [obj["bbox"] for obj in anno]
|
142 |
+
# guard against no boxes via resizing
|
143 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
144 |
+
boxes[:, 2:] += boxes[:, :2]
|
145 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
146 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
147 |
+
|
148 |
+
category2label = kwargs.get("category2label", None)
|
149 |
+
if category2label is not None:
|
150 |
+
labels = [category2label[obj["category_id"]] for obj in anno]
|
151 |
+
else:
|
152 |
+
labels = [obj["category_id"] for obj in anno]
|
153 |
+
|
154 |
+
labels = torch.tensor(labels, dtype=torch.int64)
|
155 |
+
|
156 |
+
if self.return_masks:
|
157 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
158 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
159 |
+
|
160 |
+
keypoints = None
|
161 |
+
if anno and "keypoints" in anno[0]:
|
162 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
163 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
164 |
+
num_keypoints = keypoints.shape[0]
|
165 |
+
if num_keypoints:
|
166 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
167 |
+
|
168 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
169 |
+
boxes = boxes[keep]
|
170 |
+
labels = labels[keep]
|
171 |
+
if self.return_masks:
|
172 |
+
masks = masks[keep]
|
173 |
+
if keypoints is not None:
|
174 |
+
keypoints = keypoints[keep]
|
175 |
+
|
176 |
+
target = {}
|
177 |
+
target["boxes"] = boxes
|
178 |
+
target["labels"] = labels
|
179 |
+
if self.return_masks:
|
180 |
+
target["masks"] = masks
|
181 |
+
target["image_id"] = image_id
|
182 |
+
target["image_path"] = image_path
|
183 |
+
if keypoints is not None:
|
184 |
+
target["keypoints"] = keypoints
|
185 |
+
|
186 |
+
# for conversion to coco api
|
187 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
188 |
+
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
189 |
+
target["area"] = area[keep]
|
190 |
+
target["iscrowd"] = iscrowd[keep]
|
191 |
+
|
192 |
+
target["orig_size"] = torch.as_tensor([int(w), int(h)])
|
193 |
+
# target["size"] = torch.as_tensor([int(w), int(h)])
|
194 |
+
|
195 |
+
return image, target
|
196 |
+
|
197 |
+
|
198 |
+
mscoco_category2name = {
|
199 |
+
1: "person",
|
200 |
+
2: "bicycle",
|
201 |
+
3: "car",
|
202 |
+
4: "motorcycle",
|
203 |
+
5: "airplane",
|
204 |
+
6: "bus",
|
205 |
+
7: "train",
|
206 |
+
8: "truck",
|
207 |
+
9: "boat",
|
208 |
+
10: "traffic light",
|
209 |
+
11: "fire hydrant",
|
210 |
+
13: "stop sign",
|
211 |
+
14: "parking meter",
|
212 |
+
15: "bench",
|
213 |
+
16: "bird",
|
214 |
+
17: "cat",
|
215 |
+
18: "dog",
|
216 |
+
19: "horse",
|
217 |
+
20: "sheep",
|
218 |
+
21: "cow",
|
219 |
+
22: "elephant",
|
220 |
+
23: "bear",
|
221 |
+
24: "zebra",
|
222 |
+
25: "giraffe",
|
223 |
+
27: "backpack",
|
224 |
+
28: "umbrella",
|
225 |
+
31: "handbag",
|
226 |
+
32: "tie",
|
227 |
+
33: "suitcase",
|
228 |
+
34: "frisbee",
|
229 |
+
35: "skis",
|
230 |
+
36: "snowboard",
|
231 |
+
37: "sports ball",
|
232 |
+
38: "kite",
|
233 |
+
39: "baseball bat",
|
234 |
+
40: "baseball glove",
|
235 |
+
41: "skateboard",
|
236 |
+
42: "surfboard",
|
237 |
+
43: "tennis racket",
|
238 |
+
44: "bottle",
|
239 |
+
46: "wine glass",
|
240 |
+
47: "cup",
|
241 |
+
48: "fork",
|
242 |
+
49: "knife",
|
243 |
+
50: "spoon",
|
244 |
+
51: "bowl",
|
245 |
+
52: "banana",
|
246 |
+
53: "apple",
|
247 |
+
54: "sandwich",
|
248 |
+
55: "orange",
|
249 |
+
56: "broccoli",
|
250 |
+
57: "carrot",
|
251 |
+
58: "hot dog",
|
252 |
+
59: "pizza",
|
253 |
+
60: "donut",
|
254 |
+
61: "cake",
|
255 |
+
62: "chair",
|
256 |
+
63: "couch",
|
257 |
+
64: "potted plant",
|
258 |
+
65: "bed",
|
259 |
+
67: "dining table",
|
260 |
+
70: "toilet",
|
261 |
+
72: "tv",
|
262 |
+
73: "laptop",
|
263 |
+
74: "mouse",
|
264 |
+
75: "remote",
|
265 |
+
76: "keyboard",
|
266 |
+
77: "cell phone",
|
267 |
+
78: "microwave",
|
268 |
+
79: "oven",
|
269 |
+
80: "toaster",
|
270 |
+
81: "sink",
|
271 |
+
82: "refrigerator",
|
272 |
+
84: "book",
|
273 |
+
85: "clock",
|
274 |
+
86: "vase",
|
275 |
+
87: "scissors",
|
276 |
+
88: "teddy bear",
|
277 |
+
89: "hair drier",
|
278 |
+
90: "toothbrush",
|
279 |
+
}
|
280 |
+
|
281 |
+
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
|
282 |
+
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
|
src/data/dataset/coco_eval.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
COCO evaluator that works in distributed mode.
|
4 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
5 |
+
The difference is that there is less copy-pasting from pycocotools
|
6 |
+
in the end of the file, as python3 can suppress prints with contextlib
|
7 |
+
"""
|
8 |
+
|
9 |
+
import contextlib
|
10 |
+
import copy
|
11 |
+
import os
|
12 |
+
|
13 |
+
import faster_coco_eval.core.mask as mask_util
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from faster_coco_eval import COCO, COCOeval_faster
|
17 |
+
|
18 |
+
from ...core import register
|
19 |
+
from ...misc import dist_utils
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"CocoEvaluator",
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
@register()
|
27 |
+
class CocoEvaluator(object):
|
28 |
+
def __init__(self, coco_gt, iou_types):
|
29 |
+
assert isinstance(iou_types, (list, tuple))
|
30 |
+
coco_gt = copy.deepcopy(coco_gt)
|
31 |
+
self.coco_gt: COCO = coco_gt
|
32 |
+
self.iou_types = iou_types
|
33 |
+
|
34 |
+
self.coco_eval = {}
|
35 |
+
for iou_type in iou_types:
|
36 |
+
self.coco_eval[iou_type] = COCOeval_faster(
|
37 |
+
coco_gt, iouType=iou_type, print_function=print, separate_eval=True
|
38 |
+
)
|
39 |
+
|
40 |
+
self.img_ids = []
|
41 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
42 |
+
|
43 |
+
def cleanup(self):
|
44 |
+
self.coco_eval = {}
|
45 |
+
for iou_type in self.iou_types:
|
46 |
+
self.coco_eval[iou_type] = COCOeval_faster(
|
47 |
+
self.coco_gt, iouType=iou_type, print_function=print, separate_eval=True
|
48 |
+
)
|
49 |
+
self.img_ids = []
|
50 |
+
self.eval_imgs = {k: [] for k in self.iou_types}
|
51 |
+
|
52 |
+
def update(self, predictions):
|
53 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
54 |
+
self.img_ids.extend(img_ids)
|
55 |
+
|
56 |
+
for iou_type in self.iou_types:
|
57 |
+
results = self.prepare(predictions, iou_type)
|
58 |
+
coco_eval = self.coco_eval[iou_type]
|
59 |
+
|
60 |
+
# suppress pycocotools prints
|
61 |
+
with open(os.devnull, "w") as devnull:
|
62 |
+
with contextlib.redirect_stdout(devnull):
|
63 |
+
coco_dt = self.coco_gt.loadRes(results) if results else COCO()
|
64 |
+
coco_eval.cocoDt = coco_dt
|
65 |
+
coco_eval.params.imgIds = list(img_ids)
|
66 |
+
coco_eval.evaluate()
|
67 |
+
|
68 |
+
self.eval_imgs[iou_type].append(
|
69 |
+
np.array(coco_eval._evalImgs_cpp).reshape(
|
70 |
+
len(coco_eval.params.catIds),
|
71 |
+
len(coco_eval.params.areaRng),
|
72 |
+
len(coco_eval.params.imgIds),
|
73 |
+
)
|
74 |
+
)
|
75 |
+
|
76 |
+
def synchronize_between_processes(self):
|
77 |
+
for iou_type in self.iou_types:
|
78 |
+
img_ids, eval_imgs = merge(self.img_ids, self.eval_imgs[iou_type])
|
79 |
+
|
80 |
+
coco_eval = self.coco_eval[iou_type]
|
81 |
+
coco_eval.params.imgIds = img_ids
|
82 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
83 |
+
coco_eval._evalImgs_cpp = eval_imgs
|
84 |
+
|
85 |
+
def accumulate(self):
|
86 |
+
for coco_eval in self.coco_eval.values():
|
87 |
+
coco_eval.accumulate()
|
88 |
+
|
89 |
+
def summarize(self):
|
90 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
91 |
+
print("IoU metric: {}".format(iou_type))
|
92 |
+
coco_eval.summarize()
|
93 |
+
|
94 |
+
def prepare(self, predictions, iou_type):
|
95 |
+
if iou_type == "bbox":
|
96 |
+
return self.prepare_for_coco_detection(predictions)
|
97 |
+
elif iou_type == "segm":
|
98 |
+
return self.prepare_for_coco_segmentation(predictions)
|
99 |
+
elif iou_type == "keypoints":
|
100 |
+
return self.prepare_for_coco_keypoint(predictions)
|
101 |
+
else:
|
102 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
103 |
+
|
104 |
+
def prepare_for_coco_detection(self, predictions):
|
105 |
+
coco_results = []
|
106 |
+
for original_id, prediction in predictions.items():
|
107 |
+
if len(prediction) == 0:
|
108 |
+
continue
|
109 |
+
|
110 |
+
boxes = prediction["boxes"]
|
111 |
+
boxes = convert_to_xywh(boxes).tolist()
|
112 |
+
scores = prediction["scores"].tolist()
|
113 |
+
labels = prediction["labels"].tolist()
|
114 |
+
|
115 |
+
coco_results.extend(
|
116 |
+
[
|
117 |
+
{
|
118 |
+
"image_id": original_id,
|
119 |
+
"category_id": labels[k],
|
120 |
+
"bbox": box,
|
121 |
+
"score": scores[k],
|
122 |
+
}
|
123 |
+
for k, box in enumerate(boxes)
|
124 |
+
]
|
125 |
+
)
|
126 |
+
return coco_results
|
127 |
+
|
128 |
+
def prepare_for_coco_segmentation(self, predictions):
|
129 |
+
coco_results = []
|
130 |
+
for original_id, prediction in predictions.items():
|
131 |
+
if len(prediction) == 0:
|
132 |
+
continue
|
133 |
+
|
134 |
+
scores = prediction["scores"]
|
135 |
+
labels = prediction["labels"]
|
136 |
+
masks = prediction["masks"]
|
137 |
+
|
138 |
+
masks = masks > 0.5
|
139 |
+
|
140 |
+
scores = prediction["scores"].tolist()
|
141 |
+
labels = prediction["labels"].tolist()
|
142 |
+
|
143 |
+
rles = [
|
144 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
145 |
+
for mask in masks
|
146 |
+
]
|
147 |
+
for rle in rles:
|
148 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
149 |
+
|
150 |
+
coco_results.extend(
|
151 |
+
[
|
152 |
+
{
|
153 |
+
"image_id": original_id,
|
154 |
+
"category_id": labels[k],
|
155 |
+
"segmentation": rle,
|
156 |
+
"score": scores[k],
|
157 |
+
}
|
158 |
+
for k, rle in enumerate(rles)
|
159 |
+
]
|
160 |
+
)
|
161 |
+
return coco_results
|
162 |
+
|
163 |
+
def prepare_for_coco_keypoint(self, predictions):
|
164 |
+
coco_results = []
|
165 |
+
for original_id, prediction in predictions.items():
|
166 |
+
if len(prediction) == 0:
|
167 |
+
continue
|
168 |
+
|
169 |
+
boxes = prediction["boxes"]
|
170 |
+
boxes = convert_to_xywh(boxes).tolist()
|
171 |
+
scores = prediction["scores"].tolist()
|
172 |
+
labels = prediction["labels"].tolist()
|
173 |
+
keypoints = prediction["keypoints"]
|
174 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
175 |
+
|
176 |
+
coco_results.extend(
|
177 |
+
[
|
178 |
+
{
|
179 |
+
"image_id": original_id,
|
180 |
+
"category_id": labels[k],
|
181 |
+
"keypoints": keypoint,
|
182 |
+
"score": scores[k],
|
183 |
+
}
|
184 |
+
for k, keypoint in enumerate(keypoints)
|
185 |
+
]
|
186 |
+
)
|
187 |
+
return coco_results
|
188 |
+
|
189 |
+
|
190 |
+
def convert_to_xywh(boxes):
|
191 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
192 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
193 |
+
|
194 |
+
|
195 |
+
def merge(img_ids, eval_imgs):
|
196 |
+
all_img_ids = dist_utils.all_gather(img_ids)
|
197 |
+
all_eval_imgs = dist_utils.all_gather(eval_imgs)
|
198 |
+
|
199 |
+
merged_img_ids = []
|
200 |
+
for p in all_img_ids:
|
201 |
+
merged_img_ids.extend(p)
|
202 |
+
|
203 |
+
merged_eval_imgs = []
|
204 |
+
for p in all_eval_imgs:
|
205 |
+
merged_eval_imgs.extend(p)
|
206 |
+
|
207 |
+
merged_img_ids = np.array(merged_img_ids)
|
208 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, axis=2).ravel()
|
209 |
+
# merged_eval_imgs = np.array(merged_eval_imgs).T.ravel()
|
210 |
+
|
211 |
+
# keep only unique (and in sorted order) images
|
212 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
213 |
+
|
214 |
+
return merged_img_ids.tolist(), merged_eval_imgs.tolist()
|
src/data/dataset/coco_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
copy and modified https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py
|
3 |
+
|
4 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import faster_coco_eval.core.mask as coco_mask
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
import torchvision
|
11 |
+
import torchvision.transforms.functional as TVF
|
12 |
+
from faster_coco_eval import COCO
|
13 |
+
|
14 |
+
|
15 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
16 |
+
masks = []
|
17 |
+
for polygons in segmentations:
|
18 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
19 |
+
mask = coco_mask.decode(rles)
|
20 |
+
if len(mask.shape) < 3:
|
21 |
+
mask = mask[..., None]
|
22 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
23 |
+
mask = mask.any(dim=2)
|
24 |
+
masks.append(mask)
|
25 |
+
if masks:
|
26 |
+
masks = torch.stack(masks, dim=0)
|
27 |
+
else:
|
28 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
29 |
+
return masks
|
30 |
+
|
31 |
+
|
32 |
+
class ConvertCocoPolysToMask:
|
33 |
+
def __call__(self, image, target):
|
34 |
+
w, h = image.size
|
35 |
+
|
36 |
+
image_id = target["image_id"]
|
37 |
+
|
38 |
+
anno = target["annotations"]
|
39 |
+
|
40 |
+
anno = [obj for obj in anno if obj["iscrowd"] == 0]
|
41 |
+
|
42 |
+
boxes = [obj["bbox"] for obj in anno]
|
43 |
+
# guard against no boxes via resizing
|
44 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
45 |
+
boxes[:, 2:] += boxes[:, :2]
|
46 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
47 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
48 |
+
|
49 |
+
classes = [obj["category_id"] for obj in anno]
|
50 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
51 |
+
|
52 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
53 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
54 |
+
|
55 |
+
keypoints = None
|
56 |
+
if anno and "keypoints" in anno[0]:
|
57 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
58 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
59 |
+
num_keypoints = keypoints.shape[0]
|
60 |
+
if num_keypoints:
|
61 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
62 |
+
|
63 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
64 |
+
boxes = boxes[keep]
|
65 |
+
classes = classes[keep]
|
66 |
+
masks = masks[keep]
|
67 |
+
if keypoints is not None:
|
68 |
+
keypoints = keypoints[keep]
|
69 |
+
|
70 |
+
target = {}
|
71 |
+
target["boxes"] = boxes
|
72 |
+
target["labels"] = classes
|
73 |
+
target["masks"] = masks
|
74 |
+
target["image_id"] = image_id
|
75 |
+
if keypoints is not None:
|
76 |
+
target["keypoints"] = keypoints
|
77 |
+
|
78 |
+
# for conversion to coco api
|
79 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
80 |
+
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
|
81 |
+
target["area"] = area
|
82 |
+
target["iscrowd"] = iscrowd
|
83 |
+
|
84 |
+
return image, target
|
85 |
+
|
86 |
+
|
87 |
+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
88 |
+
def _has_only_empty_bbox(anno):
|
89 |
+
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
90 |
+
|
91 |
+
def _count_visible_keypoints(anno):
|
92 |
+
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
93 |
+
|
94 |
+
min_keypoints_per_image = 10
|
95 |
+
|
96 |
+
def _has_valid_annotation(anno):
|
97 |
+
# if it's empty, there is no annotation
|
98 |
+
if len(anno) == 0:
|
99 |
+
return False
|
100 |
+
# if all boxes have close to zero area, there is no annotation
|
101 |
+
if _has_only_empty_bbox(anno):
|
102 |
+
return False
|
103 |
+
# keypoints task have a slight different criteria for considering
|
104 |
+
# if an annotation is valid
|
105 |
+
if "keypoints" not in anno[0]:
|
106 |
+
return True
|
107 |
+
# for keypoint detection tasks, only consider valid images those
|
108 |
+
# containing at least min_keypoints_per_image
|
109 |
+
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
110 |
+
return True
|
111 |
+
return False
|
112 |
+
|
113 |
+
ids = []
|
114 |
+
for ds_idx, img_id in enumerate(dataset.ids):
|
115 |
+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
116 |
+
anno = dataset.coco.loadAnns(ann_ids)
|
117 |
+
if cat_list:
|
118 |
+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
119 |
+
if _has_valid_annotation(anno):
|
120 |
+
ids.append(ds_idx)
|
121 |
+
|
122 |
+
dataset = torch.utils.data.Subset(dataset, ids)
|
123 |
+
return dataset
|
124 |
+
|
125 |
+
|
126 |
+
def convert_to_coco_api(ds):
|
127 |
+
coco_ds = COCO()
|
128 |
+
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
|
129 |
+
ann_id = 1
|
130 |
+
dataset = {"images": [], "categories": [], "annotations": []}
|
131 |
+
categories = set()
|
132 |
+
for img_idx in range(len(ds)):
|
133 |
+
# find better way to get target
|
134 |
+
# targets = ds.get_annotations(img_idx)
|
135 |
+
# img, targets = ds[img_idx]
|
136 |
+
|
137 |
+
img, targets = ds.load_item(img_idx)
|
138 |
+
width, height = img.size
|
139 |
+
|
140 |
+
image_id = targets["image_id"].item()
|
141 |
+
img_dict = {}
|
142 |
+
img_dict["id"] = image_id
|
143 |
+
img_dict["width"] = width
|
144 |
+
img_dict["height"] = height
|
145 |
+
dataset["images"].append(img_dict)
|
146 |
+
bboxes = targets["boxes"].clone()
|
147 |
+
bboxes[:, 2:] -= bboxes[:, :2] # xyxy -> xywh
|
148 |
+
bboxes = bboxes.tolist()
|
149 |
+
labels = targets["labels"].tolist()
|
150 |
+
areas = targets["area"].tolist()
|
151 |
+
iscrowd = targets["iscrowd"].tolist()
|
152 |
+
if "masks" in targets:
|
153 |
+
masks = targets["masks"]
|
154 |
+
# make masks Fortran contiguous for coco_mask
|
155 |
+
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
156 |
+
if "keypoints" in targets:
|
157 |
+
keypoints = targets["keypoints"]
|
158 |
+
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
|
159 |
+
num_objs = len(bboxes)
|
160 |
+
for i in range(num_objs):
|
161 |
+
ann = {}
|
162 |
+
ann["image_id"] = image_id
|
163 |
+
ann["bbox"] = bboxes[i]
|
164 |
+
ann["category_id"] = labels[i]
|
165 |
+
categories.add(labels[i])
|
166 |
+
ann["area"] = areas[i]
|
167 |
+
ann["iscrowd"] = iscrowd[i]
|
168 |
+
ann["id"] = ann_id
|
169 |
+
if "masks" in targets:
|
170 |
+
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
|
171 |
+
if "keypoints" in targets:
|
172 |
+
ann["keypoints"] = keypoints[i]
|
173 |
+
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
|
174 |
+
dataset["annotations"].append(ann)
|
175 |
+
ann_id += 1
|
176 |
+
dataset["categories"] = [{"id": i} for i in sorted(categories)]
|
177 |
+
coco_ds.dataset = dataset
|
178 |
+
coco_ds.createIndex()
|
179 |
+
return coco_ds
|
180 |
+
|
181 |
+
|
182 |
+
def get_coco_api_from_dataset(dataset):
|
183 |
+
# FIXME: This is... awful?
|
184 |
+
for _ in range(10):
|
185 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
186 |
+
break
|
187 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
188 |
+
dataset = dataset.dataset
|
189 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
190 |
+
return dataset.coco
|
191 |
+
return convert_to_coco_api(dataset)
|
src/data/dataset/voc_detection.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
import torchvision.transforms.functional as TVF
|
12 |
+
from PIL import Image
|
13 |
+
from sympy import im
|
14 |
+
|
15 |
+
try:
|
16 |
+
from defusedxml.ElementTree import parse as ET_parse
|
17 |
+
except ImportError:
|
18 |
+
from xml.etree.ElementTree import parse as ET_parse
|
19 |
+
|
20 |
+
from ...core import register
|
21 |
+
from .._misc import convert_to_tv_tensor
|
22 |
+
from ._dataset import DetDataset
|
23 |
+
|
24 |
+
|
25 |
+
@register()
|
26 |
+
class VOCDetection(torchvision.datasets.VOCDetection, DetDataset):
|
27 |
+
__inject__ = [
|
28 |
+
"transforms",
|
29 |
+
]
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
root: str,
|
34 |
+
ann_file: str = "trainval.txt",
|
35 |
+
label_file: str = "label_list.txt",
|
36 |
+
transforms: Optional[Callable] = None,
|
37 |
+
):
|
38 |
+
with open(os.path.join(root, ann_file), "r") as f:
|
39 |
+
lines = [x.strip() for x in f.readlines()]
|
40 |
+
lines = [x.split(" ") for x in lines]
|
41 |
+
|
42 |
+
self.images = [os.path.join(root, lin[0]) for lin in lines]
|
43 |
+
self.targets = [os.path.join(root, lin[1]) for lin in lines]
|
44 |
+
assert len(self.images) == len(self.targets)
|
45 |
+
|
46 |
+
with open(os.path.join(root + label_file), "r") as f:
|
47 |
+
labels = f.readlines()
|
48 |
+
labels = [lab.strip() for lab in labels]
|
49 |
+
|
50 |
+
self.transforms = transforms
|
51 |
+
self.labels_map = {lab: i for i, lab in enumerate(labels)}
|
52 |
+
|
53 |
+
def __getitem__(self, index: int):
|
54 |
+
image, target = self.load_item(index)
|
55 |
+
if self.transforms is not None:
|
56 |
+
image, target, _ = self.transforms(image, target, self)
|
57 |
+
# target["orig_size"] = torch.tensor(TVF.get_image_size(image))
|
58 |
+
return image, target
|
59 |
+
|
60 |
+
def load_item(self, index: int):
|
61 |
+
image = Image.open(self.images[index]).convert("RGB")
|
62 |
+
target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
|
63 |
+
|
64 |
+
output = {}
|
65 |
+
output["image_id"] = torch.tensor([index])
|
66 |
+
for k in ["area", "boxes", "labels", "iscrowd"]:
|
67 |
+
output[k] = []
|
68 |
+
|
69 |
+
for blob in target["annotation"]["object"]:
|
70 |
+
box = [float(v) for v in blob["bndbox"].values()]
|
71 |
+
output["boxes"].append(box)
|
72 |
+
output["labels"].append(blob["name"])
|
73 |
+
output["area"].append((box[2] - box[0]) * (box[3] - box[1]))
|
74 |
+
output["iscrowd"].append(0)
|
75 |
+
|
76 |
+
w, h = image.size
|
77 |
+
boxes = torch.tensor(output["boxes"]) if len(output["boxes"]) > 0 else torch.zeros(0, 4)
|
78 |
+
output["boxes"] = convert_to_tv_tensor(
|
79 |
+
boxes, "boxes", box_format="xyxy", spatial_size=[h, w]
|
80 |
+
)
|
81 |
+
output["labels"] = torch.tensor([self.labels_map[lab] for lab in output["labels"]])
|
82 |
+
output["area"] = torch.tensor(output["area"])
|
83 |
+
output["iscrowd"] = torch.tensor(output["iscrowd"])
|
84 |
+
output["orig_size"] = torch.tensor([w, h])
|
85 |
+
|
86 |
+
return image, output
|
src/data/dataset/voc_eval.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
|
10 |
+
class VOCEvaluator(object):
|
11 |
+
def __init__(self) -> None:
|
12 |
+
pass
|
src/data/transforms/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from ._transforms import (
|
7 |
+
ConvertBoxes,
|
8 |
+
ConvertPILImage,
|
9 |
+
EmptyTransform,
|
10 |
+
Normalize,
|
11 |
+
PadToSize,
|
12 |
+
RandomCrop,
|
13 |
+
RandomHorizontalFlip,
|
14 |
+
RandomIoUCrop,
|
15 |
+
RandomPhotometricDistort,
|
16 |
+
RandomZoomOut,
|
17 |
+
Resize,
|
18 |
+
SanitizeBoundingBoxes,
|
19 |
+
)
|
20 |
+
from .container import Compose
|
21 |
+
from .mosaic import Mosaic
|
src/data/transforms/_transforms.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Any, Dict, List, Optional
|
7 |
+
|
8 |
+
import PIL
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torchvision
|
13 |
+
import torchvision.transforms.v2 as T
|
14 |
+
import torchvision.transforms.v2.functional as F
|
15 |
+
|
16 |
+
from ...core import register
|
17 |
+
from .._misc import (
|
18 |
+
BoundingBoxes,
|
19 |
+
Image,
|
20 |
+
Mask,
|
21 |
+
SanitizeBoundingBoxes,
|
22 |
+
Video,
|
23 |
+
_boxes_keys,
|
24 |
+
convert_to_tv_tensor,
|
25 |
+
)
|
26 |
+
|
27 |
+
torchvision.disable_beta_transforms_warning()
|
28 |
+
|
29 |
+
|
30 |
+
RandomPhotometricDistort = register()(T.RandomPhotometricDistort)
|
31 |
+
RandomZoomOut = register()(T.RandomZoomOut)
|
32 |
+
RandomHorizontalFlip = register()(T.RandomHorizontalFlip)
|
33 |
+
Resize = register()(T.Resize)
|
34 |
+
# ToImageTensor = register()(T.ToImageTensor)
|
35 |
+
# ConvertDtype = register()(T.ConvertDtype)
|
36 |
+
# PILToTensor = register()(T.PILToTensor)
|
37 |
+
SanitizeBoundingBoxes = register(name="SanitizeBoundingBoxes")(SanitizeBoundingBoxes)
|
38 |
+
RandomCrop = register()(T.RandomCrop)
|
39 |
+
Normalize = register()(T.Normalize)
|
40 |
+
|
41 |
+
|
42 |
+
@register()
|
43 |
+
class EmptyTransform(T.Transform):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
def forward(self, *inputs):
|
50 |
+
inputs = inputs if len(inputs) > 1 else inputs[0]
|
51 |
+
return inputs
|
52 |
+
|
53 |
+
|
54 |
+
@register()
|
55 |
+
class PadToSize(T.Pad):
|
56 |
+
_transformed_types = (
|
57 |
+
PIL.Image.Image,
|
58 |
+
Image,
|
59 |
+
Video,
|
60 |
+
Mask,
|
61 |
+
BoundingBoxes,
|
62 |
+
)
|
63 |
+
|
64 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
65 |
+
sp = F.get_spatial_size(flat_inputs[0])
|
66 |
+
h, w = self.size[1] - sp[0], self.size[0] - sp[1]
|
67 |
+
self.padding = [0, 0, w, h]
|
68 |
+
return dict(padding=self.padding)
|
69 |
+
|
70 |
+
def __init__(self, size, fill=0, padding_mode="constant") -> None:
|
71 |
+
if isinstance(size, int):
|
72 |
+
size = (size, size)
|
73 |
+
self.size = size
|
74 |
+
super().__init__(0, fill, padding_mode)
|
75 |
+
|
76 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
77 |
+
fill = self._fill[type(inpt)]
|
78 |
+
padding = params["padding"]
|
79 |
+
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
|
80 |
+
|
81 |
+
def __call__(self, *inputs: Any) -> Any:
|
82 |
+
outputs = super().forward(*inputs)
|
83 |
+
if len(outputs) > 1 and isinstance(outputs[1], dict):
|
84 |
+
outputs[1]["padding"] = torch.tensor(self.padding)
|
85 |
+
return outputs
|
86 |
+
|
87 |
+
|
88 |
+
@register()
|
89 |
+
class RandomIoUCrop(T.RandomIoUCrop):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
min_scale: float = 0.3,
|
93 |
+
max_scale: float = 1,
|
94 |
+
min_aspect_ratio: float = 0.5,
|
95 |
+
max_aspect_ratio: float = 2,
|
96 |
+
sampler_options: Optional[List[float]] = None,
|
97 |
+
trials: int = 40,
|
98 |
+
p: float = 1.0,
|
99 |
+
):
|
100 |
+
super().__init__(
|
101 |
+
min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials
|
102 |
+
)
|
103 |
+
self.p = p
|
104 |
+
|
105 |
+
def __call__(self, *inputs: Any) -> Any:
|
106 |
+
if torch.rand(1) >= self.p:
|
107 |
+
return inputs if len(inputs) > 1 else inputs[0]
|
108 |
+
|
109 |
+
return super().forward(*inputs)
|
110 |
+
|
111 |
+
|
112 |
+
@register()
|
113 |
+
class ConvertBoxes(T.Transform):
|
114 |
+
_transformed_types = (BoundingBoxes,)
|
115 |
+
|
116 |
+
def __init__(self, fmt="", normalize=False) -> None:
|
117 |
+
super().__init__()
|
118 |
+
self.fmt = fmt
|
119 |
+
self.normalize = normalize
|
120 |
+
|
121 |
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
122 |
+
return self._transform(inpt, params)
|
123 |
+
|
124 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
125 |
+
spatial_size = getattr(inpt, _boxes_keys[1])
|
126 |
+
if self.fmt:
|
127 |
+
in_fmt = inpt.format.value.lower()
|
128 |
+
inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.fmt.lower())
|
129 |
+
inpt = convert_to_tv_tensor(
|
130 |
+
inpt, key="boxes", box_format=self.fmt.upper(), spatial_size=spatial_size
|
131 |
+
)
|
132 |
+
|
133 |
+
if self.normalize:
|
134 |
+
inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None]
|
135 |
+
|
136 |
+
return inpt
|
137 |
+
|
138 |
+
|
139 |
+
@register()
|
140 |
+
class ConvertPILImage(T.Transform):
|
141 |
+
_transformed_types = (PIL.Image.Image,)
|
142 |
+
|
143 |
+
def __init__(self, dtype="float32", scale=True) -> None:
|
144 |
+
super().__init__()
|
145 |
+
self.dtype = dtype
|
146 |
+
self.scale = scale
|
147 |
+
|
148 |
+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
149 |
+
return self._transform(inpt, params)
|
150 |
+
|
151 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
152 |
+
inpt = F.pil_to_tensor(inpt)
|
153 |
+
if self.dtype == "float32":
|
154 |
+
inpt = inpt.float()
|
155 |
+
|
156 |
+
if self.scale:
|
157 |
+
inpt = inpt / 255.0
|
158 |
+
|
159 |
+
inpt = Image(inpt)
|
160 |
+
|
161 |
+
return inpt
|
src/data/transforms/container.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Any, Dict, List, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torchvision
|
11 |
+
import torchvision.transforms.v2 as T
|
12 |
+
|
13 |
+
from ...core import GLOBAL_CONFIG, register
|
14 |
+
from ._transforms import EmptyTransform
|
15 |
+
|
16 |
+
torchvision.disable_beta_transforms_warning()
|
17 |
+
|
18 |
+
|
19 |
+
@register()
|
20 |
+
class Compose(T.Compose):
|
21 |
+
def __init__(self, ops, policy=None) -> None:
|
22 |
+
transforms = []
|
23 |
+
if ops is not None:
|
24 |
+
for op in ops:
|
25 |
+
if isinstance(op, dict):
|
26 |
+
name = op.pop("type")
|
27 |
+
transform = getattr(
|
28 |
+
GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"]
|
29 |
+
)(**op)
|
30 |
+
transforms.append(transform)
|
31 |
+
op["type"] = name
|
32 |
+
|
33 |
+
elif isinstance(op, nn.Module):
|
34 |
+
transforms.append(op)
|
35 |
+
|
36 |
+
else:
|
37 |
+
raise ValueError("")
|
38 |
+
else:
|
39 |
+
transforms = [
|
40 |
+
EmptyTransform(),
|
41 |
+
]
|
42 |
+
|
43 |
+
super().__init__(transforms=transforms)
|
44 |
+
|
45 |
+
if policy is None:
|
46 |
+
policy = {"name": "default"}
|
47 |
+
|
48 |
+
self.policy = policy
|
49 |
+
self.global_samples = 0
|
50 |
+
|
51 |
+
def forward(self, *inputs: Any) -> Any:
|
52 |
+
return self.get_forward(self.policy["name"])(*inputs)
|
53 |
+
|
54 |
+
def get_forward(self, name):
|
55 |
+
forwards = {
|
56 |
+
"default": self.default_forward,
|
57 |
+
"stop_epoch": self.stop_epoch_forward,
|
58 |
+
"stop_sample": self.stop_sample_forward,
|
59 |
+
}
|
60 |
+
return forwards[name]
|
61 |
+
|
62 |
+
def default_forward(self, *inputs: Any) -> Any:
|
63 |
+
sample = inputs if len(inputs) > 1 else inputs[0]
|
64 |
+
for transform in self.transforms:
|
65 |
+
sample = transform(sample)
|
66 |
+
return sample
|
67 |
+
|
68 |
+
def stop_epoch_forward(self, *inputs: Any):
|
69 |
+
sample = inputs if len(inputs) > 1 else inputs[0]
|
70 |
+
dataset = sample[-1]
|
71 |
+
cur_epoch = dataset.epoch
|
72 |
+
policy_ops = self.policy["ops"]
|
73 |
+
policy_epoch = self.policy["epoch"]
|
74 |
+
|
75 |
+
for transform in self.transforms:
|
76 |
+
if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch:
|
77 |
+
pass
|
78 |
+
else:
|
79 |
+
sample = transform(sample)
|
80 |
+
|
81 |
+
return sample
|
82 |
+
|
83 |
+
def stop_sample_forward(self, *inputs: Any):
|
84 |
+
sample = inputs if len(inputs) > 1 else inputs[0]
|
85 |
+
dataset = sample[-1]
|
86 |
+
|
87 |
+
cur_epoch = dataset.epoch
|
88 |
+
policy_ops = self.policy["ops"]
|
89 |
+
policy_sample = self.policy["sample"]
|
90 |
+
|
91 |
+
for transform in self.transforms:
|
92 |
+
if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample:
|
93 |
+
pass
|
94 |
+
else:
|
95 |
+
sample = transform(sample)
|
96 |
+
|
97 |
+
self.global_samples += 1
|
98 |
+
|
99 |
+
return sample
|
src/data/transforms/functional.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
6 |
+
import torchvision
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
from packaging import version
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
12 |
+
from torchvision.ops import _new_empty_tensor
|
13 |
+
from torchvision.ops.misc import _output_size
|
14 |
+
|
15 |
+
|
16 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
17 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
18 |
+
"""
|
19 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
20 |
+
This will eventually be supported natively by PyTorch, and this
|
21 |
+
class can go away.
|
22 |
+
"""
|
23 |
+
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
24 |
+
if input.numel() > 0:
|
25 |
+
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
26 |
+
|
27 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
28 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
29 |
+
return _new_empty_tensor(input, output_shape)
|
30 |
+
else:
|
31 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
32 |
+
|
33 |
+
|
34 |
+
def crop(image, target, region):
|
35 |
+
cropped_image = F.crop(image, *region)
|
36 |
+
|
37 |
+
target = target.copy()
|
38 |
+
i, j, h, w = region
|
39 |
+
|
40 |
+
# should we do something wrt the original size?
|
41 |
+
target["size"] = torch.tensor([h, w])
|
42 |
+
|
43 |
+
fields = ["labels", "area", "iscrowd"]
|
44 |
+
|
45 |
+
if "boxes" in target:
|
46 |
+
boxes = target["boxes"]
|
47 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
48 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
49 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
50 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
51 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
52 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
53 |
+
target["area"] = area
|
54 |
+
fields.append("boxes")
|
55 |
+
|
56 |
+
if "masks" in target:
|
57 |
+
# FIXME should we update the area here if there are no boxes?
|
58 |
+
target["masks"] = target["masks"][:, i : i + h, j : j + w]
|
59 |
+
fields.append("masks")
|
60 |
+
|
61 |
+
# remove elements for which the boxes or masks that have zero area
|
62 |
+
if "boxes" in target or "masks" in target:
|
63 |
+
# favor boxes selection when defining which elements to keep
|
64 |
+
# this is compatible with previous implementation
|
65 |
+
if "boxes" in target:
|
66 |
+
cropped_boxes = target["boxes"].reshape(-1, 2, 2)
|
67 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
68 |
+
else:
|
69 |
+
keep = target["masks"].flatten(1).any(1)
|
70 |
+
|
71 |
+
for field in fields:
|
72 |
+
target[field] = target[field][keep]
|
73 |
+
|
74 |
+
return cropped_image, target
|
75 |
+
|
76 |
+
|
77 |
+
def hflip(image, target):
|
78 |
+
flipped_image = F.hflip(image)
|
79 |
+
|
80 |
+
w, h = image.size
|
81 |
+
|
82 |
+
target = target.copy()
|
83 |
+
if "boxes" in target:
|
84 |
+
boxes = target["boxes"]
|
85 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
|
86 |
+
[w, 0, w, 0]
|
87 |
+
)
|
88 |
+
target["boxes"] = boxes
|
89 |
+
|
90 |
+
if "masks" in target:
|
91 |
+
target["masks"] = target["masks"].flip(-1)
|
92 |
+
|
93 |
+
return flipped_image, target
|
94 |
+
|
95 |
+
|
96 |
+
def resize(image, target, size, max_size=None):
|
97 |
+
# size can be min_size (scalar) or (w, h) tuple
|
98 |
+
|
99 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
100 |
+
w, h = image_size
|
101 |
+
if max_size is not None:
|
102 |
+
min_original_size = float(min((w, h)))
|
103 |
+
max_original_size = float(max((w, h)))
|
104 |
+
if max_original_size / min_original_size * size > max_size:
|
105 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
106 |
+
|
107 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
108 |
+
return (h, w)
|
109 |
+
|
110 |
+
if w < h:
|
111 |
+
ow = size
|
112 |
+
oh = int(size * h / w)
|
113 |
+
else:
|
114 |
+
oh = size
|
115 |
+
ow = int(size * w / h)
|
116 |
+
|
117 |
+
# r = min(size / min(h, w), max_size / max(h, w))
|
118 |
+
# ow = int(w * r)
|
119 |
+
# oh = int(h * r)
|
120 |
+
|
121 |
+
return (oh, ow)
|
122 |
+
|
123 |
+
def get_size(image_size, size, max_size=None):
|
124 |
+
if isinstance(size, (list, tuple)):
|
125 |
+
return size[::-1]
|
126 |
+
else:
|
127 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
128 |
+
|
129 |
+
size = get_size(image.size, size, max_size)
|
130 |
+
rescaled_image = F.resize(image, size)
|
131 |
+
|
132 |
+
if target is None:
|
133 |
+
return rescaled_image, None
|
134 |
+
|
135 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
136 |
+
ratio_width, ratio_height = ratios
|
137 |
+
|
138 |
+
target = target.copy()
|
139 |
+
if "boxes" in target:
|
140 |
+
boxes = target["boxes"]
|
141 |
+
scaled_boxes = boxes * torch.as_tensor(
|
142 |
+
[ratio_width, ratio_height, ratio_width, ratio_height]
|
143 |
+
)
|
144 |
+
target["boxes"] = scaled_boxes
|
145 |
+
|
146 |
+
if "area" in target:
|
147 |
+
area = target["area"]
|
148 |
+
scaled_area = area * (ratio_width * ratio_height)
|
149 |
+
target["area"] = scaled_area
|
150 |
+
|
151 |
+
h, w = size
|
152 |
+
target["size"] = torch.tensor([h, w])
|
153 |
+
|
154 |
+
if "masks" in target:
|
155 |
+
target["masks"] = (
|
156 |
+
interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
157 |
+
)
|
158 |
+
|
159 |
+
return rescaled_image, target
|
160 |
+
|
161 |
+
|
162 |
+
def pad(image, target, padding):
|
163 |
+
# assumes that we only pad on the bottom right corners
|
164 |
+
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
165 |
+
if target is None:
|
166 |
+
return padded_image, None
|
167 |
+
target = target.copy()
|
168 |
+
# should we do something wrt the original size?
|
169 |
+
target["size"] = torch.tensor(padded_image.size[::-1])
|
170 |
+
if "masks" in target:
|
171 |
+
target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
|
172 |
+
return padded_image, target
|
src/data/transforms/mosaic.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import random
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
import torchvision.transforms.v2 as T
|
11 |
+
import torchvision.transforms.v2.functional as F
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from ...core import register
|
15 |
+
from .._misc import convert_to_tv_tensor
|
16 |
+
|
17 |
+
torchvision.disable_beta_transforms_warning()
|
18 |
+
|
19 |
+
|
20 |
+
@register()
|
21 |
+
class Mosaic(T.Transform):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
size,
|
25 |
+
max_size=None,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
self.resize = T.Resize(size=size, max_size=max_size)
|
29 |
+
self.crop = T.RandomCrop(size=max_size if max_size else size)
|
30 |
+
|
31 |
+
# TODO add arg `output_size` for affine`
|
32 |
+
# self.random_perspective = T.RandomPerspective(distortion_scale=0.5, p=1., )
|
33 |
+
self.random_affine = T.RandomAffine(
|
34 |
+
degrees=0, translate=(0.1, 0.1), scale=(0.5, 1.5), fill=114
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, *inputs):
|
38 |
+
inputs = inputs if len(inputs) > 1 else inputs[0]
|
39 |
+
image, target, dataset = inputs
|
40 |
+
|
41 |
+
images = []
|
42 |
+
targets = []
|
43 |
+
indices = random.choices(range(len(dataset)), k=3)
|
44 |
+
for i in indices:
|
45 |
+
image, target = dataset.load_item(i)
|
46 |
+
image, target = self.resize(image, target)
|
47 |
+
images.append(image)
|
48 |
+
targets.append(target)
|
49 |
+
|
50 |
+
h, w = F.get_spatial_size(images[0])
|
51 |
+
offset = [[0, 0], [w, 0], [0, h], [w, h]]
|
52 |
+
image = Image.new(mode=images[0].mode, size=(w * 2, h * 2), color=0)
|
53 |
+
for i, im in enumerate(images):
|
54 |
+
image.paste(im, offset[i])
|
55 |
+
|
56 |
+
offset = torch.tensor([[0, 0], [w, 0], [0, h], [w, h]]).repeat(1, 2)
|
57 |
+
target = {}
|
58 |
+
for k in targets[0]:
|
59 |
+
if k == "boxes":
|
60 |
+
v = [t[k] + offset[i] for i, t in enumerate(targets)]
|
61 |
+
else:
|
62 |
+
v = [t[k] for t in targets]
|
63 |
+
|
64 |
+
if isinstance(v[0], torch.Tensor):
|
65 |
+
v = torch.cat(v, dim=0)
|
66 |
+
|
67 |
+
target[k] = v
|
68 |
+
|
69 |
+
if "boxes" in target:
|
70 |
+
# target['boxes'] = target['boxes'].clamp(0, 640 * 2 - 1)
|
71 |
+
w, h = image.size
|
72 |
+
target["boxes"] = convert_to_tv_tensor(
|
73 |
+
target["boxes"], "boxes", box_format="xyxy", spatial_size=[h, w]
|
74 |
+
)
|
75 |
+
|
76 |
+
if "masks" in target:
|
77 |
+
target["masks"] = convert_to_tv_tensor(target["masks"], "masks")
|
78 |
+
|
79 |
+
image, target = self.random_affine(image, target)
|
80 |
+
# image, target = self.resize(image, target)
|
81 |
+
image, target = self.crop(image, target)
|
82 |
+
|
83 |
+
return image, target, dataset
|
src/data/transforms/presets.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
src/misc/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .dist_utils import setup_print, setup_seed
|
7 |
+
from .logger import *
|
8 |
+
from .profiler_utils import stats
|
9 |
+
from .visualizer import *
|
src/misc/box_ops.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import List, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
|
13 |
+
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
|
14 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
15 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
16 |
+
return torchvision.ops.generalized_box_iou(boxes1, boxes2)
|
17 |
+
|
18 |
+
|
19 |
+
# elementwise
|
20 |
+
def elementwise_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
boxes1, [N, 4]
|
24 |
+
boxes2, [N, 4]
|
25 |
+
Returns:
|
26 |
+
iou, [N, ]
|
27 |
+
union, [N, ]
|
28 |
+
"""
|
29 |
+
area1 = torchvision.ops.box_area(boxes1) # [N, ]
|
30 |
+
area2 = torchvision.ops.box_area(boxes2) # [N, ]
|
31 |
+
lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N, 2]
|
32 |
+
rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N, 2]
|
33 |
+
wh = (rb - lt).clamp(min=0) # [N, 2]
|
34 |
+
inter = wh[:, 0] * wh[:, 1] # [N, ]
|
35 |
+
union = area1 + area2 - inter
|
36 |
+
iou = inter / union
|
37 |
+
return iou, union
|
38 |
+
|
39 |
+
|
40 |
+
def elementwise_generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
boxes1, [N, 4] with [x1, y1, x2, y2]
|
44 |
+
boxes2, [N, 4] with [x1, y1, x2, y2]
|
45 |
+
Returns:
|
46 |
+
giou, [N, ]
|
47 |
+
"""
|
48 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
49 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
50 |
+
iou, union = elementwise_box_iou(boxes1, boxes2)
|
51 |
+
lt = torch.min(boxes1[:, :2], boxes2[:, :2]) # [N, 2]
|
52 |
+
rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) # [N, 2]
|
53 |
+
wh = (rb - lt).clamp(min=0) # [N, 2]
|
54 |
+
area = wh[:, 0] * wh[:, 1]
|
55 |
+
return iou - (area - union) / area
|
56 |
+
|
57 |
+
|
58 |
+
def check_point_inside_box(points: Tensor, boxes: Tensor, eps=1e-9) -> Tensor:
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
points, [K, 2], (x, y)
|
62 |
+
boxes, [N, 4], (x1, y1, y2, y2)
|
63 |
+
Returns:
|
64 |
+
Tensor (bool), [K, N]
|
65 |
+
"""
|
66 |
+
x, y = [p.unsqueeze(-1) for p in points.unbind(-1)]
|
67 |
+
x1, y1, x2, y2 = [x.unsqueeze(0) for x in boxes.unbind(-1)]
|
68 |
+
|
69 |
+
l = x - x1
|
70 |
+
t = y - y1
|
71 |
+
r = x2 - x
|
72 |
+
b = y2 - y
|
73 |
+
|
74 |
+
ltrb = torch.stack([l, t, r, b], dim=-1)
|
75 |
+
mask = ltrb.min(dim=-1).values > eps
|
76 |
+
|
77 |
+
return mask
|
78 |
+
|
79 |
+
|
80 |
+
def point_box_distance(points: Tensor, boxes: Tensor) -> Tensor:
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
boxes, [N, 4], (x1, y1, x2, y2)
|
84 |
+
points, [N, 2], (x, y)
|
85 |
+
Returns:
|
86 |
+
Tensor (N, 4), (l, t, r, b)
|
87 |
+
"""
|
88 |
+
x1y1, x2y2 = torch.split(boxes, 2, dim=-1)
|
89 |
+
lt = points - x1y1
|
90 |
+
rb = x2y2 - points
|
91 |
+
return torch.concat([lt, rb], dim=-1)
|
92 |
+
|
93 |
+
|
94 |
+
def point_distance_box(points: Tensor, distances: Tensor) -> Tensor:
|
95 |
+
"""
|
96 |
+
Args:
|
97 |
+
points (Tensor), [N, 2], (x, y)
|
98 |
+
distances (Tensor), [N, 4], (l, t, r, b)
|
99 |
+
Returns:
|
100 |
+
boxes (Tensor), (N, 4), (x1, y1, x2, y2)
|
101 |
+
"""
|
102 |
+
lt, rb = torch.split(distances, 2, dim=-1)
|
103 |
+
x1y1 = -lt + points
|
104 |
+
x2y2 = rb + points
|
105 |
+
boxes = torch.concat([x1y1, x2y2], dim=-1)
|
106 |
+
return boxes
|
src/misc/dist_utils.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
reference
|
3 |
+
- https://github.com/pytorch/vision/blob/main/references/detection/utils.py
|
4 |
+
- https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406
|
5 |
+
|
6 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import atexit
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
import time
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.backends.cudnn
|
17 |
+
import torch.distributed
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
20 |
+
from torch.nn.parallel import DataParallel as DP
|
21 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
22 |
+
from torch.utils.data import DistributedSampler
|
23 |
+
|
24 |
+
# from torch.utils.data.dataloader import DataLoader
|
25 |
+
from ..data import DataLoader
|
26 |
+
|
27 |
+
|
28 |
+
def setup_distributed(
|
29 |
+
print_rank: int = 0,
|
30 |
+
print_method: str = "builtin",
|
31 |
+
seed: int = None,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
env setup
|
35 |
+
args:
|
36 |
+
print_rank,
|
37 |
+
print_method, (builtin, rich)
|
38 |
+
seed,
|
39 |
+
"""
|
40 |
+
try:
|
41 |
+
# https://pytorch.org/docs/stable/elastic/run.html
|
42 |
+
RANK = int(os.getenv("RANK", -1))
|
43 |
+
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))
|
44 |
+
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
|
45 |
+
|
46 |
+
# torch.distributed.init_process_group(backend=backend, init_method='env://')
|
47 |
+
torch.distributed.init_process_group(init_method="env://")
|
48 |
+
torch.distributed.barrier()
|
49 |
+
|
50 |
+
rank = torch.distributed.get_rank()
|
51 |
+
torch.cuda.set_device(rank)
|
52 |
+
torch.cuda.empty_cache()
|
53 |
+
enabled_dist = True
|
54 |
+
if get_rank() == print_rank:
|
55 |
+
print("Initialized distributed mode...")
|
56 |
+
|
57 |
+
except Exception:
|
58 |
+
enabled_dist = False
|
59 |
+
print("Not init distributed mode.")
|
60 |
+
|
61 |
+
setup_print(get_rank() == print_rank, method=print_method)
|
62 |
+
if seed is not None:
|
63 |
+
setup_seed(seed)
|
64 |
+
|
65 |
+
return enabled_dist
|
66 |
+
|
67 |
+
|
68 |
+
def setup_print(is_main, method="builtin"):
|
69 |
+
"""This function disables printing when not in master process"""
|
70 |
+
import builtins as __builtin__
|
71 |
+
|
72 |
+
if method == "builtin":
|
73 |
+
builtin_print = __builtin__.print
|
74 |
+
|
75 |
+
elif method == "rich":
|
76 |
+
import rich
|
77 |
+
|
78 |
+
builtin_print = rich.print
|
79 |
+
|
80 |
+
else:
|
81 |
+
raise AttributeError("")
|
82 |
+
|
83 |
+
def print(*args, **kwargs):
|
84 |
+
force = kwargs.pop("force", False)
|
85 |
+
if is_main or force:
|
86 |
+
builtin_print(*args, **kwargs)
|
87 |
+
|
88 |
+
__builtin__.print = print
|
89 |
+
|
90 |
+
|
91 |
+
def is_dist_available_and_initialized():
|
92 |
+
if not torch.distributed.is_available():
|
93 |
+
return False
|
94 |
+
if not torch.distributed.is_initialized():
|
95 |
+
return False
|
96 |
+
return True
|
97 |
+
|
98 |
+
|
99 |
+
@atexit.register
|
100 |
+
def cleanup():
|
101 |
+
"""cleanup distributed environment"""
|
102 |
+
if is_dist_available_and_initialized():
|
103 |
+
torch.distributed.barrier()
|
104 |
+
torch.distributed.destroy_process_group()
|
105 |
+
|
106 |
+
|
107 |
+
def get_rank():
|
108 |
+
if not is_dist_available_and_initialized():
|
109 |
+
return 0
|
110 |
+
return torch.distributed.get_rank()
|
111 |
+
|
112 |
+
|
113 |
+
def get_world_size():
|
114 |
+
if not is_dist_available_and_initialized():
|
115 |
+
return 1
|
116 |
+
return torch.distributed.get_world_size()
|
117 |
+
|
118 |
+
|
119 |
+
def is_main_process():
|
120 |
+
return get_rank() == 0
|
121 |
+
|
122 |
+
|
123 |
+
def save_on_master(*args, **kwargs):
|
124 |
+
if is_main_process():
|
125 |
+
torch.save(*args, **kwargs)
|
126 |
+
|
127 |
+
|
128 |
+
def warp_model(
|
129 |
+
model: torch.nn.Module,
|
130 |
+
sync_bn: bool = False,
|
131 |
+
dist_mode: str = "ddp",
|
132 |
+
find_unused_parameters: bool = False,
|
133 |
+
compile: bool = False,
|
134 |
+
compile_mode: str = "reduce-overhead",
|
135 |
+
**kwargs,
|
136 |
+
):
|
137 |
+
if is_dist_available_and_initialized():
|
138 |
+
rank = get_rank()
|
139 |
+
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model
|
140 |
+
if dist_mode == "dp":
|
141 |
+
model = DP(model, device_ids=[rank], output_device=rank)
|
142 |
+
elif dist_mode == "ddp":
|
143 |
+
model = DDP(
|
144 |
+
model,
|
145 |
+
device_ids=[rank],
|
146 |
+
output_device=rank,
|
147 |
+
find_unused_parameters=find_unused_parameters,
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
raise AttributeError("")
|
151 |
+
|
152 |
+
if compile:
|
153 |
+
model = torch.compile(model, mode=compile_mode)
|
154 |
+
|
155 |
+
return model
|
156 |
+
|
157 |
+
|
158 |
+
def de_model(model):
|
159 |
+
return de_parallel(de_complie(model))
|
160 |
+
|
161 |
+
|
162 |
+
def warp_loader(loader, shuffle=False):
|
163 |
+
if is_dist_available_and_initialized():
|
164 |
+
sampler = DistributedSampler(loader.dataset, shuffle=shuffle)
|
165 |
+
loader = DataLoader(
|
166 |
+
loader.dataset,
|
167 |
+
loader.batch_size,
|
168 |
+
sampler=sampler,
|
169 |
+
drop_last=loader.drop_last,
|
170 |
+
collate_fn=loader.collate_fn,
|
171 |
+
pin_memory=loader.pin_memory,
|
172 |
+
num_workers=loader.num_workers,
|
173 |
+
)
|
174 |
+
return loader
|
175 |
+
|
176 |
+
|
177 |
+
def is_parallel(model) -> bool:
|
178 |
+
# Returns True if model is of type DP or DDP
|
179 |
+
return type(model) in (
|
180 |
+
torch.nn.parallel.DataParallel,
|
181 |
+
torch.nn.parallel.DistributedDataParallel,
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
def de_parallel(model) -> nn.Module:
|
186 |
+
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
187 |
+
return model.module if is_parallel(model) else model
|
188 |
+
|
189 |
+
|
190 |
+
def reduce_dict(data, avg=True):
|
191 |
+
"""
|
192 |
+
Args
|
193 |
+
data dict: input, {k: v, ...}
|
194 |
+
avg bool: true
|
195 |
+
"""
|
196 |
+
world_size = get_world_size()
|
197 |
+
if world_size < 2:
|
198 |
+
return data
|
199 |
+
|
200 |
+
with torch.no_grad():
|
201 |
+
keys, values = [], []
|
202 |
+
for k in sorted(data.keys()):
|
203 |
+
keys.append(k)
|
204 |
+
values.append(data[k])
|
205 |
+
|
206 |
+
values = torch.stack(values, dim=0)
|
207 |
+
torch.distributed.all_reduce(values)
|
208 |
+
|
209 |
+
if avg is True:
|
210 |
+
values /= world_size
|
211 |
+
|
212 |
+
return {k: v for k, v in zip(keys, values)}
|
213 |
+
|
214 |
+
|
215 |
+
def all_gather(data):
|
216 |
+
"""
|
217 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
218 |
+
Args:
|
219 |
+
data: any picklable object
|
220 |
+
Returns:
|
221 |
+
list[data]: list of data gathered from each rank
|
222 |
+
"""
|
223 |
+
world_size = get_world_size()
|
224 |
+
if world_size == 1:
|
225 |
+
return [data]
|
226 |
+
data_list = [None] * world_size
|
227 |
+
torch.distributed.all_gather_object(data_list, data)
|
228 |
+
return data_list
|
229 |
+
|
230 |
+
|
231 |
+
def sync_time():
|
232 |
+
"""sync_time"""
|
233 |
+
if torch.cuda.is_available():
|
234 |
+
torch.cuda.synchronize()
|
235 |
+
|
236 |
+
return time.time()
|
237 |
+
|
238 |
+
|
239 |
+
def setup_seed(seed: int, deterministic=False):
|
240 |
+
"""setup_seed for reproducibility
|
241 |
+
torch.manual_seed(3407) is all you need. https://arxiv.org/abs/2109.08203
|
242 |
+
"""
|
243 |
+
seed = seed + get_rank()
|
244 |
+
random.seed(seed)
|
245 |
+
np.random.seed(seed)
|
246 |
+
torch.manual_seed(seed)
|
247 |
+
|
248 |
+
if torch.cuda.is_available():
|
249 |
+
torch.cuda.manual_seed_all(seed)
|
250 |
+
|
251 |
+
# memory will be large when setting deterministic to True
|
252 |
+
if torch.backends.cudnn.is_available() and deterministic:
|
253 |
+
torch.backends.cudnn.deterministic = True
|
254 |
+
|
255 |
+
|
256 |
+
# for torch.compile
|
257 |
+
def check_compile():
|
258 |
+
import warnings
|
259 |
+
|
260 |
+
import torch
|
261 |
+
|
262 |
+
gpu_ok = False
|
263 |
+
if torch.cuda.is_available():
|
264 |
+
device_cap = torch.cuda.get_device_capability()
|
265 |
+
if device_cap in ((7, 0), (8, 0), (9, 0)):
|
266 |
+
gpu_ok = True
|
267 |
+
if not gpu_ok:
|
268 |
+
warnings.warn(
|
269 |
+
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " "than expected."
|
270 |
+
)
|
271 |
+
return gpu_ok
|
272 |
+
|
273 |
+
|
274 |
+
def is_compile(model):
|
275 |
+
import torch._dynamo
|
276 |
+
|
277 |
+
return type(model) in (torch._dynamo.OptimizedModule,)
|
278 |
+
|
279 |
+
|
280 |
+
def de_complie(model):
|
281 |
+
return model._orig_mod if is_compile(model) else model
|
src/misc/lazy_loader.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import importlib
|
6 |
+
import types
|
7 |
+
|
8 |
+
|
9 |
+
class LazyLoader(types.ModuleType):
|
10 |
+
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
11 |
+
|
12 |
+
`paddle`, and `ffmpeg` are examples of modules that are large and not always
|
13 |
+
needed, and this allows them to only be loaded when they are used.
|
14 |
+
"""
|
15 |
+
|
16 |
+
# The lint error here is incorrect.
|
17 |
+
def __init__(self, local_name, parent_module_globals, name, warning=None):
|
18 |
+
self._local_name = local_name
|
19 |
+
self._parent_module_globals = parent_module_globals
|
20 |
+
self._warning = warning
|
21 |
+
|
22 |
+
# These members allows doctest correctly process this module member without
|
23 |
+
# triggering self._load(). self._load() mutates parant_module_globals and
|
24 |
+
# triggers a dict mutated during iteration error from doctest.py.
|
25 |
+
# - for from_module()
|
26 |
+
self.__module__ = name.rsplit(".", 1)[0]
|
27 |
+
# - for is_routine()
|
28 |
+
self.__wrapped__ = None
|
29 |
+
|
30 |
+
super(LazyLoader, self).__init__(name)
|
31 |
+
|
32 |
+
def _load(self):
|
33 |
+
"""Load the module and insert it into the parent's globals."""
|
34 |
+
# Import the target module and insert it into the parent's namespace
|
35 |
+
module = importlib.import_module(self.__name__)
|
36 |
+
self._parent_module_globals[self._local_name] = module
|
37 |
+
|
38 |
+
# Emit a warning if one was specified
|
39 |
+
if self._warning:
|
40 |
+
# logging.warning(self._warning)
|
41 |
+
# Make sure to only warn once.
|
42 |
+
self._warning = None
|
43 |
+
|
44 |
+
# Update this object's dict so that if someone keeps a reference to the
|
45 |
+
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
|
46 |
+
# that fail).
|
47 |
+
self.__dict__.update(module.__dict__)
|
48 |
+
|
49 |
+
return module
|
50 |
+
|
51 |
+
def __getattr__(self, item):
|
52 |
+
module = self._load()
|
53 |
+
return getattr(module, item)
|
54 |
+
|
55 |
+
def __repr__(self):
|
56 |
+
# Carefully to not trigger _load, since repr may be called in very
|
57 |
+
# sensitive places.
|
58 |
+
return f"<LazyLoader {self.__name__} as {self._local_name}>"
|
59 |
+
|
60 |
+
def __dir__(self):
|
61 |
+
module = self._load()
|
62 |
+
return dir(module)
|
63 |
+
|
64 |
+
|
65 |
+
# import paddle.nn as nn
|
66 |
+
# nn = LazyLoader("nn", globals(), "paddle.nn")
|
67 |
+
|
68 |
+
# class M(nn.Layer):
|
69 |
+
# def __init__(self) -> None:
|
70 |
+
# super().__init__()
|
src/misc/logger.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
https://github.com/facebookresearch/detr/blob/main/util/misc.py
|
4 |
+
Mostly copy-paste from torchvision references.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
import pickle
|
9 |
+
import time
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
from typing import Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as tdist
|
15 |
+
|
16 |
+
from .dist_utils import get_world_size, is_dist_available_and_initialized
|
17 |
+
|
18 |
+
|
19 |
+
class SmoothedValue(object):
|
20 |
+
"""Track a series of values and provide access to smoothed values over a
|
21 |
+
window or the global series average.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, window_size=20, fmt=None):
|
25 |
+
if fmt is None:
|
26 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
27 |
+
self.deque = deque(maxlen=window_size)
|
28 |
+
self.total = 0.0
|
29 |
+
self.count = 0
|
30 |
+
self.fmt = fmt
|
31 |
+
|
32 |
+
def update(self, value, n=1):
|
33 |
+
self.deque.append(value)
|
34 |
+
self.count += n
|
35 |
+
self.total += value * n
|
36 |
+
|
37 |
+
def synchronize_between_processes(self):
|
38 |
+
"""
|
39 |
+
Warning: does not synchronize the deque!
|
40 |
+
"""
|
41 |
+
if not is_dist_available_and_initialized():
|
42 |
+
return
|
43 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
44 |
+
tdist.barrier()
|
45 |
+
tdist.all_reduce(t)
|
46 |
+
t = t.tolist()
|
47 |
+
self.count = int(t[0])
|
48 |
+
self.total = t[1]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def median(self):
|
52 |
+
d = torch.tensor(list(self.deque))
|
53 |
+
return d.median().item()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def avg(self):
|
57 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
+
return d.mean().item()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def global_avg(self):
|
62 |
+
return self.total / self.count
|
63 |
+
|
64 |
+
@property
|
65 |
+
def max(self):
|
66 |
+
return max(self.deque)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def value(self):
|
70 |
+
return self.deque[-1]
|
71 |
+
|
72 |
+
def __str__(self):
|
73 |
+
return self.fmt.format(
|
74 |
+
median=self.median,
|
75 |
+
avg=self.avg,
|
76 |
+
global_avg=self.global_avg,
|
77 |
+
max=self.max,
|
78 |
+
value=self.value,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def all_gather(data):
|
83 |
+
"""
|
84 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
85 |
+
Args:
|
86 |
+
data: any picklable object
|
87 |
+
Returns:
|
88 |
+
list[data]: list of data gathered from each rank
|
89 |
+
"""
|
90 |
+
world_size = get_world_size()
|
91 |
+
if world_size == 1:
|
92 |
+
return [data]
|
93 |
+
|
94 |
+
# serialized to a Tensor
|
95 |
+
buffer = pickle.dumps(data)
|
96 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
97 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
98 |
+
|
99 |
+
# obtain Tensor size of each rank
|
100 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
101 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
102 |
+
tdist.all_gather(size_list, local_size)
|
103 |
+
size_list = [int(size.item()) for size in size_list]
|
104 |
+
max_size = max(size_list)
|
105 |
+
|
106 |
+
# receiving Tensor from all ranks
|
107 |
+
# we pad the tensor because torch all_gather does not support
|
108 |
+
# gathering tensors of different shapes
|
109 |
+
tensor_list = []
|
110 |
+
for _ in size_list:
|
111 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
112 |
+
if local_size != max_size:
|
113 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
114 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
115 |
+
tdist.all_gather(tensor_list, tensor)
|
116 |
+
|
117 |
+
data_list = []
|
118 |
+
for size, tensor in zip(size_list, tensor_list):
|
119 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
120 |
+
data_list.append(pickle.loads(buffer))
|
121 |
+
|
122 |
+
return data_list
|
123 |
+
|
124 |
+
|
125 |
+
def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]:
|
126 |
+
"""
|
127 |
+
Args:
|
128 |
+
input_dict (dict): all the values will be reduced
|
129 |
+
average (bool): whether to do average or sum
|
130 |
+
Reduce the values in the dictionary from all processes so that all processes
|
131 |
+
have the averaged results. Returns a dict with the same fields as
|
132 |
+
input_dict, after reduction.
|
133 |
+
"""
|
134 |
+
world_size = get_world_size()
|
135 |
+
if world_size < 2:
|
136 |
+
return input_dict
|
137 |
+
with torch.no_grad():
|
138 |
+
names = []
|
139 |
+
values = []
|
140 |
+
# sort the keys so that they are consistent across processes
|
141 |
+
for k in sorted(input_dict.keys()):
|
142 |
+
names.append(k)
|
143 |
+
values.append(input_dict[k])
|
144 |
+
values = torch.stack(values, dim=0)
|
145 |
+
tdist.all_reduce(values)
|
146 |
+
if average:
|
147 |
+
values /= world_size
|
148 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
149 |
+
return reduced_dict
|
150 |
+
|
151 |
+
|
152 |
+
class MetricLogger(object):
|
153 |
+
def __init__(self, delimiter="\t"):
|
154 |
+
self.meters = defaultdict(SmoothedValue)
|
155 |
+
self.delimiter = delimiter
|
156 |
+
|
157 |
+
def update(self, **kwargs):
|
158 |
+
for k, v in kwargs.items():
|
159 |
+
if isinstance(v, torch.Tensor):
|
160 |
+
v = v.item()
|
161 |
+
assert isinstance(v, (float, int))
|
162 |
+
self.meters[k].update(v)
|
163 |
+
|
164 |
+
def __getattr__(self, attr):
|
165 |
+
if attr in self.meters:
|
166 |
+
return self.meters[attr]
|
167 |
+
if attr in self.__dict__:
|
168 |
+
return self.__dict__[attr]
|
169 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
170 |
+
|
171 |
+
def __str__(self):
|
172 |
+
loss_str = []
|
173 |
+
for name, meter in self.meters.items():
|
174 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
175 |
+
return self.delimiter.join(loss_str)
|
176 |
+
|
177 |
+
def synchronize_between_processes(self):
|
178 |
+
for meter in self.meters.values():
|
179 |
+
meter.synchronize_between_processes()
|
180 |
+
|
181 |
+
def add_meter(self, name, meter):
|
182 |
+
self.meters[name] = meter
|
183 |
+
|
184 |
+
def log_every(self, iterable, print_freq, header=None):
|
185 |
+
i = 0
|
186 |
+
if not header:
|
187 |
+
header = ""
|
188 |
+
start_time = time.time()
|
189 |
+
end = time.time()
|
190 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
191 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
192 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
193 |
+
if torch.cuda.is_available():
|
194 |
+
log_msg = self.delimiter.join(
|
195 |
+
[
|
196 |
+
header,
|
197 |
+
"[{0" + space_fmt + "}/{1}]",
|
198 |
+
"eta: {eta}",
|
199 |
+
"{meters}",
|
200 |
+
"time: {time}",
|
201 |
+
"data: {data}",
|
202 |
+
"max mem: {memory:.0f}",
|
203 |
+
]
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
log_msg = self.delimiter.join(
|
207 |
+
[
|
208 |
+
header,
|
209 |
+
"[{0" + space_fmt + "}/{1}]",
|
210 |
+
"eta: {eta}",
|
211 |
+
"{meters}",
|
212 |
+
"time: {time}",
|
213 |
+
"data: {data}",
|
214 |
+
]
|
215 |
+
)
|
216 |
+
MB = 1024.0 * 1024.0
|
217 |
+
for obj in iterable:
|
218 |
+
data_time.update(time.time() - end)
|
219 |
+
yield obj
|
220 |
+
iter_time.update(time.time() - end)
|
221 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
222 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
223 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
224 |
+
if torch.cuda.is_available():
|
225 |
+
print(
|
226 |
+
log_msg.format(
|
227 |
+
i,
|
228 |
+
len(iterable),
|
229 |
+
eta=eta_string,
|
230 |
+
meters=str(self),
|
231 |
+
time=str(iter_time),
|
232 |
+
data=str(data_time),
|
233 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
234 |
+
)
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
print(
|
238 |
+
log_msg.format(
|
239 |
+
i,
|
240 |
+
len(iterable),
|
241 |
+
eta=eta_string,
|
242 |
+
meters=str(self),
|
243 |
+
time=str(iter_time),
|
244 |
+
data=str(data_time),
|
245 |
+
)
|
246 |
+
)
|
247 |
+
i += 1
|
248 |
+
end = time.time()
|
249 |
+
total_time = time.time() - start_time
|
250 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
251 |
+
print(
|
252 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
253 |
+
header, total_time_str, total_time / len(iterable)
|
254 |
+
)
|
255 |
+
)
|
src/misc/profiler_utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
from calflops import calculate_flops
|
9 |
+
|
10 |
+
|
11 |
+
def stats(
|
12 |
+
cfg,
|
13 |
+
input_shape: Tuple = (1, 3, 640, 640),
|
14 |
+
) -> Tuple[int, dict]:
|
15 |
+
base_size = cfg.train_dataloader.collate_fn.base_size
|
16 |
+
input_shape = (1, 3, base_size, base_size)
|
17 |
+
|
18 |
+
model_for_info = copy.deepcopy(cfg.model).deploy()
|
19 |
+
|
20 |
+
flops, macs, _ = calculate_flops(
|
21 |
+
model=model_for_info,
|
22 |
+
input_shape=input_shape,
|
23 |
+
output_as_string=True,
|
24 |
+
output_precision=4,
|
25 |
+
print_detailed=False,
|
26 |
+
)
|
27 |
+
params = sum(p.numel() for p in model_for_info.parameters())
|
28 |
+
del model_for_info
|
29 |
+
|
30 |
+
return params, {"Model FLOPs:%s MACs:%s Params:%s" % (flops, macs, params)}
|
src/misc/visualizer.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" "
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import PIL
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
import torchvision
|
11 |
+
from typing import List, Dict
|
12 |
+
|
13 |
+
torchvision.disable_beta_transforms_warning()
|
14 |
+
|
15 |
+
__all__ = ["show_sample", "save_samples"]
|
16 |
+
|
17 |
+
def save_samples(samples: torch.Tensor, targets: List[Dict], output_dir: str, split: str, normalized: bool, box_fmt: str):
|
18 |
+
'''
|
19 |
+
normalized: whether the boxes are normalized to [0, 1]
|
20 |
+
box_fmt: 'xyxy', 'xywh', 'cxcywh', D-FINE uses 'cxcywh' for training, 'xyxy' for validation
|
21 |
+
'''
|
22 |
+
from torchvision.transforms.functional import to_pil_image
|
23 |
+
from torchvision.ops import box_convert
|
24 |
+
from pathlib import Path
|
25 |
+
from PIL import ImageDraw, ImageFont
|
26 |
+
import os
|
27 |
+
|
28 |
+
os.makedirs(Path(output_dir) / Path(f"{split}_samples"), exist_ok=True)
|
29 |
+
# Predefined colors (standard color names recognized by PIL)
|
30 |
+
BOX_COLORS = [
|
31 |
+
"red", "blue", "green", "orange", "purple",
|
32 |
+
"cyan", "magenta", "yellow", "lime", "pink",
|
33 |
+
"teal", "lavender", "brown", "beige", "maroon",
|
34 |
+
"navy", "olive", "coral", "turquoise", "gold"
|
35 |
+
]
|
36 |
+
|
37 |
+
LABEL_TEXT_COLOR = "white"
|
38 |
+
|
39 |
+
font = ImageFont.load_default()
|
40 |
+
font.size = 32
|
41 |
+
|
42 |
+
for i, (sample, target) in enumerate(zip(samples, targets)):
|
43 |
+
sample_visualization = sample.clone().cpu()
|
44 |
+
target_boxes = target["boxes"].clone().cpu()
|
45 |
+
target_labels = target["labels"].clone().cpu()
|
46 |
+
target_image_id = target["image_id"].item()
|
47 |
+
target_image_path = target["image_path"]
|
48 |
+
target_image_path_stem = Path(target_image_path).stem
|
49 |
+
|
50 |
+
sample_visualization = to_pil_image(sample_visualization)
|
51 |
+
sample_visualization_w, sample_visualization_h = sample_visualization.size
|
52 |
+
|
53 |
+
# normalized to pixel space
|
54 |
+
if normalized:
|
55 |
+
target_boxes[:, 0] = target_boxes[:, 0] * sample_visualization_w
|
56 |
+
target_boxes[:, 2] = target_boxes[:, 2] * sample_visualization_w
|
57 |
+
target_boxes[:, 1] = target_boxes[:, 1] * sample_visualization_h
|
58 |
+
target_boxes[:, 3] = target_boxes[:, 3] * sample_visualization_h
|
59 |
+
|
60 |
+
# any box format -> xyxy
|
61 |
+
target_boxes = box_convert(target_boxes, in_fmt=box_fmt, out_fmt="xyxy")
|
62 |
+
|
63 |
+
# clip to image size
|
64 |
+
target_boxes[:, 0] = torch.clamp(target_boxes[:, 0], 0, sample_visualization_w)
|
65 |
+
target_boxes[:, 1] = torch.clamp(target_boxes[:, 1], 0, sample_visualization_h)
|
66 |
+
target_boxes[:, 2] = torch.clamp(target_boxes[:, 2], 0, sample_visualization_w)
|
67 |
+
target_boxes[:, 3] = torch.clamp(target_boxes[:, 3], 0, sample_visualization_h)
|
68 |
+
|
69 |
+
target_boxes = target_boxes.numpy().astype(np.int32)
|
70 |
+
target_labels = target_labels.numpy().astype(np.int32)
|
71 |
+
|
72 |
+
draw = ImageDraw.Draw(sample_visualization)
|
73 |
+
|
74 |
+
# draw target boxes
|
75 |
+
for box, label in zip(target_boxes, target_labels):
|
76 |
+
x1, y1, x2, y2 = box
|
77 |
+
|
78 |
+
# Select color based on class ID
|
79 |
+
box_color = BOX_COLORS[int(label) % len(BOX_COLORS)]
|
80 |
+
|
81 |
+
# Draw box (thick)
|
82 |
+
draw.rectangle([x1, y1, x2, y2], outline=box_color, width=3)
|
83 |
+
|
84 |
+
label_text = f"{label}"
|
85 |
+
|
86 |
+
# Measure text size
|
87 |
+
text_width, text_height = draw.textbbox((0, 0), label_text, font=font)[2:4]
|
88 |
+
|
89 |
+
# Draw text background
|
90 |
+
padding = 2
|
91 |
+
draw.rectangle(
|
92 |
+
[x1, y1 - text_height - padding * 2, x1 + text_width + padding * 2, y1],
|
93 |
+
fill=box_color
|
94 |
+
)
|
95 |
+
|
96 |
+
# Draw text (LABEL_TEXT_COLOR)
|
97 |
+
draw.text((x1 + padding, y1 - text_height - padding), label_text,
|
98 |
+
fill=LABEL_TEXT_COLOR, font=font)
|
99 |
+
|
100 |
+
save_path = Path(output_dir) / f"{split}_samples" / f"{target_image_id}_{target_image_path_stem}.webp"
|
101 |
+
sample_visualization.save(save_path)
|
102 |
+
|
103 |
+
def show_sample(sample):
|
104 |
+
"""for coco dataset/dataloader"""
|
105 |
+
import matplotlib.pyplot as plt
|
106 |
+
from torchvision.transforms.v2 import functional as F
|
107 |
+
from torchvision.utils import draw_bounding_boxes
|
108 |
+
|
109 |
+
image, target = sample
|
110 |
+
if isinstance(image, PIL.Image.Image):
|
111 |
+
image = F.to_image_tensor(image)
|
112 |
+
|
113 |
+
image = F.convert_dtype(image, torch.uint8)
|
114 |
+
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
|
115 |
+
|
116 |
+
fig, ax = plt.subplots()
|
117 |
+
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
|
118 |
+
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
|
119 |
+
fig.tight_layout()
|
120 |
+
fig.show()
|
121 |
+
plt.show()
|
src/nn/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .arch import *
|
7 |
+
|
8 |
+
#
|
9 |
+
from .backbone import *
|
10 |
+
from .backbone import (
|
11 |
+
FrozenBatchNorm2d,
|
12 |
+
freeze_batch_norm2d,
|
13 |
+
get_activation,
|
14 |
+
)
|
15 |
+
from .criterion import *
|
16 |
+
from .postprocessor import *
|
src/nn/arch/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .classification import ClassHead, Classification
|
7 |
+
from .yolo import YOLO
|
src/nn/arch/classification.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from ...core import register
|
10 |
+
|
11 |
+
__all__ = ["Classification", "ClassHead"]
|
12 |
+
|
13 |
+
|
14 |
+
@register()
|
15 |
+
class Classification(torch.nn.Module):
|
16 |
+
__inject__ = ["backbone", "head"]
|
17 |
+
|
18 |
+
def __init__(self, backbone: nn.Module, head: nn.Module = None):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.backbone = backbone
|
22 |
+
self.head = head
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.backbone(x)
|
26 |
+
|
27 |
+
if self.head is not None:
|
28 |
+
x = self.head(x)
|
29 |
+
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
@register()
|
34 |
+
class ClassHead(nn.Module):
|
35 |
+
def __init__(self, hidden_dim, num_classes):
|
36 |
+
super().__init__()
|
37 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
38 |
+
self.proj = nn.Linear(hidden_dim, num_classes)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = x[0] if isinstance(x, (list, tuple)) else x
|
42 |
+
x = self.pool(x)
|
43 |
+
x = x.reshape(x.shape[0], -1)
|
44 |
+
x = self.proj(x)
|
45 |
+
return x
|
src/nn/arch/yolo.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from ...core import register
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"YOLO",
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
@register()
|
16 |
+
class YOLO(torch.nn.Module):
|
17 |
+
__inject__ = [
|
18 |
+
"backbone",
|
19 |
+
"neck",
|
20 |
+
"head",
|
21 |
+
]
|
22 |
+
|
23 |
+
def __init__(self, backbone: torch.nn.Module, neck, head):
|
24 |
+
super().__init__()
|
25 |
+
self.backbone = backbone
|
26 |
+
self.neck = neck
|
27 |
+
self.head = head
|
28 |
+
|
29 |
+
def forward(self, x, **kwargs):
|
30 |
+
x = self.backbone(x)
|
31 |
+
x = self.neck(x)
|
32 |
+
x = self.head(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
def deploy(
|
36 |
+
self,
|
37 |
+
):
|
38 |
+
self.eval()
|
39 |
+
for m in self.modules():
|
40 |
+
if m is not self and hasattr(m, "deploy"):
|
41 |
+
m.deploy()
|
42 |
+
return self
|
src/nn/backbone/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .common import (
|
7 |
+
FrozenBatchNorm2d,
|
8 |
+
freeze_batch_norm2d,
|
9 |
+
get_activation,
|
10 |
+
)
|
11 |
+
from .csp_darknet import CSPPAN, CSPDarkNet
|
12 |
+
from .csp_resnet import CSPResNet
|
13 |
+
from .hgnetv2 import HGNetv2
|
14 |
+
from .presnet import PResNet
|
15 |
+
from .test_resnet import MResNet
|
16 |
+
from .timm_model import TimmModel
|
17 |
+
from .torchvision_model import TorchVisionModel
|
src/nn/backbone/common.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class ConvNormLayer(nn.Module):
|
11 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
|
12 |
+
super().__init__()
|
13 |
+
self.conv = nn.Conv2d(
|
14 |
+
ch_in,
|
15 |
+
ch_out,
|
16 |
+
kernel_size,
|
17 |
+
stride,
|
18 |
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
19 |
+
bias=bias,
|
20 |
+
)
|
21 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
22 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.act(self.norm(self.conv(x)))
|
26 |
+
|
27 |
+
|
28 |
+
class FrozenBatchNorm2d(nn.Module):
|
29 |
+
"""copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
30 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
31 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
32 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
33 |
+
produce nans.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, num_features, eps=1e-5):
|
37 |
+
super(FrozenBatchNorm2d, self).__init__()
|
38 |
+
n = num_features
|
39 |
+
self.register_buffer("weight", torch.ones(n))
|
40 |
+
self.register_buffer("bias", torch.zeros(n))
|
41 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
42 |
+
self.register_buffer("running_var", torch.ones(n))
|
43 |
+
self.eps = eps
|
44 |
+
self.num_features = n
|
45 |
+
|
46 |
+
def _load_from_state_dict(
|
47 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
48 |
+
):
|
49 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
50 |
+
if num_batches_tracked_key in state_dict:
|
51 |
+
del state_dict[num_batches_tracked_key]
|
52 |
+
|
53 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
54 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
# move reshapes to the beginning
|
59 |
+
# to make it fuser-friendly
|
60 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
61 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
62 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
63 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
64 |
+
scale = w * (rv + self.eps).rsqrt()
|
65 |
+
bias = b - rm * scale
|
66 |
+
return x * scale + bias
|
67 |
+
|
68 |
+
def extra_repr(self):
|
69 |
+
return "{num_features}, eps={eps}".format(**self.__dict__)
|
70 |
+
|
71 |
+
|
72 |
+
def freeze_batch_norm2d(module: nn.Module) -> nn.Module:
|
73 |
+
if isinstance(module, nn.BatchNorm2d):
|
74 |
+
module = FrozenBatchNorm2d(module.num_features)
|
75 |
+
else:
|
76 |
+
for name, child in module.named_children():
|
77 |
+
_child = freeze_batch_norm2d(child)
|
78 |
+
if _child is not child:
|
79 |
+
setattr(module, name, _child)
|
80 |
+
return module
|
81 |
+
|
82 |
+
|
83 |
+
def get_activation(act: str, inplace: bool = True):
|
84 |
+
"""get activation"""
|
85 |
+
if act is None:
|
86 |
+
return nn.Identity()
|
87 |
+
|
88 |
+
elif isinstance(act, nn.Module):
|
89 |
+
return act
|
90 |
+
|
91 |
+
act = act.lower()
|
92 |
+
|
93 |
+
if act == "silu" or act == "swish":
|
94 |
+
m = nn.SiLU()
|
95 |
+
|
96 |
+
elif act == "relu":
|
97 |
+
m = nn.ReLU()
|
98 |
+
|
99 |
+
elif act == "leaky_relu":
|
100 |
+
m = nn.LeakyReLU()
|
101 |
+
|
102 |
+
elif act == "silu":
|
103 |
+
m = nn.SiLU()
|
104 |
+
|
105 |
+
elif act == "gelu":
|
106 |
+
m = nn.GELU()
|
107 |
+
|
108 |
+
elif act == "hardsigmoid":
|
109 |
+
m = nn.Hardsigmoid()
|
110 |
+
|
111 |
+
else:
|
112 |
+
raise RuntimeError("")
|
113 |
+
|
114 |
+
if hasattr(m, "inplace"):
|
115 |
+
m.inplace = inplace
|
116 |
+
|
117 |
+
return m
|
src/nn/backbone/csp_darknet.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from ...core import register
|
14 |
+
from .common import get_activation
|
15 |
+
|
16 |
+
|
17 |
+
def autopad(k, p=None):
|
18 |
+
if p is None:
|
19 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
20 |
+
return p
|
21 |
+
|
22 |
+
|
23 |
+
def make_divisible(c, d):
|
24 |
+
return math.ceil(c / d) * d
|
25 |
+
|
26 |
+
|
27 |
+
class Conv(nn.Module):
|
28 |
+
def __init__(self, cin, cout, k=1, s=1, p=None, g=1, act="silu") -> None:
|
29 |
+
super().__init__()
|
30 |
+
self.conv = nn.Conv2d(cin, cout, k, s, autopad(k, p), groups=g, bias=False)
|
31 |
+
self.bn = nn.BatchNorm2d(cout)
|
32 |
+
self.act = get_activation(act, inplace=True)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.act(self.bn(self.conv(x)))
|
36 |
+
|
37 |
+
|
38 |
+
class Bottleneck(nn.Module):
|
39 |
+
# Standard bottleneck
|
40 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act="silu"):
|
41 |
+
super().__init__()
|
42 |
+
c_ = int(c2 * e) # hidden channels
|
43 |
+
self.cv1 = Conv(c1, c_, 1, 1, act=act)
|
44 |
+
self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act)
|
45 |
+
self.add = shortcut and c1 == c2
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
49 |
+
|
50 |
+
|
51 |
+
class C3(nn.Module):
|
52 |
+
# CSP Bottleneck with 3 convolutions
|
53 |
+
def __init__(
|
54 |
+
self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act="silu"
|
55 |
+
): # ch_in, ch_out, number, shortcut, groups, expansion
|
56 |
+
super().__init__()
|
57 |
+
c_ = int(c2 * e) # hidden channels
|
58 |
+
self.cv1 = Conv(c1, c_, 1, 1, act=act)
|
59 |
+
self.cv2 = Conv(c1, c_, 1, 1, act=act)
|
60 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n)))
|
61 |
+
self.cv3 = Conv(2 * c_, c2, 1, act=act)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
65 |
+
|
66 |
+
|
67 |
+
class SPPF(nn.Module):
|
68 |
+
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
|
69 |
+
def __init__(self, c1, c2, k=5, act="silu"): # equivalent to SPP(k=(5, 9, 13))
|
70 |
+
super().__init__()
|
71 |
+
c_ = c1 // 2 # hidden channels
|
72 |
+
self.cv1 = Conv(c1, c_, 1, 1, act=act)
|
73 |
+
self.cv2 = Conv(c_ * 4, c2, 1, 1, act=act)
|
74 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x = self.cv1(x)
|
78 |
+
with warnings.catch_warnings():
|
79 |
+
warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning
|
80 |
+
y1 = self.m(x)
|
81 |
+
y2 = self.m(y1)
|
82 |
+
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
|
83 |
+
|
84 |
+
|
85 |
+
@register()
|
86 |
+
class CSPDarkNet(nn.Module):
|
87 |
+
__share__ = ["depth_multi", "width_multi"]
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
in_channels=3,
|
92 |
+
width_multi=1.0,
|
93 |
+
depth_multi=1.0,
|
94 |
+
return_idx=[2, 3, -1],
|
95 |
+
act="silu",
|
96 |
+
) -> None:
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
channels = [64, 128, 256, 512, 1024]
|
100 |
+
channels = [make_divisible(c * width_multi, 8) for c in channels]
|
101 |
+
|
102 |
+
depths = [3, 6, 9, 3]
|
103 |
+
depths = [max(round(d * depth_multi), 1) for d in depths]
|
104 |
+
|
105 |
+
self.layers = nn.ModuleList([Conv(in_channels, channels[0], 6, 2, 2, act=act)])
|
106 |
+
for i, (c, d) in enumerate(zip(channels, depths), 1):
|
107 |
+
layer = nn.Sequential(
|
108 |
+
*[Conv(c, channels[i], 3, 2, act=act), C3(channels[i], channels[i], n=d, act=act)]
|
109 |
+
)
|
110 |
+
self.layers.append(layer)
|
111 |
+
|
112 |
+
self.layers.append(SPPF(channels[-1], channels[-1], k=5, act=act))
|
113 |
+
|
114 |
+
self.return_idx = return_idx
|
115 |
+
self.out_channels = [channels[i] for i in self.return_idx]
|
116 |
+
self.strides = [[2, 4, 8, 16, 32][i] for i in self.return_idx]
|
117 |
+
self.depths = depths
|
118 |
+
self.act = act
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
outputs = []
|
122 |
+
for _, m in enumerate(self.layers):
|
123 |
+
x = m(x)
|
124 |
+
outputs.append(x)
|
125 |
+
|
126 |
+
return [outputs[i] for i in self.return_idx]
|
127 |
+
|
128 |
+
|
129 |
+
@register()
|
130 |
+
class CSPPAN(nn.Module):
|
131 |
+
"""
|
132 |
+
P5 ---> 1x1 ---------------------------------> concat --> c3 --> det
|
133 |
+
| up | conv /2
|
134 |
+
P4 ---> concat ---> c3 ---> 1x1 --> concat ---> c3 -----------> det
|
135 |
+
| up | conv /2
|
136 |
+
P3 -----------------------> concat ---> c3 ---------------------> det
|
137 |
+
"""
|
138 |
+
|
139 |
+
__share__ = [
|
140 |
+
"depth_multi",
|
141 |
+
]
|
142 |
+
|
143 |
+
def __init__(self, in_channels=[256, 512, 1024], depth_multi=1.0, act="silu") -> None:
|
144 |
+
super().__init__()
|
145 |
+
depth = max(round(3 * depth_multi), 1)
|
146 |
+
|
147 |
+
self.out_channels = in_channels
|
148 |
+
self.fpn_stems = nn.ModuleList(
|
149 |
+
[
|
150 |
+
Conv(cin, cout, 1, 1, act=act)
|
151 |
+
for cin, cout in zip(in_channels[::-1], in_channels[::-1][1:])
|
152 |
+
]
|
153 |
+
)
|
154 |
+
self.fpn_csps = nn.ModuleList(
|
155 |
+
[
|
156 |
+
C3(cin, cout, depth, False, act=act)
|
157 |
+
for cin, cout in zip(in_channels[::-1], in_channels[::-1][1:])
|
158 |
+
]
|
159 |
+
)
|
160 |
+
|
161 |
+
self.pan_stems = nn.ModuleList([Conv(c, c, 3, 2, act=act) for c in in_channels[:-1]])
|
162 |
+
self.pan_csps = nn.ModuleList([C3(c, c, depth, False, act=act) for c in in_channels[1:]])
|
163 |
+
|
164 |
+
def forward(self, feats):
|
165 |
+
fpn_feats = []
|
166 |
+
for i, feat in enumerate(feats[::-1]):
|
167 |
+
if i == 0:
|
168 |
+
feat = self.fpn_stems[i](feat)
|
169 |
+
fpn_feats.append(feat)
|
170 |
+
else:
|
171 |
+
_feat = F.interpolate(fpn_feats[-1], scale_factor=2, mode="nearest")
|
172 |
+
feat = torch.concat([_feat, feat], dim=1)
|
173 |
+
feat = self.fpn_csps[i - 1](feat)
|
174 |
+
if i < len(self.fpn_stems):
|
175 |
+
feat = self.fpn_stems[i](feat)
|
176 |
+
fpn_feats.append(feat)
|
177 |
+
|
178 |
+
pan_feats = []
|
179 |
+
for i, feat in enumerate(fpn_feats[::-1]):
|
180 |
+
if i == 0:
|
181 |
+
pan_feats.append(feat)
|
182 |
+
else:
|
183 |
+
_feat = self.pan_stems[i - 1](pan_feats[-1])
|
184 |
+
feat = torch.concat([_feat, feat], dim=1)
|
185 |
+
feat = self.pan_csps[i - 1](feat)
|
186 |
+
pan_feats.append(feat)
|
187 |
+
|
188 |
+
return pan_feats
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
data = torch.rand(1, 3, 320, 640)
|
193 |
+
|
194 |
+
width_multi = 0.75
|
195 |
+
depth_multi = 0.33
|
196 |
+
|
197 |
+
m = CSPDarkNet(3, width_multi=width_multi, depth_multi=depth_multi, act="silu")
|
198 |
+
outputs = m(data)
|
199 |
+
print([o.shape for o in outputs])
|
200 |
+
|
201 |
+
m = CSPPAN(in_channels=m.out_channels, depth_multi=depth_multi, act="silu")
|
202 |
+
outputs = m(outputs)
|
203 |
+
print([o.shape for o in outputs])
|
src/nn/backbone/csp_resnet.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.6/ppdet/modeling/backbones/cspresnet.py
|
3 |
+
|
4 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from ...core import register
|
14 |
+
from .common import get_activation
|
15 |
+
|
16 |
+
__all__ = ["CSPResNet"]
|
17 |
+
|
18 |
+
|
19 |
+
donwload_url = {
|
20 |
+
"s": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_s_pretrained_from_paddle.pth",
|
21 |
+
"m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_m_pretrained_from_paddle.pth",
|
22 |
+
"l": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_l_pretrained_from_paddle.pth",
|
23 |
+
"x": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_x_pretrained_from_paddle.pth",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
class ConvBNLayer(nn.Module):
|
28 |
+
def __init__(self, ch_in, ch_out, filter_size=3, stride=1, groups=1, padding=0, act=None):
|
29 |
+
super().__init__()
|
30 |
+
self.conv = nn.Conv2d(
|
31 |
+
ch_in, ch_out, filter_size, stride, padding, groups=groups, bias=False
|
32 |
+
)
|
33 |
+
self.bn = nn.BatchNorm2d(ch_out)
|
34 |
+
self.act = get_activation(act)
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
37 |
+
x = self.conv(x)
|
38 |
+
x = self.bn(x)
|
39 |
+
x = self.act(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class RepVggBlock(nn.Module):
|
44 |
+
def __init__(self, ch_in, ch_out, act="relu", alpha: bool = False):
|
45 |
+
super().__init__()
|
46 |
+
self.ch_in = ch_in
|
47 |
+
self.ch_out = ch_out
|
48 |
+
self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=None)
|
49 |
+
self.conv2 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=None)
|
50 |
+
self.act = get_activation(act)
|
51 |
+
|
52 |
+
if alpha:
|
53 |
+
self.alpha = nn.Parameter(
|
54 |
+
torch.ones(
|
55 |
+
1,
|
56 |
+
)
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
self.alpha = None
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
if hasattr(self, "conv"):
|
63 |
+
y = self.conv(x)
|
64 |
+
else:
|
65 |
+
if self.alpha:
|
66 |
+
y = self.conv1(x) + self.alpha * self.conv2(x)
|
67 |
+
else:
|
68 |
+
y = self.conv1(x) + self.conv2(x)
|
69 |
+
y = self.act(y)
|
70 |
+
return y
|
71 |
+
|
72 |
+
def convert_to_deploy(self):
|
73 |
+
if not hasattr(self, "conv"):
|
74 |
+
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
|
75 |
+
|
76 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
77 |
+
self.conv.weight.data = kernel
|
78 |
+
self.conv.bias.data = bias
|
79 |
+
|
80 |
+
def get_equivalent_kernel_bias(self):
|
81 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
82 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
83 |
+
|
84 |
+
if self.alpha:
|
85 |
+
return kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor(
|
86 |
+
kernel1x1
|
87 |
+
), bias3x3 + self.alpha * bias1x1
|
88 |
+
else:
|
89 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
|
90 |
+
|
91 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
92 |
+
if kernel1x1 is None:
|
93 |
+
return 0
|
94 |
+
else:
|
95 |
+
return F.pad(kernel1x1, [1, 1, 1, 1])
|
96 |
+
|
97 |
+
def _fuse_bn_tensor(self, branch: ConvBNLayer):
|
98 |
+
if branch is None:
|
99 |
+
return 0, 0
|
100 |
+
kernel = branch.conv.weight
|
101 |
+
running_mean = branch.norm.running_mean
|
102 |
+
running_var = branch.norm.running_var
|
103 |
+
gamma = branch.norm.weight
|
104 |
+
beta = branch.norm.bias
|
105 |
+
eps = branch.norm.eps
|
106 |
+
std = (running_var + eps).sqrt()
|
107 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
108 |
+
return kernel * t, beta - running_mean * gamma / std
|
109 |
+
|
110 |
+
|
111 |
+
class BasicBlock(nn.Module):
|
112 |
+
def __init__(self, ch_in, ch_out, act="relu", shortcut=True, use_alpha=False):
|
113 |
+
super().__init__()
|
114 |
+
assert ch_in == ch_out
|
115 |
+
self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
|
116 |
+
self.conv2 = RepVggBlock(ch_out, ch_out, act=act, alpha=use_alpha)
|
117 |
+
self.shortcut = shortcut
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
y = self.conv1(x)
|
121 |
+
y = self.conv2(y)
|
122 |
+
if self.shortcut:
|
123 |
+
return x + y
|
124 |
+
else:
|
125 |
+
return y
|
126 |
+
|
127 |
+
|
128 |
+
class EffectiveSELayer(nn.Module):
|
129 |
+
"""Effective Squeeze-Excitation
|
130 |
+
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, channels, act="hardsigmoid"):
|
134 |
+
super(EffectiveSELayer, self).__init__()
|
135 |
+
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
136 |
+
self.act = get_activation(act)
|
137 |
+
|
138 |
+
def forward(self, x: torch.Tensor):
|
139 |
+
x_se = x.mean((2, 3), keepdim=True)
|
140 |
+
x_se = self.fc(x_se)
|
141 |
+
x_se = self.act(x_se)
|
142 |
+
return x * x_se
|
143 |
+
|
144 |
+
|
145 |
+
class CSPResStage(nn.Module):
|
146 |
+
def __init__(self, block_fn, ch_in, ch_out, n, stride, act="relu", attn="eca", use_alpha=False):
|
147 |
+
super().__init__()
|
148 |
+
ch_mid = (ch_in + ch_out) // 2
|
149 |
+
if stride == 2:
|
150 |
+
self.conv_down = ConvBNLayer(ch_in, ch_mid, 3, stride=2, padding=1, act=act)
|
151 |
+
else:
|
152 |
+
self.conv_down = None
|
153 |
+
self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
|
154 |
+
self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
|
155 |
+
self.blocks = nn.Sequential(
|
156 |
+
*[
|
157 |
+
block_fn(ch_mid // 2, ch_mid // 2, act=act, shortcut=True, use_alpha=use_alpha)
|
158 |
+
for i in range(n)
|
159 |
+
]
|
160 |
+
)
|
161 |
+
if attn:
|
162 |
+
self.attn = EffectiveSELayer(ch_mid, act="hardsigmoid")
|
163 |
+
else:
|
164 |
+
self.attn = None
|
165 |
+
|
166 |
+
self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
if self.conv_down is not None:
|
170 |
+
x = self.conv_down(x)
|
171 |
+
y1 = self.conv1(x)
|
172 |
+
y2 = self.blocks(self.conv2(x))
|
173 |
+
y = torch.concat([y1, y2], dim=1)
|
174 |
+
if self.attn is not None:
|
175 |
+
y = self.attn(y)
|
176 |
+
y = self.conv3(y)
|
177 |
+
return y
|
178 |
+
|
179 |
+
|
180 |
+
@register()
|
181 |
+
class CSPResNet(nn.Module):
|
182 |
+
layers = [3, 6, 6, 3]
|
183 |
+
channels = [64, 128, 256, 512, 1024]
|
184 |
+
model_cfg = {
|
185 |
+
"s": {
|
186 |
+
"depth_mult": 0.33,
|
187 |
+
"width_mult": 0.50,
|
188 |
+
},
|
189 |
+
"m": {
|
190 |
+
"depth_mult": 0.67,
|
191 |
+
"width_mult": 0.75,
|
192 |
+
},
|
193 |
+
"l": {
|
194 |
+
"depth_mult": 1.00,
|
195 |
+
"width_mult": 1.00,
|
196 |
+
},
|
197 |
+
"x": {
|
198 |
+
"depth_mult": 1.33,
|
199 |
+
"width_mult": 1.25,
|
200 |
+
},
|
201 |
+
}
|
202 |
+
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
name: str,
|
206 |
+
act="silu",
|
207 |
+
return_idx=[1, 2, 3],
|
208 |
+
use_large_stem=True,
|
209 |
+
use_alpha=False,
|
210 |
+
pretrained=False,
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
depth_mult = self.model_cfg[name]["depth_mult"]
|
214 |
+
width_mult = self.model_cfg[name]["width_mult"]
|
215 |
+
|
216 |
+
channels = [max(round(c * width_mult), 1) for c in self.channels]
|
217 |
+
layers = [max(round(l * depth_mult), 1) for l in self.layers]
|
218 |
+
act = get_activation(act)
|
219 |
+
|
220 |
+
if use_large_stem:
|
221 |
+
self.stem = nn.Sequential(
|
222 |
+
OrderedDict(
|
223 |
+
[
|
224 |
+
(
|
225 |
+
"conv1",
|
226 |
+
ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act),
|
227 |
+
),
|
228 |
+
(
|
229 |
+
"conv2",
|
230 |
+
ConvBNLayer(
|
231 |
+
channels[0] // 2, channels[0] // 2, 3, stride=1, padding=1, act=act
|
232 |
+
),
|
233 |
+
),
|
234 |
+
(
|
235 |
+
"conv3",
|
236 |
+
ConvBNLayer(
|
237 |
+
channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act
|
238 |
+
),
|
239 |
+
),
|
240 |
+
]
|
241 |
+
)
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
self.stem = nn.Sequential(
|
245 |
+
OrderedDict(
|
246 |
+
[
|
247 |
+
(
|
248 |
+
"conv1",
|
249 |
+
ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act),
|
250 |
+
),
|
251 |
+
(
|
252 |
+
"conv2",
|
253 |
+
ConvBNLayer(
|
254 |
+
channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act
|
255 |
+
),
|
256 |
+
),
|
257 |
+
]
|
258 |
+
)
|
259 |
+
)
|
260 |
+
|
261 |
+
n = len(channels) - 1
|
262 |
+
self.stages = nn.Sequential(
|
263 |
+
OrderedDict(
|
264 |
+
[
|
265 |
+
(
|
266 |
+
str(i),
|
267 |
+
CSPResStage(
|
268 |
+
BasicBlock,
|
269 |
+
channels[i],
|
270 |
+
channels[i + 1],
|
271 |
+
layers[i],
|
272 |
+
2,
|
273 |
+
act=act,
|
274 |
+
use_alpha=use_alpha,
|
275 |
+
),
|
276 |
+
)
|
277 |
+
for i in range(n)
|
278 |
+
]
|
279 |
+
)
|
280 |
+
)
|
281 |
+
|
282 |
+
self._out_channels = channels[1:]
|
283 |
+
self._out_strides = [4 * 2**i for i in range(n)]
|
284 |
+
self.return_idx = return_idx
|
285 |
+
|
286 |
+
if pretrained:
|
287 |
+
if isinstance(pretrained, bool) or "http" in pretrained:
|
288 |
+
state = torch.hub.load_state_dict_from_url(donwload_url[name], map_location="cpu")
|
289 |
+
else:
|
290 |
+
state = torch.load(pretrained, map_location="cpu")
|
291 |
+
self.load_state_dict(state)
|
292 |
+
print(f"Load CSPResNet_{name} state_dict")
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
x = self.stem(x)
|
296 |
+
outs = []
|
297 |
+
for idx, stage in enumerate(self.stages):
|
298 |
+
x = stage(x)
|
299 |
+
if idx in self.return_idx:
|
300 |
+
outs.append(x)
|
301 |
+
|
302 |
+
return outs
|
src/nn/backbone/hgnetv2.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
reference
|
3 |
+
- https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
|
4 |
+
|
5 |
+
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from ...core import register
|
16 |
+
from .common import FrozenBatchNorm2d
|
17 |
+
|
18 |
+
# Constants for initialization
|
19 |
+
kaiming_normal_ = nn.init.kaiming_normal_
|
20 |
+
zeros_ = nn.init.zeros_
|
21 |
+
ones_ = nn.init.ones_
|
22 |
+
|
23 |
+
__all__ = ["HGNetv2"]
|
24 |
+
|
25 |
+
def safe_barrier():
|
26 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
27 |
+
torch.distributed.barrier()
|
28 |
+
else:
|
29 |
+
pass
|
30 |
+
|
31 |
+
def safe_get_rank():
|
32 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
33 |
+
return torch.distributed.get_rank()
|
34 |
+
else:
|
35 |
+
return 0
|
36 |
+
|
37 |
+
class LearnableAffineBlock(nn.Module):
|
38 |
+
def __init__(self, scale_value=1.0, bias_value=0.0):
|
39 |
+
super().__init__()
|
40 |
+
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
|
41 |
+
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.scale * x + self.bias
|
45 |
+
|
46 |
+
|
47 |
+
class ConvBNAct(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_chs,
|
51 |
+
out_chs,
|
52 |
+
kernel_size,
|
53 |
+
stride=1,
|
54 |
+
groups=1,
|
55 |
+
padding="",
|
56 |
+
use_act=True,
|
57 |
+
use_lab=False,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.use_act = use_act
|
61 |
+
self.use_lab = use_lab
|
62 |
+
if padding == "same":
|
63 |
+
self.conv = nn.Sequential(
|
64 |
+
nn.ZeroPad2d([0, 1, 0, 1]),
|
65 |
+
nn.Conv2d(in_chs, out_chs, kernel_size, stride, groups=groups, bias=False),
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
self.conv = nn.Conv2d(
|
69 |
+
in_chs,
|
70 |
+
out_chs,
|
71 |
+
kernel_size,
|
72 |
+
stride,
|
73 |
+
padding=(kernel_size - 1) // 2,
|
74 |
+
groups=groups,
|
75 |
+
bias=False,
|
76 |
+
)
|
77 |
+
self.bn = nn.BatchNorm2d(out_chs)
|
78 |
+
if self.use_act:
|
79 |
+
self.act = nn.ReLU()
|
80 |
+
else:
|
81 |
+
self.act = nn.Identity()
|
82 |
+
if self.use_act and self.use_lab:
|
83 |
+
self.lab = LearnableAffineBlock()
|
84 |
+
else:
|
85 |
+
self.lab = nn.Identity()
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.conv(x)
|
89 |
+
x = self.bn(x)
|
90 |
+
x = self.act(x)
|
91 |
+
x = self.lab(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class LightConvBNAct(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
in_chs,
|
99 |
+
out_chs,
|
100 |
+
kernel_size,
|
101 |
+
groups=1,
|
102 |
+
use_lab=False,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
self.conv1 = ConvBNAct(
|
106 |
+
in_chs,
|
107 |
+
out_chs,
|
108 |
+
kernel_size=1,
|
109 |
+
use_act=False,
|
110 |
+
use_lab=use_lab,
|
111 |
+
)
|
112 |
+
self.conv2 = ConvBNAct(
|
113 |
+
out_chs,
|
114 |
+
out_chs,
|
115 |
+
kernel_size=kernel_size,
|
116 |
+
groups=out_chs,
|
117 |
+
use_act=True,
|
118 |
+
use_lab=use_lab,
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = self.conv1(x)
|
123 |
+
x = self.conv2(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class StemBlock(nn.Module):
|
128 |
+
# for HGNetv2
|
129 |
+
def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
|
130 |
+
super().__init__()
|
131 |
+
self.stem1 = ConvBNAct(
|
132 |
+
in_chs,
|
133 |
+
mid_chs,
|
134 |
+
kernel_size=3,
|
135 |
+
stride=2,
|
136 |
+
use_lab=use_lab,
|
137 |
+
)
|
138 |
+
self.stem2a = ConvBNAct(
|
139 |
+
mid_chs,
|
140 |
+
mid_chs // 2,
|
141 |
+
kernel_size=2,
|
142 |
+
stride=1,
|
143 |
+
use_lab=use_lab,
|
144 |
+
)
|
145 |
+
self.stem2b = ConvBNAct(
|
146 |
+
mid_chs // 2,
|
147 |
+
mid_chs,
|
148 |
+
kernel_size=2,
|
149 |
+
stride=1,
|
150 |
+
use_lab=use_lab,
|
151 |
+
)
|
152 |
+
self.stem3 = ConvBNAct(
|
153 |
+
mid_chs * 2,
|
154 |
+
mid_chs,
|
155 |
+
kernel_size=3,
|
156 |
+
stride=2,
|
157 |
+
use_lab=use_lab,
|
158 |
+
)
|
159 |
+
self.stem4 = ConvBNAct(
|
160 |
+
mid_chs,
|
161 |
+
out_chs,
|
162 |
+
kernel_size=1,
|
163 |
+
stride=1,
|
164 |
+
use_lab=use_lab,
|
165 |
+
)
|
166 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
x = self.stem1(x)
|
170 |
+
x = F.pad(x, (0, 1, 0, 1))
|
171 |
+
x2 = self.stem2a(x)
|
172 |
+
x2 = F.pad(x2, (0, 1, 0, 1))
|
173 |
+
x2 = self.stem2b(x2)
|
174 |
+
x1 = self.pool(x)
|
175 |
+
x = torch.cat([x1, x2], dim=1)
|
176 |
+
x = self.stem3(x)
|
177 |
+
x = self.stem4(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class EseModule(nn.Module):
|
182 |
+
def __init__(self, chs):
|
183 |
+
super().__init__()
|
184 |
+
self.conv = nn.Conv2d(
|
185 |
+
chs,
|
186 |
+
chs,
|
187 |
+
kernel_size=1,
|
188 |
+
stride=1,
|
189 |
+
padding=0,
|
190 |
+
)
|
191 |
+
self.sigmoid = nn.Sigmoid()
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
identity = x
|
195 |
+
x = x.mean((2, 3), keepdim=True)
|
196 |
+
x = self.conv(x)
|
197 |
+
x = self.sigmoid(x)
|
198 |
+
return torch.mul(identity, x)
|
199 |
+
|
200 |
+
|
201 |
+
class HG_Block(nn.Module):
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
in_chs,
|
205 |
+
mid_chs,
|
206 |
+
out_chs,
|
207 |
+
layer_num,
|
208 |
+
kernel_size=3,
|
209 |
+
residual=False,
|
210 |
+
light_block=False,
|
211 |
+
use_lab=False,
|
212 |
+
agg="ese",
|
213 |
+
drop_path=0.0,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
self.residual = residual
|
217 |
+
|
218 |
+
self.layers = nn.ModuleList()
|
219 |
+
for i in range(layer_num):
|
220 |
+
if light_block:
|
221 |
+
self.layers.append(
|
222 |
+
LightConvBNAct(
|
223 |
+
in_chs if i == 0 else mid_chs,
|
224 |
+
mid_chs,
|
225 |
+
kernel_size=kernel_size,
|
226 |
+
use_lab=use_lab,
|
227 |
+
)
|
228 |
+
)
|
229 |
+
else:
|
230 |
+
self.layers.append(
|
231 |
+
ConvBNAct(
|
232 |
+
in_chs if i == 0 else mid_chs,
|
233 |
+
mid_chs,
|
234 |
+
kernel_size=kernel_size,
|
235 |
+
stride=1,
|
236 |
+
use_lab=use_lab,
|
237 |
+
)
|
238 |
+
)
|
239 |
+
|
240 |
+
# feature aggregation
|
241 |
+
total_chs = in_chs + layer_num * mid_chs
|
242 |
+
if agg == "se":
|
243 |
+
aggregation_squeeze_conv = ConvBNAct(
|
244 |
+
total_chs,
|
245 |
+
out_chs // 2,
|
246 |
+
kernel_size=1,
|
247 |
+
stride=1,
|
248 |
+
use_lab=use_lab,
|
249 |
+
)
|
250 |
+
aggregation_excitation_conv = ConvBNAct(
|
251 |
+
out_chs // 2,
|
252 |
+
out_chs,
|
253 |
+
kernel_size=1,
|
254 |
+
stride=1,
|
255 |
+
use_lab=use_lab,
|
256 |
+
)
|
257 |
+
self.aggregation = nn.Sequential(
|
258 |
+
aggregation_squeeze_conv,
|
259 |
+
aggregation_excitation_conv,
|
260 |
+
)
|
261 |
+
else:
|
262 |
+
aggregation_conv = ConvBNAct(
|
263 |
+
total_chs,
|
264 |
+
out_chs,
|
265 |
+
kernel_size=1,
|
266 |
+
stride=1,
|
267 |
+
use_lab=use_lab,
|
268 |
+
)
|
269 |
+
att = EseModule(out_chs)
|
270 |
+
self.aggregation = nn.Sequential(
|
271 |
+
aggregation_conv,
|
272 |
+
att,
|
273 |
+
)
|
274 |
+
|
275 |
+
self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity()
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
identity = x
|
279 |
+
output = [x]
|
280 |
+
for layer in self.layers:
|
281 |
+
x = layer(x)
|
282 |
+
output.append(x)
|
283 |
+
x = torch.cat(output, dim=1)
|
284 |
+
x = self.aggregation(x)
|
285 |
+
if self.residual:
|
286 |
+
x = self.drop_path(x) + identity
|
287 |
+
return x
|
288 |
+
|
289 |
+
|
290 |
+
class HG_Stage(nn.Module):
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
in_chs,
|
294 |
+
mid_chs,
|
295 |
+
out_chs,
|
296 |
+
block_num,
|
297 |
+
layer_num,
|
298 |
+
downsample=True,
|
299 |
+
light_block=False,
|
300 |
+
kernel_size=3,
|
301 |
+
use_lab=False,
|
302 |
+
agg="se",
|
303 |
+
drop_path=0.0,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
self.downsample = downsample
|
307 |
+
if downsample:
|
308 |
+
self.downsample = ConvBNAct(
|
309 |
+
in_chs,
|
310 |
+
in_chs,
|
311 |
+
kernel_size=3,
|
312 |
+
stride=2,
|
313 |
+
groups=in_chs,
|
314 |
+
use_act=False,
|
315 |
+
use_lab=use_lab,
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
self.downsample = nn.Identity()
|
319 |
+
|
320 |
+
blocks_list = []
|
321 |
+
for i in range(block_num):
|
322 |
+
blocks_list.append(
|
323 |
+
HG_Block(
|
324 |
+
in_chs if i == 0 else out_chs,
|
325 |
+
mid_chs,
|
326 |
+
out_chs,
|
327 |
+
layer_num,
|
328 |
+
residual=False if i == 0 else True,
|
329 |
+
kernel_size=kernel_size,
|
330 |
+
light_block=light_block,
|
331 |
+
use_lab=use_lab,
|
332 |
+
agg=agg,
|
333 |
+
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
|
334 |
+
)
|
335 |
+
)
|
336 |
+
self.blocks = nn.Sequential(*blocks_list)
|
337 |
+
|
338 |
+
def forward(self, x):
|
339 |
+
x = self.downsample(x)
|
340 |
+
x = self.blocks(x)
|
341 |
+
return x
|
342 |
+
|
343 |
+
|
344 |
+
@register()
|
345 |
+
class HGNetv2(nn.Module):
|
346 |
+
"""
|
347 |
+
HGNetV2
|
348 |
+
Args:
|
349 |
+
stem_channels: list. Number of channels for the stem block.
|
350 |
+
stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc.
|
351 |
+
use_lab: boolean. Whether to use LearnableAffineBlock in network.
|
352 |
+
lr_mult_list: list. Control the learning rate of different stages.
|
353 |
+
Returns:
|
354 |
+
model: nn.Layer. Specific HGNetV2 model depends on args.
|
355 |
+
"""
|
356 |
+
|
357 |
+
arch_configs = {
|
358 |
+
"B0": {
|
359 |
+
"stem_channels": [3, 16, 16],
|
360 |
+
"stage_config": {
|
361 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
362 |
+
"stage1": [16, 16, 64, 1, False, False, 3, 3],
|
363 |
+
"stage2": [64, 32, 256, 1, True, False, 3, 3],
|
364 |
+
"stage3": [256, 64, 512, 2, True, True, 5, 3],
|
365 |
+
"stage4": [512, 128, 1024, 1, True, True, 5, 3],
|
366 |
+
},
|
367 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth",
|
368 |
+
},
|
369 |
+
"B1": {
|
370 |
+
"stem_channels": [3, 24, 32],
|
371 |
+
"stage_config": {
|
372 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
373 |
+
"stage1": [32, 32, 64, 1, False, False, 3, 3],
|
374 |
+
"stage2": [64, 48, 256, 1, True, False, 3, 3],
|
375 |
+
"stage3": [256, 96, 512, 2, True, True, 5, 3],
|
376 |
+
"stage4": [512, 192, 1024, 1, True, True, 5, 3],
|
377 |
+
},
|
378 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth",
|
379 |
+
},
|
380 |
+
"B2": {
|
381 |
+
"stem_channels": [3, 24, 32],
|
382 |
+
"stage_config": {
|
383 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
384 |
+
"stage1": [32, 32, 96, 1, False, False, 3, 4],
|
385 |
+
"stage2": [96, 64, 384, 1, True, False, 3, 4],
|
386 |
+
"stage3": [384, 128, 768, 3, True, True, 5, 4],
|
387 |
+
"stage4": [768, 256, 1536, 1, True, True, 5, 4],
|
388 |
+
},
|
389 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth",
|
390 |
+
},
|
391 |
+
"B3": {
|
392 |
+
"stem_channels": [3, 24, 32],
|
393 |
+
"stage_config": {
|
394 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
395 |
+
"stage1": [32, 32, 128, 1, False, False, 3, 5],
|
396 |
+
"stage2": [128, 64, 512, 1, True, False, 3, 5],
|
397 |
+
"stage3": [512, 128, 1024, 3, True, True, 5, 5],
|
398 |
+
"stage4": [1024, 256, 2048, 1, True, True, 5, 5],
|
399 |
+
},
|
400 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth",
|
401 |
+
},
|
402 |
+
"B4": {
|
403 |
+
"stem_channels": [3, 32, 48],
|
404 |
+
"stage_config": {
|
405 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
406 |
+
"stage1": [48, 48, 128, 1, False, False, 3, 6],
|
407 |
+
"stage2": [128, 96, 512, 1, True, False, 3, 6],
|
408 |
+
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
|
409 |
+
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
|
410 |
+
},
|
411 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth",
|
412 |
+
},
|
413 |
+
"B5": {
|
414 |
+
"stem_channels": [3, 32, 64],
|
415 |
+
"stage_config": {
|
416 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
417 |
+
"stage1": [64, 64, 128, 1, False, False, 3, 6],
|
418 |
+
"stage2": [128, 128, 512, 2, True, False, 3, 6],
|
419 |
+
"stage3": [512, 256, 1024, 5, True, True, 5, 6],
|
420 |
+
"stage4": [1024, 512, 2048, 2, True, True, 5, 6],
|
421 |
+
},
|
422 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth",
|
423 |
+
},
|
424 |
+
"B6": {
|
425 |
+
"stem_channels": [3, 48, 96],
|
426 |
+
"stage_config": {
|
427 |
+
# in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
|
428 |
+
"stage1": [96, 96, 192, 2, False, False, 3, 6],
|
429 |
+
"stage2": [192, 192, 512, 3, True, False, 3, 6],
|
430 |
+
"stage3": [512, 384, 1024, 6, True, True, 5, 6],
|
431 |
+
"stage4": [1024, 768, 2048, 3, True, True, 5, 6],
|
432 |
+
},
|
433 |
+
"url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth",
|
434 |
+
},
|
435 |
+
}
|
436 |
+
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
name,
|
440 |
+
use_lab=False,
|
441 |
+
return_idx=[1, 2, 3],
|
442 |
+
freeze_stem_only=True,
|
443 |
+
freeze_at=0,
|
444 |
+
freeze_norm=True,
|
445 |
+
pretrained=True,
|
446 |
+
local_model_dir="weight/hgnetv2/",
|
447 |
+
):
|
448 |
+
super().__init__()
|
449 |
+
self.use_lab = use_lab
|
450 |
+
self.return_idx = return_idx
|
451 |
+
|
452 |
+
stem_channels = self.arch_configs[name]["stem_channels"]
|
453 |
+
stage_config = self.arch_configs[name]["stage_config"]
|
454 |
+
download_url = self.arch_configs[name]["url"]
|
455 |
+
|
456 |
+
self._out_strides = [4, 8, 16, 32]
|
457 |
+
self._out_channels = [stage_config[k][2] for k in stage_config]
|
458 |
+
|
459 |
+
# stem
|
460 |
+
self.stem = StemBlock(
|
461 |
+
in_chs=stem_channels[0],
|
462 |
+
mid_chs=stem_channels[1],
|
463 |
+
out_chs=stem_channels[2],
|
464 |
+
use_lab=use_lab,
|
465 |
+
)
|
466 |
+
|
467 |
+
# stages
|
468 |
+
self.stages = nn.ModuleList()
|
469 |
+
for i, k in enumerate(stage_config):
|
470 |
+
(
|
471 |
+
in_channels,
|
472 |
+
mid_channels,
|
473 |
+
out_channels,
|
474 |
+
block_num,
|
475 |
+
downsample,
|
476 |
+
light_block,
|
477 |
+
kernel_size,
|
478 |
+
layer_num,
|
479 |
+
) = stage_config[k]
|
480 |
+
self.stages.append(
|
481 |
+
HG_Stage(
|
482 |
+
in_channels,
|
483 |
+
mid_channels,
|
484 |
+
out_channels,
|
485 |
+
block_num,
|
486 |
+
layer_num,
|
487 |
+
downsample,
|
488 |
+
light_block,
|
489 |
+
kernel_size,
|
490 |
+
use_lab,
|
491 |
+
)
|
492 |
+
)
|
493 |
+
|
494 |
+
if freeze_at >= 0:
|
495 |
+
self._freeze_parameters(self.stem)
|
496 |
+
if not freeze_stem_only:
|
497 |
+
for i in range(min(freeze_at + 1, len(self.stages))):
|
498 |
+
self._freeze_parameters(self.stages[i])
|
499 |
+
|
500 |
+
if freeze_norm:
|
501 |
+
self._freeze_norm(self)
|
502 |
+
|
503 |
+
if pretrained:
|
504 |
+
RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
|
505 |
+
try:
|
506 |
+
model_path = local_model_dir + "PPHGNetV2_" + name + "_stage1.pth"
|
507 |
+
if os.path.exists(model_path):
|
508 |
+
state = torch.load(model_path, map_location="cpu")
|
509 |
+
print(f"Loaded stage1 {name} HGNetV2 from local file.")
|
510 |
+
else:
|
511 |
+
# If the file doesn't exist locally, download from the URL
|
512 |
+
if safe_get_rank() == 0:
|
513 |
+
print(
|
514 |
+
GREEN
|
515 |
+
+ "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection."
|
516 |
+
+ RESET
|
517 |
+
)
|
518 |
+
print(
|
519 |
+
GREEN
|
520 |
+
+ "Please check your network connection. Or download the model manually from "
|
521 |
+
+ RESET
|
522 |
+
+ f"{download_url}"
|
523 |
+
+ GREEN
|
524 |
+
+ " to "
|
525 |
+
+ RESET
|
526 |
+
+ f"{local_model_dir}."
|
527 |
+
+ RESET
|
528 |
+
)
|
529 |
+
state = torch.hub.load_state_dict_from_url(
|
530 |
+
download_url, map_location="cpu", model_dir=local_model_dir
|
531 |
+
)
|
532 |
+
safe_barrier()
|
533 |
+
else:
|
534 |
+
safe_barrier()
|
535 |
+
state = torch.load(local_model_dir)
|
536 |
+
|
537 |
+
print(f"Loaded stage1 {name} HGNetV2 from URL.")
|
538 |
+
|
539 |
+
self.load_state_dict(state)
|
540 |
+
|
541 |
+
except (Exception, KeyboardInterrupt) as e:
|
542 |
+
if safe_get_rank() == 0:
|
543 |
+
print(f"{str(e)}")
|
544 |
+
logging.error(
|
545 |
+
RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET
|
546 |
+
)
|
547 |
+
logging.error(
|
548 |
+
GREEN
|
549 |
+
+ "Please check your network connection. Or download the model manually from "
|
550 |
+
+ RESET
|
551 |
+
+ f"{download_url}"
|
552 |
+
+ GREEN
|
553 |
+
+ " to "
|
554 |
+
+ RESET
|
555 |
+
+ f"{local_model_dir}."
|
556 |
+
+ RESET
|
557 |
+
)
|
558 |
+
exit()
|
559 |
+
|
560 |
+
def _freeze_norm(self, m: nn.Module):
|
561 |
+
if isinstance(m, nn.BatchNorm2d):
|
562 |
+
m = FrozenBatchNorm2d(m.num_features)
|
563 |
+
else:
|
564 |
+
for name, child in m.named_children():
|
565 |
+
_child = self._freeze_norm(child)
|
566 |
+
if _child is not child:
|
567 |
+
setattr(m, name, _child)
|
568 |
+
return m
|
569 |
+
|
570 |
+
def _freeze_parameters(self, m: nn.Module):
|
571 |
+
for p in m.parameters():
|
572 |
+
p.requires_grad = False
|
573 |
+
|
574 |
+
def forward(self, x):
|
575 |
+
x = self.stem(x)
|
576 |
+
outs = []
|
577 |
+
for idx, stage in enumerate(self.stages):
|
578 |
+
x = stage(x)
|
579 |
+
if idx in self.return_idx:
|
580 |
+
outs.append(x)
|
581 |
+
return outs
|
src/nn/backbone/presnet.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from ...core import register
|
13 |
+
from .common import FrozenBatchNorm2d, get_activation
|
14 |
+
|
15 |
+
__all__ = ["PResNet"]
|
16 |
+
|
17 |
+
|
18 |
+
ResNet_cfg = {
|
19 |
+
18: [2, 2, 2, 2],
|
20 |
+
34: [3, 4, 6, 3],
|
21 |
+
50: [3, 4, 6, 3],
|
22 |
+
101: [3, 4, 23, 3],
|
23 |
+
# 152: [3, 8, 36, 3],
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
donwload_url = {
|
28 |
+
18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth",
|
29 |
+
34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth",
|
30 |
+
50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth",
|
31 |
+
101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth",
|
32 |
+
}
|
33 |
+
|
34 |
+
|
35 |
+
class ConvNormLayer(nn.Module):
|
36 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
|
37 |
+
super().__init__()
|
38 |
+
self.conv = nn.Conv2d(
|
39 |
+
ch_in,
|
40 |
+
ch_out,
|
41 |
+
kernel_size,
|
42 |
+
stride,
|
43 |
+
padding=(kernel_size - 1) // 2 if padding is None else padding,
|
44 |
+
bias=bias,
|
45 |
+
)
|
46 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
47 |
+
self.act = get_activation(act)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.act(self.norm(self.conv(x)))
|
51 |
+
|
52 |
+
|
53 |
+
class BasicBlock(nn.Module):
|
54 |
+
expansion = 1
|
55 |
+
|
56 |
+
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.shortcut = shortcut
|
60 |
+
|
61 |
+
if not shortcut:
|
62 |
+
if variant == "d" and stride == 2:
|
63 |
+
self.short = nn.Sequential(
|
64 |
+
OrderedDict(
|
65 |
+
[
|
66 |
+
("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
|
67 |
+
("conv", ConvNormLayer(ch_in, ch_out, 1, 1)),
|
68 |
+
]
|
69 |
+
)
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
|
73 |
+
|
74 |
+
self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
|
75 |
+
self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
|
76 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
out = self.branch2a(x)
|
80 |
+
out = self.branch2b(out)
|
81 |
+
if self.shortcut:
|
82 |
+
short = x
|
83 |
+
else:
|
84 |
+
short = self.short(x)
|
85 |
+
|
86 |
+
out = out + short
|
87 |
+
out = self.act(out)
|
88 |
+
|
89 |
+
return out
|
90 |
+
|
91 |
+
|
92 |
+
class BottleNeck(nn.Module):
|
93 |
+
expansion = 4
|
94 |
+
|
95 |
+
def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
if variant == "a":
|
99 |
+
stride1, stride2 = stride, 1
|
100 |
+
else:
|
101 |
+
stride1, stride2 = 1, stride
|
102 |
+
|
103 |
+
width = ch_out
|
104 |
+
|
105 |
+
self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
|
106 |
+
self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
|
107 |
+
self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
|
108 |
+
|
109 |
+
self.shortcut = shortcut
|
110 |
+
if not shortcut:
|
111 |
+
if variant == "d" and stride == 2:
|
112 |
+
self.short = nn.Sequential(
|
113 |
+
OrderedDict(
|
114 |
+
[
|
115 |
+
("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
|
116 |
+
("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)),
|
117 |
+
]
|
118 |
+
)
|
119 |
+
)
|
120 |
+
else:
|
121 |
+
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
|
122 |
+
|
123 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
out = self.branch2a(x)
|
127 |
+
out = self.branch2b(out)
|
128 |
+
out = self.branch2c(out)
|
129 |
+
|
130 |
+
if self.shortcut:
|
131 |
+
short = x
|
132 |
+
else:
|
133 |
+
short = self.short(x)
|
134 |
+
|
135 |
+
out = out + short
|
136 |
+
out = self.act(out)
|
137 |
+
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
class Blocks(nn.Module):
|
142 |
+
def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
self.blocks = nn.ModuleList()
|
146 |
+
for i in range(count):
|
147 |
+
self.blocks.append(
|
148 |
+
block(
|
149 |
+
ch_in,
|
150 |
+
ch_out,
|
151 |
+
stride=2 if i == 0 and stage_num != 2 else 1,
|
152 |
+
shortcut=False if i == 0 else True,
|
153 |
+
variant=variant,
|
154 |
+
act=act,
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
if i == 0:
|
159 |
+
ch_in = ch_out * block.expansion
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
out = x
|
163 |
+
for block in self.blocks:
|
164 |
+
out = block(out)
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
@register()
|
169 |
+
class PResNet(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
depth,
|
173 |
+
variant="d",
|
174 |
+
num_stages=4,
|
175 |
+
return_idx=[0, 1, 2, 3],
|
176 |
+
act="relu",
|
177 |
+
freeze_at=-1,
|
178 |
+
freeze_norm=True,
|
179 |
+
pretrained=False,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
block_nums = ResNet_cfg[depth]
|
184 |
+
ch_in = 64
|
185 |
+
if variant in ["c", "d"]:
|
186 |
+
conv_def = [
|
187 |
+
[3, ch_in // 2, 3, 2, "conv1_1"],
|
188 |
+
[ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
|
189 |
+
[ch_in // 2, ch_in, 3, 1, "conv1_3"],
|
190 |
+
]
|
191 |
+
else:
|
192 |
+
conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
|
193 |
+
|
194 |
+
self.conv1 = nn.Sequential(
|
195 |
+
OrderedDict(
|
196 |
+
[
|
197 |
+
(name, ConvNormLayer(cin, cout, k, s, act=act))
|
198 |
+
for cin, cout, k, s, name in conv_def
|
199 |
+
]
|
200 |
+
)
|
201 |
+
)
|
202 |
+
|
203 |
+
ch_out_list = [64, 128, 256, 512]
|
204 |
+
block = BottleNeck if depth >= 50 else BasicBlock
|
205 |
+
|
206 |
+
_out_channels = [block.expansion * v for v in ch_out_list]
|
207 |
+
_out_strides = [4, 8, 16, 32]
|
208 |
+
|
209 |
+
self.res_layers = nn.ModuleList()
|
210 |
+
for i in range(num_stages):
|
211 |
+
stage_num = i + 2
|
212 |
+
self.res_layers.append(
|
213 |
+
Blocks(
|
214 |
+
block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant
|
215 |
+
)
|
216 |
+
)
|
217 |
+
ch_in = _out_channels[i]
|
218 |
+
|
219 |
+
self.return_idx = return_idx
|
220 |
+
self.out_channels = [_out_channels[_i] for _i in return_idx]
|
221 |
+
self.out_strides = [_out_strides[_i] for _i in return_idx]
|
222 |
+
|
223 |
+
if freeze_at >= 0:
|
224 |
+
self._freeze_parameters(self.conv1)
|
225 |
+
for i in range(min(freeze_at, num_stages)):
|
226 |
+
self._freeze_parameters(self.res_layers[i])
|
227 |
+
|
228 |
+
if freeze_norm:
|
229 |
+
self._freeze_norm(self)
|
230 |
+
|
231 |
+
if pretrained:
|
232 |
+
if isinstance(pretrained, bool) or "http" in pretrained:
|
233 |
+
state = torch.hub.load_state_dict_from_url(
|
234 |
+
donwload_url[depth], map_location="cpu", model_dir="weight"
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
state = torch.load(pretrained, map_location="cpu")
|
238 |
+
self.load_state_dict(state)
|
239 |
+
print(f"Load PResNet{depth} state_dict")
|
240 |
+
|
241 |
+
def _freeze_parameters(self, m: nn.Module):
|
242 |
+
for p in m.parameters():
|
243 |
+
p.requires_grad = False
|
244 |
+
|
245 |
+
def _freeze_norm(self, m: nn.Module):
|
246 |
+
if isinstance(m, nn.BatchNorm2d):
|
247 |
+
m = FrozenBatchNorm2d(m.num_features)
|
248 |
+
else:
|
249 |
+
for name, child in m.named_children():
|
250 |
+
_child = self._freeze_norm(child)
|
251 |
+
if _child is not child:
|
252 |
+
setattr(m, name, _child)
|
253 |
+
return m
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
conv1 = self.conv1(x)
|
257 |
+
x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
|
258 |
+
outs = []
|
259 |
+
for idx, stage in enumerate(self.res_layers):
|
260 |
+
x = stage(x)
|
261 |
+
if idx in self.return_idx:
|
262 |
+
outs.append(x)
|
263 |
+
return outs
|
src/nn/backbone/test_resnet.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ...core import register
|
8 |
+
|
9 |
+
|
10 |
+
class BasicBlock(nn.Module):
|
11 |
+
expansion = 1
|
12 |
+
|
13 |
+
def __init__(self, in_planes, planes, stride=1):
|
14 |
+
super(BasicBlock, self).__init__()
|
15 |
+
|
16 |
+
self.conv1 = nn.Conv2d(
|
17 |
+
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
18 |
+
)
|
19 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
|
24 |
+
self.shortcut = nn.Sequential()
|
25 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
26 |
+
self.shortcut = nn.Sequential(
|
27 |
+
nn.Conv2d(
|
28 |
+
in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False
|
29 |
+
),
|
30 |
+
nn.BatchNorm2d(self.expansion * planes),
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
35 |
+
out = self.bn2(self.conv2(out))
|
36 |
+
out += self.shortcut(x)
|
37 |
+
out = F.relu(out)
|
38 |
+
return out
|
39 |
+
|
40 |
+
|
41 |
+
class _ResNet(nn.Module):
|
42 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
43 |
+
super().__init__()
|
44 |
+
self.in_planes = 64
|
45 |
+
|
46 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
47 |
+
self.bn1 = nn.BatchNorm2d(64)
|
48 |
+
|
49 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
50 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
51 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
52 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
53 |
+
|
54 |
+
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
55 |
+
|
56 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
57 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
58 |
+
layers = []
|
59 |
+
for stride in strides:
|
60 |
+
layers.append(block(self.in_planes, planes, stride))
|
61 |
+
self.in_planes = planes * block.expansion
|
62 |
+
return nn.Sequential(*layers)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
66 |
+
out = self.layer1(out)
|
67 |
+
out = self.layer2(out)
|
68 |
+
out = self.layer3(out)
|
69 |
+
out = self.layer4(out)
|
70 |
+
out = F.avg_pool2d(out, 4)
|
71 |
+
out = out.view(out.size(0), -1)
|
72 |
+
out = self.linear(out)
|
73 |
+
return out
|
74 |
+
|
75 |
+
|
76 |
+
@register()
|
77 |
+
class MResNet(nn.Module):
|
78 |
+
def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None:
|
79 |
+
super().__init__()
|
80 |
+
self.model = _ResNet(BasicBlock, num_blocks, num_classes)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
return self.model(x)
|
src/nn/backbone/timm_model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
2 |
+
|
3 |
+
https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
|
8 |
+
|
9 |
+
from ...core import register
|
10 |
+
from .utils import IntermediateLayerGetter
|
11 |
+
|
12 |
+
|
13 |
+
@register()
|
14 |
+
class TimmModel(torch.nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self, name, return_layers, pretrained=False, exportable=True, features_only=True, **kwargs
|
17 |
+
) -> None:
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
import timm
|
21 |
+
|
22 |
+
model = timm.create_model(
|
23 |
+
name,
|
24 |
+
pretrained=pretrained,
|
25 |
+
exportable=exportable,
|
26 |
+
features_only=features_only,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
# nodes, _ = get_graph_node_names(model)
|
30 |
+
# print(nodes)
|
31 |
+
# features = {'': ''}
|
32 |
+
# model = create_feature_extractor(model, return_nodes=features)
|
33 |
+
|
34 |
+
assert set(return_layers).issubset(
|
35 |
+
model.feature_info.module_name()
|
36 |
+
), f"return_layers should be a subset of {model.feature_info.module_name()}"
|
37 |
+
|
38 |
+
# self.model = model
|
39 |
+
self.model = IntermediateLayerGetter(model, return_layers)
|
40 |
+
|
41 |
+
return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
|
42 |
+
self.strides = [model.feature_info.reduction()[i] for i in return_idx]
|
43 |
+
self.channels = [model.feature_info.channels()[i] for i in return_idx]
|
44 |
+
self.return_idx = return_idx
|
45 |
+
self.return_layers = return_layers
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor):
|
48 |
+
outputs = self.model(x)
|
49 |
+
# outputs = [outputs[i] for i in self.return_idx]
|
50 |
+
return outputs
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
model = TimmModel(name="resnet34", return_layers=["layer2", "layer3"])
|
55 |
+
data = torch.rand(1, 3, 640, 640)
|
56 |
+
outputs = model(data)
|
57 |
+
|
58 |
+
for output in outputs:
|
59 |
+
print(output.shape)
|
60 |
+
|
61 |
+
"""
|
62 |
+
model:
|
63 |
+
type: TimmModel
|
64 |
+
name: resnet34
|
65 |
+
return_layers: ['layer2', 'layer4']
|
66 |
+
"""
|
src/nn/backbone/torchvision_model.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from ...core import register
|
10 |
+
from .utils import IntermediateLayerGetter
|
11 |
+
|
12 |
+
__all__ = ["TorchVisionModel"]
|
13 |
+
|
14 |
+
|
15 |
+
@register()
|
16 |
+
class TorchVisionModel(torch.nn.Module):
|
17 |
+
def __init__(self, name, return_layers, weights=None, **kwargs) -> None:
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
if weights is not None:
|
21 |
+
weights = getattr(torchvision.models.get_model_weights(name), weights)
|
22 |
+
|
23 |
+
model = torchvision.models.get_model(name, weights=weights, **kwargs)
|
24 |
+
|
25 |
+
# TODO hard code.
|
26 |
+
if hasattr(model, "features"):
|
27 |
+
model = IntermediateLayerGetter(model.features, return_layers)
|
28 |
+
else:
|
29 |
+
model = IntermediateLayerGetter(model, return_layers)
|
30 |
+
|
31 |
+
self.model = model
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.model(x)
|
35 |
+
|
36 |
+
|
37 |
+
# TorchVisionModel('swin_t', return_layers=['5', '7'])
|
38 |
+
# TorchVisionModel('resnet34', return_layers=['layer2','layer3', 'layer4'])
|
39 |
+
|
40 |
+
# TorchVisionModel:
|
41 |
+
# name: swin_t
|
42 |
+
# return_layers: ['5', '7']
|
43 |
+
# weights: DEFAULT
|
44 |
+
|
45 |
+
|
46 |
+
# model:
|
47 |
+
# type: TorchVisionModel
|
48 |
+
# name: resnet34
|
49 |
+
# return_layers: ['layer2','layer3', 'layer4']
|
50 |
+
# weights: DEFAULT
|
src/nn/backbone/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py
|
3 |
+
|
4 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import Dict, List
|
9 |
+
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
|
13 |
+
class IntermediateLayerGetter(nn.ModuleDict):
|
14 |
+
"""
|
15 |
+
Module wrapper that returns intermediate layers from a model
|
16 |
+
|
17 |
+
It has a strong assumption that the modules have been registered
|
18 |
+
into the model in the same order as they are used.
|
19 |
+
This means that one should **not** reuse the same nn.Module
|
20 |
+
twice in the forward if you want this to work.
|
21 |
+
|
22 |
+
Additionally, it is only able to query submodules that are directly
|
23 |
+
assigned to the model. So if `model` is passed, `model.feature1` can
|
24 |
+
be returned, but not `model.feature1.layer2`.
|
25 |
+
"""
|
26 |
+
|
27 |
+
_version = 3
|
28 |
+
|
29 |
+
def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
|
30 |
+
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
31 |
+
raise ValueError(
|
32 |
+
"return_layers are not present in model. {}".format(
|
33 |
+
[name for name, _ in model.named_children()]
|
34 |
+
)
|
35 |
+
)
|
36 |
+
orig_return_layers = return_layers
|
37 |
+
return_layers = {str(k): str(k) for k in return_layers}
|
38 |
+
layers = OrderedDict()
|
39 |
+
for name, module in model.named_children():
|
40 |
+
layers[name] = module
|
41 |
+
if name in return_layers:
|
42 |
+
del return_layers[name]
|
43 |
+
if not return_layers:
|
44 |
+
break
|
45 |
+
|
46 |
+
super().__init__(layers)
|
47 |
+
self.return_layers = orig_return_layers
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
outputs = []
|
51 |
+
for name, module in self.items():
|
52 |
+
x = module(x)
|
53 |
+
if name in self.return_layers:
|
54 |
+
outputs.append(x)
|
55 |
+
|
56 |
+
return outputs
|
src/nn/criterion/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ...core import register
|
9 |
+
from .det_criterion import DetCriterion
|
10 |
+
|
11 |
+
CrossEntropyLoss = register()(nn.CrossEntropyLoss)
|
src/nn/criterion/det_criterion.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
|
11 |
+
from ...core import register
|
12 |
+
from ...misc import box_ops, dist_utils
|
13 |
+
|
14 |
+
|
15 |
+
@register()
|
16 |
+
class DetCriterion(torch.nn.Module):
|
17 |
+
"""Default Detection Criterion"""
|
18 |
+
|
19 |
+
__share__ = ["num_classes"]
|
20 |
+
__inject__ = ["matcher"]
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
losses,
|
25 |
+
weight_dict,
|
26 |
+
num_classes=80,
|
27 |
+
alpha=0.75,
|
28 |
+
gamma=2.0,
|
29 |
+
box_fmt="cxcywh",
|
30 |
+
matcher=None,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Args:
|
34 |
+
losses (list[str]): requested losses, support ['boxes', 'vfl', 'focal']
|
35 |
+
weight_dict (dict[str, float)]: corresponding losses weight, including
|
36 |
+
['loss_bbox', 'loss_giou', 'loss_vfl', 'loss_focal']
|
37 |
+
box_fmt (str): in box format, 'cxcywh' or 'xyxy'
|
38 |
+
matcher (Matcher): matcher used to match source to target
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
self.losses = losses
|
42 |
+
self.weight_dict = weight_dict
|
43 |
+
self.alpha = alpha
|
44 |
+
self.gamma = gamma
|
45 |
+
self.num_classes = num_classes
|
46 |
+
self.box_fmt = box_fmt
|
47 |
+
assert matcher is not None, ""
|
48 |
+
self.matcher = matcher
|
49 |
+
|
50 |
+
def forward(self, outputs, targets, **kwargs):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
outputs: Dict[Tensor], 'pred_boxes', 'pred_logits', 'meta'.
|
54 |
+
targets, List[Dict[str, Tensor]], len(targets) == batch_size.
|
55 |
+
kwargs, store other information such as current epoch id.
|
56 |
+
Return:
|
57 |
+
losses, Dict[str, Tensor]
|
58 |
+
"""
|
59 |
+
matched = self.matcher(outputs, targets)
|
60 |
+
values = matched["values"]
|
61 |
+
indices = matched["indices"]
|
62 |
+
num_boxes = self._get_positive_nums(indices)
|
63 |
+
|
64 |
+
# Compute all the requested losses
|
65 |
+
losses = {}
|
66 |
+
for loss in self.losses:
|
67 |
+
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
68 |
+
l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
69 |
+
losses.update(l_dict)
|
70 |
+
return losses
|
71 |
+
|
72 |
+
def _get_src_permutation_idx(self, indices):
|
73 |
+
# permute predictions following indices
|
74 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
75 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
76 |
+
return batch_idx, src_idx
|
77 |
+
|
78 |
+
def _get_tgt_permutation_idx(self, indices):
|
79 |
+
# permute targets following indices
|
80 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
81 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
82 |
+
return batch_idx, tgt_idx
|
83 |
+
|
84 |
+
def _get_positive_nums(self, indices):
|
85 |
+
# number of positive samples
|
86 |
+
num_pos = sum(len(i) for (i, _) in indices)
|
87 |
+
num_pos = torch.as_tensor([num_pos], dtype=torch.float32, device=indices[0][0].device)
|
88 |
+
if dist_utils.is_dist_available_and_initialized():
|
89 |
+
torch.distributed.all_reduce(num_pos)
|
90 |
+
num_pos = torch.clamp(num_pos / dist_utils.get_world_size(), min=1).item()
|
91 |
+
return num_pos
|
92 |
+
|
93 |
+
def loss_labels_focal(self, outputs, targets, indices, num_boxes):
|
94 |
+
assert "pred_logits" in outputs
|
95 |
+
src_logits = outputs["pred_logits"]
|
96 |
+
|
97 |
+
idx = self._get_src_permutation_idx(indices)
|
98 |
+
target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
|
99 |
+
target_classes = torch.full(
|
100 |
+
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
101 |
+
)
|
102 |
+
target_classes[idx] = target_classes_o
|
103 |
+
|
104 |
+
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1].to(
|
105 |
+
src_logits.dtype
|
106 |
+
)
|
107 |
+
loss = torchvision.ops.sigmoid_focal_loss(
|
108 |
+
src_logits, target, self.alpha, self.gamma, reduction="none"
|
109 |
+
)
|
110 |
+
loss = loss.sum() / num_boxes
|
111 |
+
return {"loss_focal": loss}
|
112 |
+
|
113 |
+
def loss_labels_vfl(self, outputs, targets, indices, num_boxes):
|
114 |
+
assert "pred_boxes" in outputs
|
115 |
+
idx = self._get_src_permutation_idx(indices)
|
116 |
+
|
117 |
+
src_boxes = outputs["pred_boxes"][idx]
|
118 |
+
target_boxes = torch.cat([t["boxes"][j] for t, (_, j) in zip(targets, indices)], dim=0)
|
119 |
+
|
120 |
+
src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
121 |
+
target_boxes = torchvision.ops.box_convert(
|
122 |
+
target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
|
123 |
+
)
|
124 |
+
iou, _ = box_ops.elementwise_box_iou(src_boxes.detach(), target_boxes)
|
125 |
+
|
126 |
+
src_logits: torch.Tensor = outputs["pred_logits"]
|
127 |
+
target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
|
128 |
+
target_classes = torch.full(
|
129 |
+
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
130 |
+
)
|
131 |
+
target_classes[idx] = target_classes_o
|
132 |
+
target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
|
133 |
+
|
134 |
+
target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
|
135 |
+
target_score_o[idx] = iou.to(src_logits.dtype)
|
136 |
+
target_score = target_score_o.unsqueeze(-1) * target
|
137 |
+
|
138 |
+
src_score = F.sigmoid(src_logits.detach())
|
139 |
+
weight = self.alpha * src_score.pow(self.gamma) * (1 - target) + target_score
|
140 |
+
|
141 |
+
loss = F.binary_cross_entropy_with_logits(
|
142 |
+
src_logits, target_score, weight=weight, reduction="none"
|
143 |
+
)
|
144 |
+
loss = loss.sum() / num_boxes
|
145 |
+
return {"loss_vfl": loss}
|
146 |
+
|
147 |
+
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
148 |
+
assert "pred_boxes" in outputs
|
149 |
+
idx = self._get_src_permutation_idx(indices)
|
150 |
+
src_boxes = outputs["pred_boxes"][idx]
|
151 |
+
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
152 |
+
|
153 |
+
losses = {}
|
154 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
155 |
+
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
156 |
+
|
157 |
+
src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
158 |
+
target_boxes = torchvision.ops.box_convert(
|
159 |
+
target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
|
160 |
+
)
|
161 |
+
loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
|
162 |
+
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
163 |
+
return losses
|
164 |
+
|
165 |
+
def loss_boxes_giou(self, outputs, targets, indices, num_boxes):
|
166 |
+
assert "pred_boxes" in outputs
|
167 |
+
idx = self._get_src_permutation_idx(indices)
|
168 |
+
src_boxes = outputs["pred_boxes"][idx]
|
169 |
+
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
170 |
+
|
171 |
+
losses = {}
|
172 |
+
src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
173 |
+
target_boxes = torchvision.ops.box_convert(
|
174 |
+
target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
|
175 |
+
)
|
176 |
+
loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
|
177 |
+
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
178 |
+
return losses
|
179 |
+
|
180 |
+
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
181 |
+
loss_map = {
|
182 |
+
"boxes": self.loss_boxes,
|
183 |
+
"giou": self.loss_boxes_giou,
|
184 |
+
"vfl": self.loss_labels_vfl,
|
185 |
+
"focal": self.loss_labels_focal,
|
186 |
+
}
|
187 |
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
188 |
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
src/nn/postprocessor/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .nms_postprocessor import DetNMSPostProcessor
|
src/nn/postprocessor/box_revert.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
|
13 |
+
class BoxProcessFormat(Enum):
|
14 |
+
"""Box process format
|
15 |
+
|
16 |
+
Available formats are
|
17 |
+
* ``RESIZE``
|
18 |
+
* ``RESIZE_KEEP_RATIO``
|
19 |
+
* ``RESIZE_KEEP_RATIO_PADDING``
|
20 |
+
"""
|
21 |
+
|
22 |
+
RESIZE = 1
|
23 |
+
RESIZE_KEEP_RATIO = 2
|
24 |
+
RESIZE_KEEP_RATIO_PADDING = 3
|
25 |
+
|
26 |
+
|
27 |
+
def box_revert(
|
28 |
+
boxes: Tensor,
|
29 |
+
orig_sizes: Tensor = None,
|
30 |
+
eval_sizes: Tensor = None,
|
31 |
+
inpt_sizes: Tensor = None,
|
32 |
+
inpt_padding: Tensor = None,
|
33 |
+
normalized: bool = True,
|
34 |
+
in_fmt: str = "cxcywh",
|
35 |
+
out_fmt: str = "xyxy",
|
36 |
+
process_fmt=BoxProcessFormat.RESIZE,
|
37 |
+
) -> Tensor:
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
boxes(Tensor), [N, :, 4], (x1, y1, x2, y2), pred boxes.
|
41 |
+
inpt_sizes(Tensor), [N, 2], (w, h). input sizes.
|
42 |
+
orig_sizes(Tensor), [N, 2], (w, h). origin sizes.
|
43 |
+
inpt_padding (Tensor), [N, 2], (w_pad, h_pad, ...).
|
44 |
+
(inpt_sizes + inpt_padding) == eval_sizes
|
45 |
+
"""
|
46 |
+
assert in_fmt in ("cxcywh", "xyxy"), ""
|
47 |
+
|
48 |
+
if normalized and eval_sizes is not None:
|
49 |
+
boxes = boxes * eval_sizes.repeat(1, 2).unsqueeze(1)
|
50 |
+
|
51 |
+
if inpt_padding is not None:
|
52 |
+
if in_fmt == "xyxy":
|
53 |
+
boxes -= inpt_padding[:, :2].repeat(1, 2).unsqueeze(1)
|
54 |
+
elif in_fmt == "cxcywh":
|
55 |
+
boxes[..., :2] -= inpt_padding[:, :2].repeat(1, 2).unsqueeze(1)
|
56 |
+
|
57 |
+
if orig_sizes is not None:
|
58 |
+
orig_sizes = orig_sizes.repeat(1, 2).unsqueeze(1)
|
59 |
+
if inpt_sizes is not None:
|
60 |
+
inpt_sizes = inpt_sizes.repeat(1, 2).unsqueeze(1)
|
61 |
+
boxes = boxes * (orig_sizes / inpt_sizes)
|
62 |
+
else:
|
63 |
+
boxes = boxes * orig_sizes
|
64 |
+
|
65 |
+
boxes = torchvision.ops.box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt)
|
66 |
+
return boxes
|
src/nn/postprocessor/detr_postprocessor.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
|
11 |
+
__all__ = ["DetDETRPostProcessor"]
|
12 |
+
|
13 |
+
from .box_revert import BoxProcessFormat, box_revert
|
14 |
+
|
15 |
+
|
16 |
+
def mod(a, b):
|
17 |
+
out = a - a // b * b
|
18 |
+
return out
|
19 |
+
|
20 |
+
|
21 |
+
class DetDETRPostProcessor(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
num_classes=80,
|
25 |
+
use_focal_loss=True,
|
26 |
+
num_top_queries=300,
|
27 |
+
box_process_format=BoxProcessFormat.RESIZE,
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
self.use_focal_loss = use_focal_loss
|
31 |
+
self.num_top_queries = num_top_queries
|
32 |
+
self.num_classes = int(num_classes)
|
33 |
+
self.box_process_format = box_process_format
|
34 |
+
self.deploy_mode = False
|
35 |
+
|
36 |
+
def extra_repr(self) -> str:
|
37 |
+
return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}"
|
38 |
+
|
39 |
+
def forward(self, outputs, **kwargs):
|
40 |
+
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
|
41 |
+
|
42 |
+
if self.use_focal_loss:
|
43 |
+
scores = F.sigmoid(logits)
|
44 |
+
scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1)
|
45 |
+
labels = index % self.num_classes
|
46 |
+
# labels = mod(index, self.num_classes) # for tensorrt
|
47 |
+
index = index // self.num_classes
|
48 |
+
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
|
49 |
+
|
50 |
+
else:
|
51 |
+
scores = F.softmax(logits)[:, :, :-1]
|
52 |
+
scores, labels = scores.max(dim=-1)
|
53 |
+
if scores.shape[1] > self.num_top_queries:
|
54 |
+
scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
|
55 |
+
labels = torch.gather(labels, dim=1, index=index)
|
56 |
+
boxes = torch.gather(
|
57 |
+
boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])
|
58 |
+
)
|
59 |
+
|
60 |
+
if kwargs is not None:
|
61 |
+
boxes = box_revert(
|
62 |
+
boxes,
|
63 |
+
in_fmt="cxcywh",
|
64 |
+
out_fmt="xyxy",
|
65 |
+
process_fmt=self.box_process_format,
|
66 |
+
normalized=True,
|
67 |
+
**kwargs,
|
68 |
+
)
|
69 |
+
|
70 |
+
# TODO for onnx export
|
71 |
+
if self.deploy_mode:
|
72 |
+
return labels, boxes, scores
|
73 |
+
|
74 |
+
results = []
|
75 |
+
for lab, box, sco in zip(labels, boxes, scores):
|
76 |
+
result = dict(labels=lab, boxes=box, scores=sco)
|
77 |
+
results.append(result)
|
78 |
+
|
79 |
+
return results
|
80 |
+
|
81 |
+
def deploy(
|
82 |
+
self,
|
83 |
+
):
|
84 |
+
self.eval()
|
85 |
+
self.deploy_mode = True
|
86 |
+
return self
|
src/nn/postprocessor/nms_postprocessor.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
3 |
+
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchvision
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
from ...core import register
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"DetNMSPostProcessor",
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
@register()
|
22 |
+
class DetNMSPostProcessor(torch.nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
iou_threshold=0.7,
|
26 |
+
score_threshold=0.01,
|
27 |
+
keep_topk=300,
|
28 |
+
box_fmt="cxcywh",
|
29 |
+
logit_fmt="sigmoid",
|
30 |
+
) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.iou_threshold = iou_threshold
|
33 |
+
self.score_threshold = score_threshold
|
34 |
+
self.keep_topk = keep_topk
|
35 |
+
self.box_fmt = box_fmt.lower()
|
36 |
+
self.logit_fmt = logit_fmt.lower()
|
37 |
+
self.logit_func = getattr(F, self.logit_fmt, None)
|
38 |
+
self.deploy_mode = False
|
39 |
+
|
40 |
+
def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
|
41 |
+
logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
|
42 |
+
pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
|
43 |
+
pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
|
44 |
+
|
45 |
+
values, pred_labels = torch.max(logits, dim=-1)
|
46 |
+
|
47 |
+
if self.logit_func:
|
48 |
+
pred_scores = self.logit_func(values)
|
49 |
+
else:
|
50 |
+
pred_scores = values
|
51 |
+
|
52 |
+
# TODO for onnx export
|
53 |
+
if self.deploy_mode:
|
54 |
+
blobs = {
|
55 |
+
"pred_labels": pred_labels,
|
56 |
+
"pred_boxes": pred_boxes,
|
57 |
+
"pred_scores": pred_scores,
|
58 |
+
}
|
59 |
+
return blobs
|
60 |
+
|
61 |
+
results = []
|
62 |
+
for i in range(logits.shape[0]):
|
63 |
+
score_keep = pred_scores[i] > self.score_threshold
|
64 |
+
pred_box = pred_boxes[i][score_keep]
|
65 |
+
pred_label = pred_labels[i][score_keep]
|
66 |
+
pred_score = pred_scores[i][score_keep]
|
67 |
+
|
68 |
+
keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
|
69 |
+
keep = keep[: self.keep_topk]
|
70 |
+
|
71 |
+
blob = {
|
72 |
+
"labels": pred_label[keep],
|
73 |
+
"boxes": pred_box[keep],
|
74 |
+
"scores": pred_score[keep],
|
75 |
+
}
|
76 |
+
|
77 |
+
results.append(blob)
|
78 |
+
|
79 |
+
return results
|
80 |
+
|
81 |
+
def deploy(
|
82 |
+
self,
|
83 |
+
):
|
84 |
+
self.eval()
|
85 |
+
self.deploy_mode = True
|
86 |
+
return self
|