developer0hye commited on
Commit
e85fecb
·
verified ·
1 Parent(s): 8765dbd

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__init__.py +6 -0
  2. src/core/__init__.py +9 -0
  3. src/core/_config.py +299 -0
  4. src/core/workspace.py +178 -0
  5. src/core/yaml_config.py +187 -0
  6. src/core/yaml_utils.py +126 -0
  7. src/data/__init__.py +20 -0
  8. src/data/_misc.py +62 -0
  9. src/data/dataloader.py +122 -0
  10. src/data/dataset/__init__.py +17 -0
  11. src/data/dataset/_dataset.py +27 -0
  12. src/data/dataset/cifar_dataset.py +25 -0
  13. src/data/dataset/coco_dataset.py +282 -0
  14. src/data/dataset/coco_eval.py +214 -0
  15. src/data/dataset/coco_utils.py +191 -0
  16. src/data/dataset/voc_detection.py +86 -0
  17. src/data/dataset/voc_eval.py +12 -0
  18. src/data/transforms/__init__.py +21 -0
  19. src/data/transforms/_transforms.py +161 -0
  20. src/data/transforms/container.py +99 -0
  21. src/data/transforms/functional.py +172 -0
  22. src/data/transforms/mosaic.py +83 -0
  23. src/data/transforms/presets.py +4 -0
  24. src/misc/__init__.py +9 -0
  25. src/misc/box_ops.py +106 -0
  26. src/misc/dist_utils.py +281 -0
  27. src/misc/lazy_loader.py +70 -0
  28. src/misc/logger.py +255 -0
  29. src/misc/profiler_utils.py +30 -0
  30. src/misc/visualizer.py +121 -0
  31. src/nn/__init__.py +16 -0
  32. src/nn/arch/__init__.py +7 -0
  33. src/nn/arch/classification.py +45 -0
  34. src/nn/arch/yolo.py +42 -0
  35. src/nn/backbone/__init__.py +17 -0
  36. src/nn/backbone/common.py +117 -0
  37. src/nn/backbone/csp_darknet.py +203 -0
  38. src/nn/backbone/csp_resnet.py +302 -0
  39. src/nn/backbone/hgnetv2.py +581 -0
  40. src/nn/backbone/presnet.py +263 -0
  41. src/nn/backbone/test_resnet.py +83 -0
  42. src/nn/backbone/timm_model.py +66 -0
  43. src/nn/backbone/torchvision_model.py +50 -0
  44. src/nn/backbone/utils.py +56 -0
  45. src/nn/criterion/__init__.py +11 -0
  46. src/nn/criterion/det_criterion.py +188 -0
  47. src/nn/postprocessor/__init__.py +6 -0
  48. src/nn/postprocessor/box_revert.py +66 -0
  49. src/nn/postprocessor/detr_postprocessor.py +86 -0
  50. src/nn/postprocessor/nms_postprocessor.py +86 -0
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
3
+ """
4
+
5
+ # for register purpose
6
+ from . import data, nn, optim, zoo
src/core/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from ._config import BaseConfig
7
+ from .workspace import GLOBAL_CONFIG, create, register
8
+ from .yaml_config import YAMLConfig
9
+ from .yaml_utils import *
src/core/_config.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from pathlib import Path
7
+ from typing import Callable, Dict, List
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.cuda.amp.grad_scaler import GradScaler
12
+ from torch.optim import Optimizer
13
+ from torch.optim.lr_scheduler import LRScheduler
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.utils.tensorboard import SummaryWriter
16
+
17
+ __all__ = [
18
+ "BaseConfig",
19
+ ]
20
+
21
+
22
+ class BaseConfig(object):
23
+ # TODO property
24
+
25
+ def __init__(self) -> None:
26
+ super().__init__()
27
+
28
+ self.task: str = None
29
+
30
+ # instance / function
31
+ self._model: nn.Module = None
32
+ self._postprocessor: nn.Module = None
33
+ self._criterion: nn.Module = None
34
+ self._optimizer: Optimizer = None
35
+ self._lr_scheduler: LRScheduler = None
36
+ self._lr_warmup_scheduler: LRScheduler = None
37
+ self._train_dataloader: DataLoader = None
38
+ self._val_dataloader: DataLoader = None
39
+ self._ema: nn.Module = None
40
+ self._scaler: GradScaler = None
41
+ self._train_dataset: Dataset = None
42
+ self._val_dataset: Dataset = None
43
+ self._collate_fn: Callable = None
44
+ self._evaluator: Callable[[nn.Module, DataLoader, str],] = None
45
+ self._writer: SummaryWriter = None
46
+
47
+ # dataset
48
+ self.num_workers: int = 0
49
+ self.batch_size: int = None
50
+ self._train_batch_size: int = None
51
+ self._val_batch_size: int = None
52
+ self._train_shuffle: bool = None
53
+ self._val_shuffle: bool = None
54
+
55
+ # runtime
56
+ self.resume: str = None
57
+ self.tuning: str = None
58
+
59
+ self.epochs: int = None
60
+ self.last_epoch: int = -1
61
+
62
+ self.use_amp: bool = False
63
+ self.use_ema: bool = False
64
+ self.ema_decay: float = 0.9999
65
+ self.ema_warmups: int = 2000
66
+ self.sync_bn: bool = False
67
+ self.clip_max_norm: float = 0.0
68
+ self.find_unused_parameters: bool = None
69
+
70
+ self.seed: int = None
71
+ self.print_freq: int = None
72
+ self.checkpoint_freq: int = 1
73
+ self.output_dir: str = None
74
+ self.summary_dir: str = None
75
+ self.device: str = ""
76
+
77
+ @property
78
+ def model(self) -> nn.Module:
79
+ return self._model
80
+
81
+ @model.setter
82
+ def model(self, m):
83
+ assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
84
+ self._model = m
85
+
86
+ @property
87
+ def postprocessor(self) -> nn.Module:
88
+ return self._postprocessor
89
+
90
+ @postprocessor.setter
91
+ def postprocessor(self, m):
92
+ assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
93
+ self._postprocessor = m
94
+
95
+ @property
96
+ def criterion(self) -> nn.Module:
97
+ return self._criterion
98
+
99
+ @criterion.setter
100
+ def criterion(self, m):
101
+ assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
102
+ self._criterion = m
103
+
104
+ @property
105
+ def optimizer(self) -> Optimizer:
106
+ return self._optimizer
107
+
108
+ @optimizer.setter
109
+ def optimizer(self, m):
110
+ assert isinstance(
111
+ m, Optimizer
112
+ ), f"{type(m)} != optim.Optimizer, please check your model class"
113
+ self._optimizer = m
114
+
115
+ @property
116
+ def lr_scheduler(self) -> LRScheduler:
117
+ return self._lr_scheduler
118
+
119
+ @lr_scheduler.setter
120
+ def lr_scheduler(self, m):
121
+ assert isinstance(
122
+ m, LRScheduler
123
+ ), f"{type(m)} != LRScheduler, please check your model class"
124
+ self._lr_scheduler = m
125
+
126
+ @property
127
+ def lr_warmup_scheduler(self) -> LRScheduler:
128
+ return self._lr_warmup_scheduler
129
+
130
+ @lr_warmup_scheduler.setter
131
+ def lr_warmup_scheduler(self, m):
132
+ self._lr_warmup_scheduler = m
133
+
134
+ @property
135
+ def train_dataloader(self) -> DataLoader:
136
+ if self._train_dataloader is None and self.train_dataset is not None:
137
+ loader = DataLoader(
138
+ self.train_dataset,
139
+ batch_size=self.train_batch_size,
140
+ num_workers=self.num_workers,
141
+ collate_fn=self.collate_fn,
142
+ shuffle=self.train_shuffle,
143
+ )
144
+ loader.shuffle = self.train_shuffle
145
+ self._train_dataloader = loader
146
+
147
+ return self._train_dataloader
148
+
149
+ @train_dataloader.setter
150
+ def train_dataloader(self, loader):
151
+ self._train_dataloader = loader
152
+
153
+ @property
154
+ def val_dataloader(self) -> DataLoader:
155
+ if self._val_dataloader is None and self.val_dataset is not None:
156
+ loader = DataLoader(
157
+ self.val_dataset,
158
+ batch_size=self.val_batch_size,
159
+ num_workers=self.num_workers,
160
+ drop_last=False,
161
+ collate_fn=self.collate_fn,
162
+ shuffle=self.val_shuffle,
163
+ persistent_workers=True,
164
+ )
165
+ loader.shuffle = self.val_shuffle
166
+ self._val_dataloader = loader
167
+
168
+ return self._val_dataloader
169
+
170
+ @val_dataloader.setter
171
+ def val_dataloader(self, loader):
172
+ self._val_dataloader = loader
173
+
174
+ @property
175
+ def ema(self) -> nn.Module:
176
+ if self._ema is None and self.use_ema and self.model is not None:
177
+ from ..optim import ModelEMA
178
+
179
+ self._ema = ModelEMA(self.model, self.ema_decay, self.ema_warmups)
180
+ return self._ema
181
+
182
+ @ema.setter
183
+ def ema(self, obj):
184
+ self._ema = obj
185
+
186
+ @property
187
+ def scaler(self) -> GradScaler:
188
+ if self._scaler is None and self.use_amp and torch.cuda.is_available():
189
+ self._scaler = GradScaler()
190
+ return self._scaler
191
+
192
+ @scaler.setter
193
+ def scaler(self, obj: GradScaler):
194
+ self._scaler = obj
195
+
196
+ @property
197
+ def val_shuffle(self) -> bool:
198
+ if self._val_shuffle is None:
199
+ print("warning: set default val_shuffle=False")
200
+ return False
201
+ return self._val_shuffle
202
+
203
+ @val_shuffle.setter
204
+ def val_shuffle(self, shuffle):
205
+ assert isinstance(shuffle, bool), "shuffle must be bool"
206
+ self._val_shuffle = shuffle
207
+
208
+ @property
209
+ def train_shuffle(self) -> bool:
210
+ if self._train_shuffle is None:
211
+ print("warning: set default train_shuffle=True")
212
+ return True
213
+ return self._train_shuffle
214
+
215
+ @train_shuffle.setter
216
+ def train_shuffle(self, shuffle):
217
+ assert isinstance(shuffle, bool), "shuffle must be bool"
218
+ self._train_shuffle = shuffle
219
+
220
+ @property
221
+ def train_batch_size(self) -> int:
222
+ if self._train_batch_size is None and isinstance(self.batch_size, int):
223
+ print(f"warning: set train_batch_size=batch_size={self.batch_size}")
224
+ return self.batch_size
225
+ return self._train_batch_size
226
+
227
+ @train_batch_size.setter
228
+ def train_batch_size(self, batch_size):
229
+ assert isinstance(batch_size, int), "batch_size must be int"
230
+ self._train_batch_size = batch_size
231
+
232
+ @property
233
+ def val_batch_size(self) -> int:
234
+ if self._val_batch_size is None:
235
+ print(f"warning: set val_batch_size=batch_size={self.batch_size}")
236
+ return self.batch_size
237
+ return self._val_batch_size
238
+
239
+ @val_batch_size.setter
240
+ def val_batch_size(self, batch_size):
241
+ assert isinstance(batch_size, int), "batch_size must be int"
242
+ self._val_batch_size = batch_size
243
+
244
+ @property
245
+ def train_dataset(self) -> Dataset:
246
+ return self._train_dataset
247
+
248
+ @train_dataset.setter
249
+ def train_dataset(self, dataset):
250
+ assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset"
251
+ self._train_dataset = dataset
252
+
253
+ @property
254
+ def val_dataset(self) -> Dataset:
255
+ return self._val_dataset
256
+
257
+ @val_dataset.setter
258
+ def val_dataset(self, dataset):
259
+ assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset"
260
+ self._val_dataset = dataset
261
+
262
+ @property
263
+ def collate_fn(self) -> Callable:
264
+ return self._collate_fn
265
+
266
+ @collate_fn.setter
267
+ def collate_fn(self, fn):
268
+ assert isinstance(fn, Callable), f"{type(fn)} must be Callable"
269
+ self._collate_fn = fn
270
+
271
+ @property
272
+ def evaluator(self) -> Callable:
273
+ return self._evaluator
274
+
275
+ @evaluator.setter
276
+ def evaluator(self, fn):
277
+ assert isinstance(fn, Callable), f"{type(fn)} must be Callable"
278
+ self._evaluator = fn
279
+
280
+ @property
281
+ def writer(self) -> SummaryWriter:
282
+ if self._writer is None:
283
+ if self.summary_dir:
284
+ self._writer = SummaryWriter(self.summary_dir)
285
+ elif self.output_dir:
286
+ self._writer = SummaryWriter(Path(self.output_dir) / "summary")
287
+ return self._writer
288
+
289
+ @writer.setter
290
+ def writer(self, m):
291
+ assert isinstance(m, SummaryWriter), f"{type(m)} must be SummaryWriter"
292
+ self._writer = m
293
+
294
+ def __repr__(self):
295
+ s = ""
296
+ for k, v in self.__dict__.items():
297
+ if not k.startswith("_"):
298
+ s += f"{k}: {v}\n"
299
+ return s
src/core/workspace.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import functools
7
+ import importlib
8
+ import inspect
9
+ from collections import defaultdict
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ GLOBAL_CONFIG = defaultdict(dict)
13
+
14
+
15
+ def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
16
+ """
17
+ dct:
18
+ if dct is Dict, register foo into dct as key-value pair
19
+ if dct is Clas, register as modules attibute
20
+ force
21
+ whether force register.
22
+ """
23
+
24
+ def decorator(foo):
25
+ register_name = foo.__name__ if name is None else name
26
+ if not force:
27
+ if inspect.isclass(dct):
28
+ assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
29
+ else:
30
+ assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"
31
+
32
+ if inspect.isfunction(foo):
33
+
34
+ @functools.wraps(foo)
35
+ def wrap_func(*args, **kwargs):
36
+ return foo(*args, **kwargs)
37
+
38
+ if isinstance(dct, dict):
39
+ dct[foo.__name__] = wrap_func
40
+ elif inspect.isclass(dct):
41
+ setattr(dct, foo.__name__, wrap_func)
42
+ else:
43
+ raise AttributeError("")
44
+ return wrap_func
45
+
46
+ elif inspect.isclass(foo):
47
+ dct[register_name] = extract_schema(foo)
48
+
49
+ else:
50
+ raise ValueError(f"Do not support {type(foo)} register")
51
+
52
+ return foo
53
+
54
+ return decorator
55
+
56
+
57
+ def extract_schema(module: type):
58
+ """
59
+ Args:
60
+ module (type),
61
+ Return:
62
+ Dict,
63
+ """
64
+ argspec = inspect.getfullargspec(module.__init__)
65
+ arg_names = [arg for arg in argspec.args if arg != "self"]
66
+ num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
67
+ num_requires = len(arg_names) - num_defualts
68
+
69
+ schame = dict()
70
+ schame["_name"] = module.__name__
71
+ schame["_pymodule"] = importlib.import_module(module.__module__)
72
+ schame["_inject"] = getattr(module, "__inject__", [])
73
+ schame["_share"] = getattr(module, "__share__", [])
74
+ schame["_kwargs"] = {}
75
+ for i, name in enumerate(arg_names):
76
+ if name in schame["_share"]:
77
+ assert i >= num_requires, "share config must have default value."
78
+ value = argspec.defaults[i - num_requires]
79
+
80
+ elif i >= num_requires:
81
+ value = argspec.defaults[i - num_requires]
82
+
83
+ else:
84
+ value = None
85
+
86
+ schame[name] = value
87
+ schame["_kwargs"][name] = value
88
+
89
+ return schame
90
+
91
+
92
+ def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
93
+ """ """
94
+ assert type(type_or_name) in (type, str), "create should be modules or name."
95
+
96
+ name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
97
+
98
+ if name in global_cfg:
99
+ if hasattr(global_cfg[name], "__dict__"):
100
+ return global_cfg[name]
101
+ else:
102
+ raise ValueError("The module {} is not registered".format(name))
103
+
104
+ cfg = global_cfg[name]
105
+
106
+ if isinstance(cfg, dict) and "type" in cfg:
107
+ _cfg: dict = global_cfg[cfg["type"]]
108
+ # clean args
109
+ _keys = [k for k in _cfg.keys() if not k.startswith("_")]
110
+ for _arg in _keys:
111
+ del _cfg[_arg]
112
+ _cfg.update(_cfg["_kwargs"]) # restore default args
113
+ _cfg.update(cfg) # load config args
114
+ _cfg.update(kwargs) # TODO recive extra kwargs
115
+ name = _cfg.pop("type") # pop extra key `type` (from cfg)
116
+
117
+ return create(name, global_cfg)
118
+
119
+ module = getattr(cfg["_pymodule"], name)
120
+ module_kwargs = {}
121
+ module_kwargs.update(cfg)
122
+
123
+ # shared var
124
+ for k in cfg["_share"]:
125
+ if k in global_cfg:
126
+ module_kwargs[k] = global_cfg[k]
127
+ else:
128
+ module_kwargs[k] = cfg[k]
129
+
130
+ # inject
131
+ for k in cfg["_inject"]:
132
+ _k = cfg[k]
133
+
134
+ if _k is None:
135
+ continue
136
+
137
+ if isinstance(_k, str):
138
+ if _k not in global_cfg:
139
+ raise ValueError(f"Missing inject config of {_k}.")
140
+
141
+ _cfg = global_cfg[_k]
142
+
143
+ if isinstance(_cfg, dict):
144
+ module_kwargs[k] = create(_cfg["_name"], global_cfg)
145
+ else:
146
+ module_kwargs[k] = _cfg
147
+
148
+ elif isinstance(_k, dict):
149
+ if "type" not in _k.keys():
150
+ raise ValueError("Missing inject for `type` style.")
151
+
152
+ _type = str(_k["type"])
153
+ if _type not in global_cfg:
154
+ raise ValueError(f"Missing {_type} in inspect stage.")
155
+
156
+ # TODO
157
+ _cfg: dict = global_cfg[_type]
158
+ # clean args
159
+ _keys = [k for k in _cfg.keys() if not k.startswith("_")]
160
+ for _arg in _keys:
161
+ del _cfg[_arg]
162
+ _cfg.update(_cfg["_kwargs"]) # restore default values
163
+ _cfg.update(_k) # load config args
164
+ name = _cfg.pop("type") # pop extra key (`type` from _k)
165
+ module_kwargs[k] = create(name, global_cfg)
166
+
167
+ else:
168
+ raise ValueError(f"Inject does not support {_k}")
169
+
170
+ # TODO hard code
171
+ module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}
172
+
173
+ # TODO for **kwargs
174
+ # extra_args = set(module_kwargs.keys()) - set(arg_names)
175
+ # if len(extra_args) > 0:
176
+ # raise RuntimeError(f'Error: unknown args {extra_args} for {module}')
177
+
178
+ return module(**module_kwargs)
src/core/yaml_config.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import copy
7
+ import re
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ from torch.utils.data import DataLoader
13
+
14
+ from ._config import BaseConfig
15
+ from .workspace import create
16
+ from .yaml_utils import load_config, merge_config, merge_dict
17
+
18
+
19
+ class YAMLConfig(BaseConfig):
20
+ def __init__(self, cfg_path: str, **kwargs) -> None:
21
+ super().__init__()
22
+
23
+ cfg = load_config(cfg_path)
24
+ cfg = merge_dict(cfg, kwargs)
25
+
26
+ self.yaml_cfg = copy.deepcopy(cfg)
27
+
28
+ for k in super().__dict__:
29
+ if not k.startswith("_") and k in cfg:
30
+ self.__dict__[k] = cfg[k]
31
+
32
+ @property
33
+ def global_cfg(self):
34
+ return merge_config(self.yaml_cfg, inplace=False, overwrite=False)
35
+
36
+ @property
37
+ def model(self) -> torch.nn.Module:
38
+ if self._model is None and "model" in self.yaml_cfg:
39
+ self._model = create(self.yaml_cfg["model"], self.global_cfg)
40
+ return super().model
41
+
42
+ @property
43
+ def postprocessor(self) -> torch.nn.Module:
44
+ if self._postprocessor is None and "postprocessor" in self.yaml_cfg:
45
+ self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg)
46
+ return super().postprocessor
47
+
48
+ @property
49
+ def criterion(self) -> torch.nn.Module:
50
+ if self._criterion is None and "criterion" in self.yaml_cfg:
51
+ self._criterion = create(self.yaml_cfg["criterion"], self.global_cfg)
52
+ return super().criterion
53
+
54
+ @property
55
+ def optimizer(self) -> optim.Optimizer:
56
+ if self._optimizer is None and "optimizer" in self.yaml_cfg:
57
+ params = self.get_optim_params(self.yaml_cfg["optimizer"], self.model)
58
+ self._optimizer = create("optimizer", self.global_cfg, params=params)
59
+ return super().optimizer
60
+
61
+ @property
62
+ def lr_scheduler(self) -> optim.lr_scheduler.LRScheduler:
63
+ if self._lr_scheduler is None and "lr_scheduler" in self.yaml_cfg:
64
+ self._lr_scheduler = create("lr_scheduler", self.global_cfg, optimizer=self.optimizer)
65
+ print(f"Initial lr: {self._lr_scheduler.get_last_lr()}")
66
+ return super().lr_scheduler
67
+
68
+ @property
69
+ def lr_warmup_scheduler(self) -> optim.lr_scheduler.LRScheduler:
70
+ if self._lr_warmup_scheduler is None and "lr_warmup_scheduler" in self.yaml_cfg:
71
+ self._lr_warmup_scheduler = create(
72
+ "lr_warmup_scheduler", self.global_cfg, lr_scheduler=self.lr_scheduler
73
+ )
74
+ return super().lr_warmup_scheduler
75
+
76
+ @property
77
+ def train_dataloader(self) -> DataLoader:
78
+ if self._train_dataloader is None and "train_dataloader" in self.yaml_cfg:
79
+ self._train_dataloader = self.build_dataloader("train_dataloader")
80
+ return super().train_dataloader
81
+
82
+ @property
83
+ def val_dataloader(self) -> DataLoader:
84
+ if self._val_dataloader is None and "val_dataloader" in self.yaml_cfg:
85
+ self._val_dataloader = self.build_dataloader("val_dataloader")
86
+ return super().val_dataloader
87
+
88
+ @property
89
+ def ema(self) -> torch.nn.Module:
90
+ if self._ema is None and self.yaml_cfg.get("use_ema", False):
91
+ self._ema = create("ema", self.global_cfg, model=self.model)
92
+ return super().ema
93
+
94
+ @property
95
+ def scaler(self):
96
+ if self._scaler is None and self.yaml_cfg.get("use_amp", False):
97
+ self._scaler = create("scaler", self.global_cfg)
98
+ return super().scaler
99
+
100
+ @property
101
+ def evaluator(self):
102
+ if self._evaluator is None and "evaluator" in self.yaml_cfg:
103
+ if self.yaml_cfg["evaluator"]["type"] == "CocoEvaluator":
104
+ from ..data import get_coco_api_from_dataset
105
+
106
+ base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
107
+ self._evaluator = create("evaluator", self.global_cfg, coco_gt=base_ds)
108
+ else:
109
+ raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}")
110
+ return super().evaluator
111
+
112
+ @property
113
+ def use_wandb(self) -> bool:
114
+ return self.yaml_cfg.get("use_wandb", False)
115
+
116
+ @staticmethod
117
+ def get_optim_params(cfg: dict, model: nn.Module):
118
+ """
119
+ E.g.:
120
+ ^(?=.*a)(?=.*b).*$ means including a and b
121
+ ^(?=.*(?:a|b)).*$ means including a or b
122
+ ^(?=.*a)(?!.*b).*$ means including a, but not b
123
+ """
124
+ assert "type" in cfg, ""
125
+ cfg = copy.deepcopy(cfg)
126
+
127
+ if "params" not in cfg:
128
+ return model.parameters()
129
+
130
+ assert isinstance(cfg["params"], list), ""
131
+
132
+ param_groups = []
133
+ visited = []
134
+ for pg in cfg["params"]:
135
+ pattern = pg["params"]
136
+ params = {
137
+ k: v
138
+ for k, v in model.named_parameters()
139
+ if v.requires_grad and len(re.findall(pattern, k)) > 0
140
+ }
141
+ pg["params"] = params.values()
142
+ param_groups.append(pg)
143
+ visited.extend(list(params.keys()))
144
+ # print(params.keys())
145
+
146
+ names = [k for k, v in model.named_parameters() if v.requires_grad]
147
+
148
+ if len(visited) < len(names):
149
+ unseen = set(names) - set(visited)
150
+ params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
151
+ param_groups.append({"params": params.values()})
152
+ visited.extend(list(params.keys()))
153
+ # print(params.keys())
154
+
155
+ assert len(visited) == len(names), ""
156
+
157
+ return param_groups
158
+
159
+ @staticmethod
160
+ def get_rank_batch_size(cfg):
161
+ """compute batch size for per rank if total_batch_size is provided."""
162
+ assert ("total_batch_size" in cfg or "batch_size" in cfg) and not (
163
+ "total_batch_size" in cfg and "batch_size" in cfg
164
+ ), "`batch_size` or `total_batch_size` should be choosed one"
165
+
166
+ total_batch_size = cfg.get("total_batch_size", None)
167
+ if total_batch_size is None:
168
+ bs = cfg.get("batch_size")
169
+ else:
170
+ from ..misc import dist_utils
171
+
172
+ assert (
173
+ total_batch_size % dist_utils.get_world_size() == 0
174
+ ), "total_batch_size should be divisible by world size"
175
+ bs = total_batch_size // dist_utils.get_world_size()
176
+ return bs
177
+
178
+ def build_dataloader(self, name: str):
179
+ bs = self.get_rank_batch_size(self.yaml_cfg[name])
180
+ global_cfg = self.global_cfg
181
+ if "total_batch_size" in global_cfg[name]:
182
+ # pop unexpected key for dataloader init
183
+ _ = global_cfg[name].pop("total_batch_size")
184
+ print(f"building {name} with batch_size={bs}...")
185
+ loader = create(name, global_cfg, batch_size=bs)
186
+ loader.shuffle = self.yaml_cfg[name].get("shuffle", False)
187
+ return loader
src/core/yaml_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import copy
7
+ import os
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import yaml
11
+
12
+ from .workspace import GLOBAL_CONFIG
13
+
14
+ __all__ = [
15
+ "load_config",
16
+ "merge_config",
17
+ "merge_dict",
18
+ "parse_cli",
19
+ ]
20
+
21
+
22
+ INCLUDE_KEY = "__include__"
23
+
24
+
25
+ def load_config(file_path, cfg=dict()):
26
+ """load config"""
27
+ _, ext = os.path.splitext(file_path)
28
+ assert ext in [".yml", ".yaml"], "only support yaml files"
29
+
30
+ with open(file_path) as f:
31
+ file_cfg = yaml.load(f, Loader=yaml.Loader)
32
+ if file_cfg is None:
33
+ return {}
34
+
35
+ if INCLUDE_KEY in file_cfg:
36
+ base_yamls = list(file_cfg[INCLUDE_KEY])
37
+ for base_yaml in base_yamls:
38
+ if base_yaml.startswith("~"):
39
+ base_yaml = os.path.expanduser(base_yaml)
40
+
41
+ if not base_yaml.startswith("/"):
42
+ base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
43
+
44
+ with open(base_yaml) as f:
45
+ base_cfg = load_config(base_yaml, cfg)
46
+ merge_dict(cfg, base_cfg)
47
+
48
+ return merge_dict(cfg, file_cfg)
49
+
50
+
51
+ def merge_dict(dct, another_dct, inplace=True) -> Dict:
52
+ """merge another_dct into dct"""
53
+
54
+ def _merge(dct, another) -> Dict:
55
+ for k in another:
56
+ if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict):
57
+ _merge(dct[k], another[k])
58
+ else:
59
+ dct[k] = another[k]
60
+
61
+ return dct
62
+
63
+ if not inplace:
64
+ dct = copy.deepcopy(dct)
65
+
66
+ return _merge(dct, another_dct)
67
+
68
+
69
+ def dictify(s: str, v: Any) -> Dict:
70
+ if "." not in s:
71
+ return {s: v}
72
+ key, rest = s.split(".", 1)
73
+ return {key: dictify(rest, v)}
74
+
75
+
76
+ def parse_cli(nargs: List[str]) -> Dict:
77
+ """
78
+ parse command-line arguments
79
+ convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}`
80
+ """
81
+ cfg = {}
82
+ if nargs is None or len(nargs) == 0:
83
+ return cfg
84
+
85
+ for s in nargs:
86
+ s = s.strip()
87
+ k, v = s.split("=", 1)
88
+ d = dictify(k, yaml.load(v, Loader=yaml.Loader))
89
+ cfg = merge_dict(cfg, d)
90
+
91
+ return cfg
92
+
93
+
94
+ def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False):
95
+ """
96
+ Merge another_cfg into cfg, return the merged config
97
+
98
+ Example:
99
+
100
+ cfg1 = load_config('./dfine_r18vd_6x_coco.yml')
101
+ cfg1 = merge_config(cfg, inplace=True)
102
+
103
+ cfg2 = load_config('./dfine_r50vd_6x_coco.yml')
104
+ cfg2 = merge_config(cfg2, inplace=True)
105
+
106
+ model1 = create(cfg1['model'], cfg1)
107
+ model2 = create(cfg2['model'], cfg2)
108
+ """
109
+
110
+ def _merge(dct, another):
111
+ for k in another:
112
+ if k not in dct:
113
+ dct[k] = another[k]
114
+
115
+ elif isinstance(dct[k], dict) and isinstance(another[k], dict):
116
+ _merge(dct[k], another[k])
117
+
118
+ elif overwrite:
119
+ dct[k] = another[k]
120
+
121
+ return cfg
122
+
123
+ if not inplace:
124
+ cfg = copy.deepcopy(cfg)
125
+
126
+ return _merge(cfg, another_cfg)
src/data/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from ._misc import convert_to_tv_tensor
7
+ from .dataloader import *
8
+ from .dataset import *
9
+ from .transforms import *
10
+
11
+
12
+ # def set_epoch(self, epoch) -> None:
13
+ # self.epoch = epoch
14
+ # def _set_epoch_func(datasets):
15
+ # """Add `set_epoch` for datasets
16
+ # """
17
+ # from ..core import register
18
+ # for ds in datasets:
19
+ # register(ds)(set_epoch)
20
+ # _set_epoch_func([CIFAR10, VOCDetection, CocoDetection])
src/data/_misc.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import importlib.metadata
7
+
8
+ from torch import Tensor
9
+
10
+ if "0.15.2" in importlib.metadata.version("torchvision"):
11
+ import torchvision
12
+
13
+ torchvision.disable_beta_transforms_warning()
14
+
15
+ from torchvision.datapoints import BoundingBox as BoundingBoxes
16
+ from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video
17
+ from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes
18
+
19
+ _boxes_keys = ["format", "spatial_size"]
20
+
21
+ elif "0.17" > importlib.metadata.version("torchvision") >= "0.16":
22
+ import torchvision
23
+
24
+ torchvision.disable_beta_transforms_warning()
25
+
26
+ from torchvision.transforms.v2 import SanitizeBoundingBoxes
27
+ from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
28
+
29
+ _boxes_keys = ["format", "canvas_size"]
30
+
31
+ elif importlib.metadata.version("torchvision") >= "0.17":
32
+ import torchvision
33
+ from torchvision.transforms.v2 import SanitizeBoundingBoxes
34
+ from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
35
+
36
+ _boxes_keys = ["format", "canvas_size"]
37
+
38
+ else:
39
+ raise RuntimeError("Please make sure torchvision version >= 0.15.2")
40
+
41
+
42
+ def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor:
43
+ """
44
+ Args:
45
+ tensor (Tensor): input tensor
46
+ key (str): transform to key
47
+
48
+ Return:
49
+ Dict[str, TV_Tensor]
50
+ """
51
+ assert key in (
52
+ "boxes",
53
+ "masks",
54
+ ), "Only support 'boxes' and 'masks'"
55
+
56
+ if key == "boxes":
57
+ box_format = getattr(BoundingBoxFormat, box_format.upper())
58
+ _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size]))
59
+ return BoundingBoxes(tensor, **_kwargs)
60
+
61
+ if key == "masks":
62
+ return Mask(tensor)
src/data/dataloader.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import random
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.data as data
12
+ import torchvision
13
+ import torchvision.transforms.v2 as VT
14
+ from torch.utils.data import default_collate
15
+ from torchvision.transforms.v2 import InterpolationMode
16
+ from torchvision.transforms.v2 import functional as VF
17
+
18
+ from ..core import register
19
+
20
+ torchvision.disable_beta_transforms_warning()
21
+
22
+
23
+ __all__ = [
24
+ "DataLoader",
25
+ "BaseCollateFunction",
26
+ "BatchImageCollateFunction",
27
+ "batch_image_collate_fn",
28
+ ]
29
+
30
+
31
+ @register()
32
+ class DataLoader(data.DataLoader):
33
+ __inject__ = ["dataset", "collate_fn"]
34
+
35
+ def __repr__(self) -> str:
36
+ format_string = self.__class__.__name__ + "("
37
+ for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]:
38
+ format_string += "\n"
39
+ format_string += " {0}: {1}".format(n, getattr(self, n))
40
+ format_string += "\n)"
41
+ return format_string
42
+
43
+ def set_epoch(self, epoch):
44
+ self._epoch = epoch
45
+ self.dataset.set_epoch(epoch)
46
+ self.collate_fn.set_epoch(epoch)
47
+
48
+ @property
49
+ def epoch(self):
50
+ return self._epoch if hasattr(self, "_epoch") else -1
51
+
52
+ @property
53
+ def shuffle(self):
54
+ return self._shuffle
55
+
56
+ @shuffle.setter
57
+ def shuffle(self, shuffle):
58
+ assert isinstance(shuffle, bool), "shuffle must be a boolean"
59
+ self._shuffle = shuffle
60
+
61
+
62
+ @register()
63
+ def batch_image_collate_fn(items):
64
+ """only batch image"""
65
+ return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
66
+
67
+
68
+ class BaseCollateFunction(object):
69
+ def set_epoch(self, epoch):
70
+ self._epoch = epoch
71
+
72
+ @property
73
+ def epoch(self):
74
+ return self._epoch if hasattr(self, "_epoch") else -1
75
+
76
+ def __call__(self, items):
77
+ raise NotImplementedError("")
78
+
79
+
80
+ def generate_scales(base_size, base_size_repeat):
81
+ scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32
82
+ scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)]
83
+ scales += [base_size] * base_size_repeat
84
+ scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)]
85
+ return scales
86
+
87
+
88
+ @register()
89
+ class BatchImageCollateFunction(BaseCollateFunction):
90
+ def __init__(
91
+ self,
92
+ stop_epoch=None,
93
+ ema_restart_decay=0.9999,
94
+ base_size=640,
95
+ base_size_repeat=None,
96
+ ) -> None:
97
+ super().__init__()
98
+ self.base_size = base_size
99
+ self.scales = (
100
+ generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None
101
+ )
102
+ self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000
103
+ self.ema_restart_decay = ema_restart_decay
104
+ # self.interpolation = interpolation
105
+
106
+ def __call__(self, items):
107
+ images = torch.cat([x[0][None] for x in items], dim=0)
108
+ targets = [x[1] for x in items]
109
+
110
+ if self.scales is not None and self.epoch < self.stop_epoch:
111
+ # sz = random.choice(self.scales)
112
+ # sz = [sz] if isinstance(sz, int) else list(sz)
113
+ # VF.resize(inpt, sz, interpolation=self.interpolation)
114
+
115
+ sz = random.choice(self.scales)
116
+ images = F.interpolate(images, size=sz)
117
+ if "masks" in targets[0]:
118
+ for tg in targets:
119
+ tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest")
120
+ raise NotImplementedError("")
121
+
122
+ return images, targets
src/data/dataset/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ # from ._dataset import DetDataset
7
+ from .cifar_dataset import CIFAR10
8
+ from .coco_dataset import (
9
+ CocoDetection,
10
+ mscoco_category2label,
11
+ mscoco_category2name,
12
+ mscoco_label2category,
13
+ )
14
+ from .coco_eval import CocoEvaluator
15
+ from .coco_utils import get_coco_api_from_dataset
16
+ from .voc_detection import VOCDetection
17
+ from .voc_eval import VOCEvaluator
src/data/dataset/_dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torch.utils.data as data
8
+
9
+
10
+ class DetDataset(data.Dataset):
11
+ def __getitem__(self, index):
12
+ img, target = self.load_item(index)
13
+ if self.transforms is not None:
14
+ img, target, _ = self.transforms(img, target, self)
15
+ return img, target
16
+
17
+ def load_item(self, index):
18
+ raise NotImplementedError(
19
+ "Please implement this function to return item before `transforms`."
20
+ )
21
+
22
+ def set_epoch(self, epoch) -> None:
23
+ self._epoch = epoch
24
+
25
+ @property
26
+ def epoch(self):
27
+ return self._epoch if hasattr(self, "_epoch") else -1
src/data/dataset/cifar_dataset.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from typing import Callable, Optional
7
+
8
+ import torchvision
9
+
10
+ from ...core import register
11
+
12
+
13
+ @register()
14
+ class CIFAR10(torchvision.datasets.CIFAR10):
15
+ __inject__ = ["transform", "target_transform"]
16
+
17
+ def __init__(
18
+ self,
19
+ root: str,
20
+ train: bool = True,
21
+ transform: Optional[Callable] = None,
22
+ target_transform: Optional[Callable] = None,
23
+ download: bool = False,
24
+ ) -> None:
25
+ super().__init__(root, train, transform, target_transform, download)
src/data/dataset/coco_dataset.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
4
+
5
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
6
+ """
7
+
8
+ import faster_coco_eval
9
+ import faster_coco_eval.core.mask as coco_mask
10
+ import torch
11
+ import torch.utils.data
12
+ import torchvision
13
+ import os
14
+ from PIL import Image
15
+
16
+ from ...core import register
17
+ from .._misc import convert_to_tv_tensor
18
+ from ._dataset import DetDataset
19
+
20
+ torchvision.disable_beta_transforms_warning()
21
+ faster_coco_eval.init_as_pycocotools()
22
+ Image.MAX_IMAGE_PIXELS = None
23
+
24
+ __all__ = ["CocoDetection"]
25
+
26
+
27
+ @register()
28
+ class CocoDetection(torchvision.datasets.CocoDetection, DetDataset):
29
+ __inject__ = [
30
+ "transforms",
31
+ ]
32
+ __share__ = ["remap_mscoco_category"]
33
+
34
+ def __init__(
35
+ self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False
36
+ ):
37
+ super(CocoDetection, self).__init__(img_folder, ann_file)
38
+ self._transforms = transforms
39
+ self.prepare = ConvertCocoPolysToMask(return_masks)
40
+ self.img_folder = img_folder
41
+ self.ann_file = ann_file
42
+ self.return_masks = return_masks
43
+ self.remap_mscoco_category = remap_mscoco_category
44
+
45
+ def __getitem__(self, idx):
46
+ img, target = self.load_item(idx)
47
+ if self._transforms is not None:
48
+ img, target, _ = self._transforms(img, target, self)
49
+ return img, target
50
+
51
+ def load_item(self, idx):
52
+ image, target = super(CocoDetection, self).__getitem__(idx)
53
+ image_id = self.ids[idx]
54
+ image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"])
55
+ target = {"image_id": image_id, "image_path": image_path, "annotations": target}
56
+
57
+ if self.remap_mscoco_category:
58
+ image, target = self.prepare(image, target, category2label=mscoco_category2label)
59
+ else:
60
+ image, target = self.prepare(image, target)
61
+
62
+ target["idx"] = torch.tensor([idx])
63
+
64
+ if "boxes" in target:
65
+ target["boxes"] = convert_to_tv_tensor(
66
+ target["boxes"], key="boxes", spatial_size=image.size[::-1]
67
+ )
68
+
69
+ if "masks" in target:
70
+ target["masks"] = convert_to_tv_tensor(target["masks"], key="masks")
71
+
72
+ return image, target
73
+
74
+ def extra_repr(self) -> str:
75
+ s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n"
76
+ s += f" return_masks: {self.return_masks}\n"
77
+ if hasattr(self, "_transforms") and self._transforms is not None:
78
+ s += f" transforms:\n {repr(self._transforms)}"
79
+ if hasattr(self, "_preset") and self._preset is not None:
80
+ s += f" preset:\n {repr(self._preset)}"
81
+ return s
82
+
83
+ @property
84
+ def categories(
85
+ self,
86
+ ):
87
+ return self.coco.dataset["categories"]
88
+
89
+ @property
90
+ def category2name(
91
+ self,
92
+ ):
93
+ return {cat["id"]: cat["name"] for cat in self.categories}
94
+
95
+ @property
96
+ def category2label(
97
+ self,
98
+ ):
99
+ return {cat["id"]: i for i, cat in enumerate(self.categories)}
100
+
101
+ @property
102
+ def label2category(
103
+ self,
104
+ ):
105
+ return {i: cat["id"] for i, cat in enumerate(self.categories)}
106
+
107
+
108
+ def convert_coco_poly_to_mask(segmentations, height, width):
109
+ masks = []
110
+ for polygons in segmentations:
111
+ rles = coco_mask.frPyObjects(polygons, height, width)
112
+ mask = coco_mask.decode(rles)
113
+ if len(mask.shape) < 3:
114
+ mask = mask[..., None]
115
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
116
+ mask = mask.any(dim=2)
117
+ masks.append(mask)
118
+ if masks:
119
+ masks = torch.stack(masks, dim=0)
120
+ else:
121
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
122
+ return masks
123
+
124
+
125
+ class ConvertCocoPolysToMask(object):
126
+ def __init__(self, return_masks=False):
127
+ self.return_masks = return_masks
128
+
129
+ def __call__(self, image: Image.Image, target, **kwargs):
130
+ w, h = image.size
131
+
132
+ image_id = target["image_id"]
133
+ image_id = torch.tensor([image_id])
134
+
135
+ image_path = target["image_path"]
136
+
137
+ anno = target["annotations"]
138
+
139
+ anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
140
+
141
+ boxes = [obj["bbox"] for obj in anno]
142
+ # guard against no boxes via resizing
143
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
144
+ boxes[:, 2:] += boxes[:, :2]
145
+ boxes[:, 0::2].clamp_(min=0, max=w)
146
+ boxes[:, 1::2].clamp_(min=0, max=h)
147
+
148
+ category2label = kwargs.get("category2label", None)
149
+ if category2label is not None:
150
+ labels = [category2label[obj["category_id"]] for obj in anno]
151
+ else:
152
+ labels = [obj["category_id"] for obj in anno]
153
+
154
+ labels = torch.tensor(labels, dtype=torch.int64)
155
+
156
+ if self.return_masks:
157
+ segmentations = [obj["segmentation"] for obj in anno]
158
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
159
+
160
+ keypoints = None
161
+ if anno and "keypoints" in anno[0]:
162
+ keypoints = [obj["keypoints"] for obj in anno]
163
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
164
+ num_keypoints = keypoints.shape[0]
165
+ if num_keypoints:
166
+ keypoints = keypoints.view(num_keypoints, -1, 3)
167
+
168
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
169
+ boxes = boxes[keep]
170
+ labels = labels[keep]
171
+ if self.return_masks:
172
+ masks = masks[keep]
173
+ if keypoints is not None:
174
+ keypoints = keypoints[keep]
175
+
176
+ target = {}
177
+ target["boxes"] = boxes
178
+ target["labels"] = labels
179
+ if self.return_masks:
180
+ target["masks"] = masks
181
+ target["image_id"] = image_id
182
+ target["image_path"] = image_path
183
+ if keypoints is not None:
184
+ target["keypoints"] = keypoints
185
+
186
+ # for conversion to coco api
187
+ area = torch.tensor([obj["area"] for obj in anno])
188
+ iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
189
+ target["area"] = area[keep]
190
+ target["iscrowd"] = iscrowd[keep]
191
+
192
+ target["orig_size"] = torch.as_tensor([int(w), int(h)])
193
+ # target["size"] = torch.as_tensor([int(w), int(h)])
194
+
195
+ return image, target
196
+
197
+
198
+ mscoco_category2name = {
199
+ 1: "person",
200
+ 2: "bicycle",
201
+ 3: "car",
202
+ 4: "motorcycle",
203
+ 5: "airplane",
204
+ 6: "bus",
205
+ 7: "train",
206
+ 8: "truck",
207
+ 9: "boat",
208
+ 10: "traffic light",
209
+ 11: "fire hydrant",
210
+ 13: "stop sign",
211
+ 14: "parking meter",
212
+ 15: "bench",
213
+ 16: "bird",
214
+ 17: "cat",
215
+ 18: "dog",
216
+ 19: "horse",
217
+ 20: "sheep",
218
+ 21: "cow",
219
+ 22: "elephant",
220
+ 23: "bear",
221
+ 24: "zebra",
222
+ 25: "giraffe",
223
+ 27: "backpack",
224
+ 28: "umbrella",
225
+ 31: "handbag",
226
+ 32: "tie",
227
+ 33: "suitcase",
228
+ 34: "frisbee",
229
+ 35: "skis",
230
+ 36: "snowboard",
231
+ 37: "sports ball",
232
+ 38: "kite",
233
+ 39: "baseball bat",
234
+ 40: "baseball glove",
235
+ 41: "skateboard",
236
+ 42: "surfboard",
237
+ 43: "tennis racket",
238
+ 44: "bottle",
239
+ 46: "wine glass",
240
+ 47: "cup",
241
+ 48: "fork",
242
+ 49: "knife",
243
+ 50: "spoon",
244
+ 51: "bowl",
245
+ 52: "banana",
246
+ 53: "apple",
247
+ 54: "sandwich",
248
+ 55: "orange",
249
+ 56: "broccoli",
250
+ 57: "carrot",
251
+ 58: "hot dog",
252
+ 59: "pizza",
253
+ 60: "donut",
254
+ 61: "cake",
255
+ 62: "chair",
256
+ 63: "couch",
257
+ 64: "potted plant",
258
+ 65: "bed",
259
+ 67: "dining table",
260
+ 70: "toilet",
261
+ 72: "tv",
262
+ 73: "laptop",
263
+ 74: "mouse",
264
+ 75: "remote",
265
+ 76: "keyboard",
266
+ 77: "cell phone",
267
+ 78: "microwave",
268
+ 79: "oven",
269
+ 80: "toaster",
270
+ 81: "sink",
271
+ 82: "refrigerator",
272
+ 84: "book",
273
+ 85: "clock",
274
+ 86: "vase",
275
+ 87: "scissors",
276
+ 88: "teddy bear",
277
+ 89: "hair drier",
278
+ 90: "toothbrush",
279
+ }
280
+
281
+ mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
282
+ mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
src/data/dataset/coco_eval.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ COCO evaluator that works in distributed mode.
4
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
5
+ The difference is that there is less copy-pasting from pycocotools
6
+ in the end of the file, as python3 can suppress prints with contextlib
7
+ """
8
+
9
+ import contextlib
10
+ import copy
11
+ import os
12
+
13
+ import faster_coco_eval.core.mask as mask_util
14
+ import numpy as np
15
+ import torch
16
+ from faster_coco_eval import COCO, COCOeval_faster
17
+
18
+ from ...core import register
19
+ from ...misc import dist_utils
20
+
21
+ __all__ = [
22
+ "CocoEvaluator",
23
+ ]
24
+
25
+
26
+ @register()
27
+ class CocoEvaluator(object):
28
+ def __init__(self, coco_gt, iou_types):
29
+ assert isinstance(iou_types, (list, tuple))
30
+ coco_gt = copy.deepcopy(coco_gt)
31
+ self.coco_gt: COCO = coco_gt
32
+ self.iou_types = iou_types
33
+
34
+ self.coco_eval = {}
35
+ for iou_type in iou_types:
36
+ self.coco_eval[iou_type] = COCOeval_faster(
37
+ coco_gt, iouType=iou_type, print_function=print, separate_eval=True
38
+ )
39
+
40
+ self.img_ids = []
41
+ self.eval_imgs = {k: [] for k in iou_types}
42
+
43
+ def cleanup(self):
44
+ self.coco_eval = {}
45
+ for iou_type in self.iou_types:
46
+ self.coco_eval[iou_type] = COCOeval_faster(
47
+ self.coco_gt, iouType=iou_type, print_function=print, separate_eval=True
48
+ )
49
+ self.img_ids = []
50
+ self.eval_imgs = {k: [] for k in self.iou_types}
51
+
52
+ def update(self, predictions):
53
+ img_ids = list(np.unique(list(predictions.keys())))
54
+ self.img_ids.extend(img_ids)
55
+
56
+ for iou_type in self.iou_types:
57
+ results = self.prepare(predictions, iou_type)
58
+ coco_eval = self.coco_eval[iou_type]
59
+
60
+ # suppress pycocotools prints
61
+ with open(os.devnull, "w") as devnull:
62
+ with contextlib.redirect_stdout(devnull):
63
+ coco_dt = self.coco_gt.loadRes(results) if results else COCO()
64
+ coco_eval.cocoDt = coco_dt
65
+ coco_eval.params.imgIds = list(img_ids)
66
+ coco_eval.evaluate()
67
+
68
+ self.eval_imgs[iou_type].append(
69
+ np.array(coco_eval._evalImgs_cpp).reshape(
70
+ len(coco_eval.params.catIds),
71
+ len(coco_eval.params.areaRng),
72
+ len(coco_eval.params.imgIds),
73
+ )
74
+ )
75
+
76
+ def synchronize_between_processes(self):
77
+ for iou_type in self.iou_types:
78
+ img_ids, eval_imgs = merge(self.img_ids, self.eval_imgs[iou_type])
79
+
80
+ coco_eval = self.coco_eval[iou_type]
81
+ coco_eval.params.imgIds = img_ids
82
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
83
+ coco_eval._evalImgs_cpp = eval_imgs
84
+
85
+ def accumulate(self):
86
+ for coco_eval in self.coco_eval.values():
87
+ coco_eval.accumulate()
88
+
89
+ def summarize(self):
90
+ for iou_type, coco_eval in self.coco_eval.items():
91
+ print("IoU metric: {}".format(iou_type))
92
+ coco_eval.summarize()
93
+
94
+ def prepare(self, predictions, iou_type):
95
+ if iou_type == "bbox":
96
+ return self.prepare_for_coco_detection(predictions)
97
+ elif iou_type == "segm":
98
+ return self.prepare_for_coco_segmentation(predictions)
99
+ elif iou_type == "keypoints":
100
+ return self.prepare_for_coco_keypoint(predictions)
101
+ else:
102
+ raise ValueError("Unknown iou type {}".format(iou_type))
103
+
104
+ def prepare_for_coco_detection(self, predictions):
105
+ coco_results = []
106
+ for original_id, prediction in predictions.items():
107
+ if len(prediction) == 0:
108
+ continue
109
+
110
+ boxes = prediction["boxes"]
111
+ boxes = convert_to_xywh(boxes).tolist()
112
+ scores = prediction["scores"].tolist()
113
+ labels = prediction["labels"].tolist()
114
+
115
+ coco_results.extend(
116
+ [
117
+ {
118
+ "image_id": original_id,
119
+ "category_id": labels[k],
120
+ "bbox": box,
121
+ "score": scores[k],
122
+ }
123
+ for k, box in enumerate(boxes)
124
+ ]
125
+ )
126
+ return coco_results
127
+
128
+ def prepare_for_coco_segmentation(self, predictions):
129
+ coco_results = []
130
+ for original_id, prediction in predictions.items():
131
+ if len(prediction) == 0:
132
+ continue
133
+
134
+ scores = prediction["scores"]
135
+ labels = prediction["labels"]
136
+ masks = prediction["masks"]
137
+
138
+ masks = masks > 0.5
139
+
140
+ scores = prediction["scores"].tolist()
141
+ labels = prediction["labels"].tolist()
142
+
143
+ rles = [
144
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
145
+ for mask in masks
146
+ ]
147
+ for rle in rles:
148
+ rle["counts"] = rle["counts"].decode("utf-8")
149
+
150
+ coco_results.extend(
151
+ [
152
+ {
153
+ "image_id": original_id,
154
+ "category_id": labels[k],
155
+ "segmentation": rle,
156
+ "score": scores[k],
157
+ }
158
+ for k, rle in enumerate(rles)
159
+ ]
160
+ )
161
+ return coco_results
162
+
163
+ def prepare_for_coco_keypoint(self, predictions):
164
+ coco_results = []
165
+ for original_id, prediction in predictions.items():
166
+ if len(prediction) == 0:
167
+ continue
168
+
169
+ boxes = prediction["boxes"]
170
+ boxes = convert_to_xywh(boxes).tolist()
171
+ scores = prediction["scores"].tolist()
172
+ labels = prediction["labels"].tolist()
173
+ keypoints = prediction["keypoints"]
174
+ keypoints = keypoints.flatten(start_dim=1).tolist()
175
+
176
+ coco_results.extend(
177
+ [
178
+ {
179
+ "image_id": original_id,
180
+ "category_id": labels[k],
181
+ "keypoints": keypoint,
182
+ "score": scores[k],
183
+ }
184
+ for k, keypoint in enumerate(keypoints)
185
+ ]
186
+ )
187
+ return coco_results
188
+
189
+
190
+ def convert_to_xywh(boxes):
191
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
192
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
193
+
194
+
195
+ def merge(img_ids, eval_imgs):
196
+ all_img_ids = dist_utils.all_gather(img_ids)
197
+ all_eval_imgs = dist_utils.all_gather(eval_imgs)
198
+
199
+ merged_img_ids = []
200
+ for p in all_img_ids:
201
+ merged_img_ids.extend(p)
202
+
203
+ merged_eval_imgs = []
204
+ for p in all_eval_imgs:
205
+ merged_eval_imgs.extend(p)
206
+
207
+ merged_img_ids = np.array(merged_img_ids)
208
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, axis=2).ravel()
209
+ # merged_eval_imgs = np.array(merged_eval_imgs).T.ravel()
210
+
211
+ # keep only unique (and in sorted order) images
212
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
213
+
214
+ return merged_img_ids.tolist(), merged_eval_imgs.tolist()
src/data/dataset/coco_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ copy and modified https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py
3
+
4
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
5
+ """
6
+
7
+ import faster_coco_eval.core.mask as coco_mask
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ import torchvision.transforms.functional as TVF
12
+ from faster_coco_eval import COCO
13
+
14
+
15
+ def convert_coco_poly_to_mask(segmentations, height, width):
16
+ masks = []
17
+ for polygons in segmentations:
18
+ rles = coco_mask.frPyObjects(polygons, height, width)
19
+ mask = coco_mask.decode(rles)
20
+ if len(mask.shape) < 3:
21
+ mask = mask[..., None]
22
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
23
+ mask = mask.any(dim=2)
24
+ masks.append(mask)
25
+ if masks:
26
+ masks = torch.stack(masks, dim=0)
27
+ else:
28
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
29
+ return masks
30
+
31
+
32
+ class ConvertCocoPolysToMask:
33
+ def __call__(self, image, target):
34
+ w, h = image.size
35
+
36
+ image_id = target["image_id"]
37
+
38
+ anno = target["annotations"]
39
+
40
+ anno = [obj for obj in anno if obj["iscrowd"] == 0]
41
+
42
+ boxes = [obj["bbox"] for obj in anno]
43
+ # guard against no boxes via resizing
44
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
45
+ boxes[:, 2:] += boxes[:, :2]
46
+ boxes[:, 0::2].clamp_(min=0, max=w)
47
+ boxes[:, 1::2].clamp_(min=0, max=h)
48
+
49
+ classes = [obj["category_id"] for obj in anno]
50
+ classes = torch.tensor(classes, dtype=torch.int64)
51
+
52
+ segmentations = [obj["segmentation"] for obj in anno]
53
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
54
+
55
+ keypoints = None
56
+ if anno and "keypoints" in anno[0]:
57
+ keypoints = [obj["keypoints"] for obj in anno]
58
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
59
+ num_keypoints = keypoints.shape[0]
60
+ if num_keypoints:
61
+ keypoints = keypoints.view(num_keypoints, -1, 3)
62
+
63
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
64
+ boxes = boxes[keep]
65
+ classes = classes[keep]
66
+ masks = masks[keep]
67
+ if keypoints is not None:
68
+ keypoints = keypoints[keep]
69
+
70
+ target = {}
71
+ target["boxes"] = boxes
72
+ target["labels"] = classes
73
+ target["masks"] = masks
74
+ target["image_id"] = image_id
75
+ if keypoints is not None:
76
+ target["keypoints"] = keypoints
77
+
78
+ # for conversion to coco api
79
+ area = torch.tensor([obj["area"] for obj in anno])
80
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
81
+ target["area"] = area
82
+ target["iscrowd"] = iscrowd
83
+
84
+ return image, target
85
+
86
+
87
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
88
+ def _has_only_empty_bbox(anno):
89
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
90
+
91
+ def _count_visible_keypoints(anno):
92
+ return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
93
+
94
+ min_keypoints_per_image = 10
95
+
96
+ def _has_valid_annotation(anno):
97
+ # if it's empty, there is no annotation
98
+ if len(anno) == 0:
99
+ return False
100
+ # if all boxes have close to zero area, there is no annotation
101
+ if _has_only_empty_bbox(anno):
102
+ return False
103
+ # keypoints task have a slight different criteria for considering
104
+ # if an annotation is valid
105
+ if "keypoints" not in anno[0]:
106
+ return True
107
+ # for keypoint detection tasks, only consider valid images those
108
+ # containing at least min_keypoints_per_image
109
+ if _count_visible_keypoints(anno) >= min_keypoints_per_image:
110
+ return True
111
+ return False
112
+
113
+ ids = []
114
+ for ds_idx, img_id in enumerate(dataset.ids):
115
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
116
+ anno = dataset.coco.loadAnns(ann_ids)
117
+ if cat_list:
118
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
119
+ if _has_valid_annotation(anno):
120
+ ids.append(ds_idx)
121
+
122
+ dataset = torch.utils.data.Subset(dataset, ids)
123
+ return dataset
124
+
125
+
126
+ def convert_to_coco_api(ds):
127
+ coco_ds = COCO()
128
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
129
+ ann_id = 1
130
+ dataset = {"images": [], "categories": [], "annotations": []}
131
+ categories = set()
132
+ for img_idx in range(len(ds)):
133
+ # find better way to get target
134
+ # targets = ds.get_annotations(img_idx)
135
+ # img, targets = ds[img_idx]
136
+
137
+ img, targets = ds.load_item(img_idx)
138
+ width, height = img.size
139
+
140
+ image_id = targets["image_id"].item()
141
+ img_dict = {}
142
+ img_dict["id"] = image_id
143
+ img_dict["width"] = width
144
+ img_dict["height"] = height
145
+ dataset["images"].append(img_dict)
146
+ bboxes = targets["boxes"].clone()
147
+ bboxes[:, 2:] -= bboxes[:, :2] # xyxy -> xywh
148
+ bboxes = bboxes.tolist()
149
+ labels = targets["labels"].tolist()
150
+ areas = targets["area"].tolist()
151
+ iscrowd = targets["iscrowd"].tolist()
152
+ if "masks" in targets:
153
+ masks = targets["masks"]
154
+ # make masks Fortran contiguous for coco_mask
155
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
156
+ if "keypoints" in targets:
157
+ keypoints = targets["keypoints"]
158
+ keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
159
+ num_objs = len(bboxes)
160
+ for i in range(num_objs):
161
+ ann = {}
162
+ ann["image_id"] = image_id
163
+ ann["bbox"] = bboxes[i]
164
+ ann["category_id"] = labels[i]
165
+ categories.add(labels[i])
166
+ ann["area"] = areas[i]
167
+ ann["iscrowd"] = iscrowd[i]
168
+ ann["id"] = ann_id
169
+ if "masks" in targets:
170
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
171
+ if "keypoints" in targets:
172
+ ann["keypoints"] = keypoints[i]
173
+ ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
174
+ dataset["annotations"].append(ann)
175
+ ann_id += 1
176
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
177
+ coco_ds.dataset = dataset
178
+ coco_ds.createIndex()
179
+ return coco_ds
180
+
181
+
182
+ def get_coco_api_from_dataset(dataset):
183
+ # FIXME: This is... awful?
184
+ for _ in range(10):
185
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
186
+ break
187
+ if isinstance(dataset, torch.utils.data.Subset):
188
+ dataset = dataset.dataset
189
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
190
+ return dataset.coco
191
+ return convert_to_coco_api(dataset)
src/data/dataset/voc_detection.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+
9
+ import torch
10
+ import torchvision
11
+ import torchvision.transforms.functional as TVF
12
+ from PIL import Image
13
+ from sympy import im
14
+
15
+ try:
16
+ from defusedxml.ElementTree import parse as ET_parse
17
+ except ImportError:
18
+ from xml.etree.ElementTree import parse as ET_parse
19
+
20
+ from ...core import register
21
+ from .._misc import convert_to_tv_tensor
22
+ from ._dataset import DetDataset
23
+
24
+
25
+ @register()
26
+ class VOCDetection(torchvision.datasets.VOCDetection, DetDataset):
27
+ __inject__ = [
28
+ "transforms",
29
+ ]
30
+
31
+ def __init__(
32
+ self,
33
+ root: str,
34
+ ann_file: str = "trainval.txt",
35
+ label_file: str = "label_list.txt",
36
+ transforms: Optional[Callable] = None,
37
+ ):
38
+ with open(os.path.join(root, ann_file), "r") as f:
39
+ lines = [x.strip() for x in f.readlines()]
40
+ lines = [x.split(" ") for x in lines]
41
+
42
+ self.images = [os.path.join(root, lin[0]) for lin in lines]
43
+ self.targets = [os.path.join(root, lin[1]) for lin in lines]
44
+ assert len(self.images) == len(self.targets)
45
+
46
+ with open(os.path.join(root + label_file), "r") as f:
47
+ labels = f.readlines()
48
+ labels = [lab.strip() for lab in labels]
49
+
50
+ self.transforms = transforms
51
+ self.labels_map = {lab: i for i, lab in enumerate(labels)}
52
+
53
+ def __getitem__(self, index: int):
54
+ image, target = self.load_item(index)
55
+ if self.transforms is not None:
56
+ image, target, _ = self.transforms(image, target, self)
57
+ # target["orig_size"] = torch.tensor(TVF.get_image_size(image))
58
+ return image, target
59
+
60
+ def load_item(self, index: int):
61
+ image = Image.open(self.images[index]).convert("RGB")
62
+ target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
63
+
64
+ output = {}
65
+ output["image_id"] = torch.tensor([index])
66
+ for k in ["area", "boxes", "labels", "iscrowd"]:
67
+ output[k] = []
68
+
69
+ for blob in target["annotation"]["object"]:
70
+ box = [float(v) for v in blob["bndbox"].values()]
71
+ output["boxes"].append(box)
72
+ output["labels"].append(blob["name"])
73
+ output["area"].append((box[2] - box[0]) * (box[3] - box[1]))
74
+ output["iscrowd"].append(0)
75
+
76
+ w, h = image.size
77
+ boxes = torch.tensor(output["boxes"]) if len(output["boxes"]) > 0 else torch.zeros(0, 4)
78
+ output["boxes"] = convert_to_tv_tensor(
79
+ boxes, "boxes", box_format="xyxy", spatial_size=[h, w]
80
+ )
81
+ output["labels"] = torch.tensor([self.labels_map[lab] for lab in output["labels"]])
82
+ output["area"] = torch.tensor(output["area"])
83
+ output["iscrowd"] = torch.tensor(output["iscrowd"])
84
+ output["orig_size"] = torch.tensor([w, h])
85
+
86
+ return image, output
src/data/dataset/voc_eval.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+
10
+ class VOCEvaluator(object):
11
+ def __init__(self) -> None:
12
+ pass
src/data/transforms/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from ._transforms import (
7
+ ConvertBoxes,
8
+ ConvertPILImage,
9
+ EmptyTransform,
10
+ Normalize,
11
+ PadToSize,
12
+ RandomCrop,
13
+ RandomHorizontalFlip,
14
+ RandomIoUCrop,
15
+ RandomPhotometricDistort,
16
+ RandomZoomOut,
17
+ Resize,
18
+ SanitizeBoundingBoxes,
19
+ )
20
+ from .container import Compose
21
+ from .mosaic import Mosaic
src/data/transforms/_transforms.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import PIL
9
+ import PIL.Image
10
+ import torch
11
+ import torch.nn as nn
12
+ import torchvision
13
+ import torchvision.transforms.v2 as T
14
+ import torchvision.transforms.v2.functional as F
15
+
16
+ from ...core import register
17
+ from .._misc import (
18
+ BoundingBoxes,
19
+ Image,
20
+ Mask,
21
+ SanitizeBoundingBoxes,
22
+ Video,
23
+ _boxes_keys,
24
+ convert_to_tv_tensor,
25
+ )
26
+
27
+ torchvision.disable_beta_transforms_warning()
28
+
29
+
30
+ RandomPhotometricDistort = register()(T.RandomPhotometricDistort)
31
+ RandomZoomOut = register()(T.RandomZoomOut)
32
+ RandomHorizontalFlip = register()(T.RandomHorizontalFlip)
33
+ Resize = register()(T.Resize)
34
+ # ToImageTensor = register()(T.ToImageTensor)
35
+ # ConvertDtype = register()(T.ConvertDtype)
36
+ # PILToTensor = register()(T.PILToTensor)
37
+ SanitizeBoundingBoxes = register(name="SanitizeBoundingBoxes")(SanitizeBoundingBoxes)
38
+ RandomCrop = register()(T.RandomCrop)
39
+ Normalize = register()(T.Normalize)
40
+
41
+
42
+ @register()
43
+ class EmptyTransform(T.Transform):
44
+ def __init__(
45
+ self,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ def forward(self, *inputs):
50
+ inputs = inputs if len(inputs) > 1 else inputs[0]
51
+ return inputs
52
+
53
+
54
+ @register()
55
+ class PadToSize(T.Pad):
56
+ _transformed_types = (
57
+ PIL.Image.Image,
58
+ Image,
59
+ Video,
60
+ Mask,
61
+ BoundingBoxes,
62
+ )
63
+
64
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
65
+ sp = F.get_spatial_size(flat_inputs[0])
66
+ h, w = self.size[1] - sp[0], self.size[0] - sp[1]
67
+ self.padding = [0, 0, w, h]
68
+ return dict(padding=self.padding)
69
+
70
+ def __init__(self, size, fill=0, padding_mode="constant") -> None:
71
+ if isinstance(size, int):
72
+ size = (size, size)
73
+ self.size = size
74
+ super().__init__(0, fill, padding_mode)
75
+
76
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
77
+ fill = self._fill[type(inpt)]
78
+ padding = params["padding"]
79
+ return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
80
+
81
+ def __call__(self, *inputs: Any) -> Any:
82
+ outputs = super().forward(*inputs)
83
+ if len(outputs) > 1 and isinstance(outputs[1], dict):
84
+ outputs[1]["padding"] = torch.tensor(self.padding)
85
+ return outputs
86
+
87
+
88
+ @register()
89
+ class RandomIoUCrop(T.RandomIoUCrop):
90
+ def __init__(
91
+ self,
92
+ min_scale: float = 0.3,
93
+ max_scale: float = 1,
94
+ min_aspect_ratio: float = 0.5,
95
+ max_aspect_ratio: float = 2,
96
+ sampler_options: Optional[List[float]] = None,
97
+ trials: int = 40,
98
+ p: float = 1.0,
99
+ ):
100
+ super().__init__(
101
+ min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials
102
+ )
103
+ self.p = p
104
+
105
+ def __call__(self, *inputs: Any) -> Any:
106
+ if torch.rand(1) >= self.p:
107
+ return inputs if len(inputs) > 1 else inputs[0]
108
+
109
+ return super().forward(*inputs)
110
+
111
+
112
+ @register()
113
+ class ConvertBoxes(T.Transform):
114
+ _transformed_types = (BoundingBoxes,)
115
+
116
+ def __init__(self, fmt="", normalize=False) -> None:
117
+ super().__init__()
118
+ self.fmt = fmt
119
+ self.normalize = normalize
120
+
121
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
122
+ return self._transform(inpt, params)
123
+
124
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
125
+ spatial_size = getattr(inpt, _boxes_keys[1])
126
+ if self.fmt:
127
+ in_fmt = inpt.format.value.lower()
128
+ inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.fmt.lower())
129
+ inpt = convert_to_tv_tensor(
130
+ inpt, key="boxes", box_format=self.fmt.upper(), spatial_size=spatial_size
131
+ )
132
+
133
+ if self.normalize:
134
+ inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None]
135
+
136
+ return inpt
137
+
138
+
139
+ @register()
140
+ class ConvertPILImage(T.Transform):
141
+ _transformed_types = (PIL.Image.Image,)
142
+
143
+ def __init__(self, dtype="float32", scale=True) -> None:
144
+ super().__init__()
145
+ self.dtype = dtype
146
+ self.scale = scale
147
+
148
+ def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
149
+ return self._transform(inpt, params)
150
+
151
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
152
+ inpt = F.pil_to_tensor(inpt)
153
+ if self.dtype == "float32":
154
+ inpt = inpt.float()
155
+
156
+ if self.scale:
157
+ inpt = inpt / 255.0
158
+
159
+ inpt = Image(inpt)
160
+
161
+ return inpt
src/data/transforms/container.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision
11
+ import torchvision.transforms.v2 as T
12
+
13
+ from ...core import GLOBAL_CONFIG, register
14
+ from ._transforms import EmptyTransform
15
+
16
+ torchvision.disable_beta_transforms_warning()
17
+
18
+
19
+ @register()
20
+ class Compose(T.Compose):
21
+ def __init__(self, ops, policy=None) -> None:
22
+ transforms = []
23
+ if ops is not None:
24
+ for op in ops:
25
+ if isinstance(op, dict):
26
+ name = op.pop("type")
27
+ transform = getattr(
28
+ GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"]
29
+ )(**op)
30
+ transforms.append(transform)
31
+ op["type"] = name
32
+
33
+ elif isinstance(op, nn.Module):
34
+ transforms.append(op)
35
+
36
+ else:
37
+ raise ValueError("")
38
+ else:
39
+ transforms = [
40
+ EmptyTransform(),
41
+ ]
42
+
43
+ super().__init__(transforms=transforms)
44
+
45
+ if policy is None:
46
+ policy = {"name": "default"}
47
+
48
+ self.policy = policy
49
+ self.global_samples = 0
50
+
51
+ def forward(self, *inputs: Any) -> Any:
52
+ return self.get_forward(self.policy["name"])(*inputs)
53
+
54
+ def get_forward(self, name):
55
+ forwards = {
56
+ "default": self.default_forward,
57
+ "stop_epoch": self.stop_epoch_forward,
58
+ "stop_sample": self.stop_sample_forward,
59
+ }
60
+ return forwards[name]
61
+
62
+ def default_forward(self, *inputs: Any) -> Any:
63
+ sample = inputs if len(inputs) > 1 else inputs[0]
64
+ for transform in self.transforms:
65
+ sample = transform(sample)
66
+ return sample
67
+
68
+ def stop_epoch_forward(self, *inputs: Any):
69
+ sample = inputs if len(inputs) > 1 else inputs[0]
70
+ dataset = sample[-1]
71
+ cur_epoch = dataset.epoch
72
+ policy_ops = self.policy["ops"]
73
+ policy_epoch = self.policy["epoch"]
74
+
75
+ for transform in self.transforms:
76
+ if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch:
77
+ pass
78
+ else:
79
+ sample = transform(sample)
80
+
81
+ return sample
82
+
83
+ def stop_sample_forward(self, *inputs: Any):
84
+ sample = inputs if len(inputs) > 1 else inputs[0]
85
+ dataset = sample[-1]
86
+
87
+ cur_epoch = dataset.epoch
88
+ policy_ops = self.policy["ops"]
89
+ policy_sample = self.policy["sample"]
90
+
91
+ for transform in self.transforms:
92
+ if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample:
93
+ pass
94
+ else:
95
+ sample = transform(sample)
96
+
97
+ self.global_samples += 1
98
+
99
+ return sample
src/data/transforms/functional.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
6
+ import torchvision
7
+ import torchvision.transforms.functional as F
8
+ from packaging import version
9
+ from torch import Tensor
10
+
11
+ if version.parse(torchvision.__version__) < version.parse("0.7"):
12
+ from torchvision.ops import _new_empty_tensor
13
+ from torchvision.ops.misc import _output_size
14
+
15
+
16
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
17
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
18
+ """
19
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
20
+ This will eventually be supported natively by PyTorch, and this
21
+ class can go away.
22
+ """
23
+ if version.parse(torchvision.__version__) < version.parse("0.7"):
24
+ if input.numel() > 0:
25
+ return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
26
+
27
+ output_shape = _output_size(2, input, size, scale_factor)
28
+ output_shape = list(input.shape[:-2]) + list(output_shape)
29
+ return _new_empty_tensor(input, output_shape)
30
+ else:
31
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
32
+
33
+
34
+ def crop(image, target, region):
35
+ cropped_image = F.crop(image, *region)
36
+
37
+ target = target.copy()
38
+ i, j, h, w = region
39
+
40
+ # should we do something wrt the original size?
41
+ target["size"] = torch.tensor([h, w])
42
+
43
+ fields = ["labels", "area", "iscrowd"]
44
+
45
+ if "boxes" in target:
46
+ boxes = target["boxes"]
47
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
48
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
49
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
50
+ cropped_boxes = cropped_boxes.clamp(min=0)
51
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
52
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
53
+ target["area"] = area
54
+ fields.append("boxes")
55
+
56
+ if "masks" in target:
57
+ # FIXME should we update the area here if there are no boxes?
58
+ target["masks"] = target["masks"][:, i : i + h, j : j + w]
59
+ fields.append("masks")
60
+
61
+ # remove elements for which the boxes or masks that have zero area
62
+ if "boxes" in target or "masks" in target:
63
+ # favor boxes selection when defining which elements to keep
64
+ # this is compatible with previous implementation
65
+ if "boxes" in target:
66
+ cropped_boxes = target["boxes"].reshape(-1, 2, 2)
67
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
68
+ else:
69
+ keep = target["masks"].flatten(1).any(1)
70
+
71
+ for field in fields:
72
+ target[field] = target[field][keep]
73
+
74
+ return cropped_image, target
75
+
76
+
77
+ def hflip(image, target):
78
+ flipped_image = F.hflip(image)
79
+
80
+ w, h = image.size
81
+
82
+ target = target.copy()
83
+ if "boxes" in target:
84
+ boxes = target["boxes"]
85
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
86
+ [w, 0, w, 0]
87
+ )
88
+ target["boxes"] = boxes
89
+
90
+ if "masks" in target:
91
+ target["masks"] = target["masks"].flip(-1)
92
+
93
+ return flipped_image, target
94
+
95
+
96
+ def resize(image, target, size, max_size=None):
97
+ # size can be min_size (scalar) or (w, h) tuple
98
+
99
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
100
+ w, h = image_size
101
+ if max_size is not None:
102
+ min_original_size = float(min((w, h)))
103
+ max_original_size = float(max((w, h)))
104
+ if max_original_size / min_original_size * size > max_size:
105
+ size = int(round(max_size * min_original_size / max_original_size))
106
+
107
+ if (w <= h and w == size) or (h <= w and h == size):
108
+ return (h, w)
109
+
110
+ if w < h:
111
+ ow = size
112
+ oh = int(size * h / w)
113
+ else:
114
+ oh = size
115
+ ow = int(size * w / h)
116
+
117
+ # r = min(size / min(h, w), max_size / max(h, w))
118
+ # ow = int(w * r)
119
+ # oh = int(h * r)
120
+
121
+ return (oh, ow)
122
+
123
+ def get_size(image_size, size, max_size=None):
124
+ if isinstance(size, (list, tuple)):
125
+ return size[::-1]
126
+ else:
127
+ return get_size_with_aspect_ratio(image_size, size, max_size)
128
+
129
+ size = get_size(image.size, size, max_size)
130
+ rescaled_image = F.resize(image, size)
131
+
132
+ if target is None:
133
+ return rescaled_image, None
134
+
135
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
136
+ ratio_width, ratio_height = ratios
137
+
138
+ target = target.copy()
139
+ if "boxes" in target:
140
+ boxes = target["boxes"]
141
+ scaled_boxes = boxes * torch.as_tensor(
142
+ [ratio_width, ratio_height, ratio_width, ratio_height]
143
+ )
144
+ target["boxes"] = scaled_boxes
145
+
146
+ if "area" in target:
147
+ area = target["area"]
148
+ scaled_area = area * (ratio_width * ratio_height)
149
+ target["area"] = scaled_area
150
+
151
+ h, w = size
152
+ target["size"] = torch.tensor([h, w])
153
+
154
+ if "masks" in target:
155
+ target["masks"] = (
156
+ interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
157
+ )
158
+
159
+ return rescaled_image, target
160
+
161
+
162
+ def pad(image, target, padding):
163
+ # assumes that we only pad on the bottom right corners
164
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
165
+ if target is None:
166
+ return padded_image, None
167
+ target = target.copy()
168
+ # should we do something wrt the original size?
169
+ target["size"] = torch.tensor(padded_image.size[::-1])
170
+ if "masks" in target:
171
+ target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
172
+ return padded_image, target
src/data/transforms/mosaic.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import random
7
+
8
+ import torch
9
+ import torchvision
10
+ import torchvision.transforms.v2 as T
11
+ import torchvision.transforms.v2.functional as F
12
+ from PIL import Image
13
+
14
+ from ...core import register
15
+ from .._misc import convert_to_tv_tensor
16
+
17
+ torchvision.disable_beta_transforms_warning()
18
+
19
+
20
+ @register()
21
+ class Mosaic(T.Transform):
22
+ def __init__(
23
+ self,
24
+ size,
25
+ max_size=None,
26
+ ) -> None:
27
+ super().__init__()
28
+ self.resize = T.Resize(size=size, max_size=max_size)
29
+ self.crop = T.RandomCrop(size=max_size if max_size else size)
30
+
31
+ # TODO add arg `output_size` for affine`
32
+ # self.random_perspective = T.RandomPerspective(distortion_scale=0.5, p=1., )
33
+ self.random_affine = T.RandomAffine(
34
+ degrees=0, translate=(0.1, 0.1), scale=(0.5, 1.5), fill=114
35
+ )
36
+
37
+ def forward(self, *inputs):
38
+ inputs = inputs if len(inputs) > 1 else inputs[0]
39
+ image, target, dataset = inputs
40
+
41
+ images = []
42
+ targets = []
43
+ indices = random.choices(range(len(dataset)), k=3)
44
+ for i in indices:
45
+ image, target = dataset.load_item(i)
46
+ image, target = self.resize(image, target)
47
+ images.append(image)
48
+ targets.append(target)
49
+
50
+ h, w = F.get_spatial_size(images[0])
51
+ offset = [[0, 0], [w, 0], [0, h], [w, h]]
52
+ image = Image.new(mode=images[0].mode, size=(w * 2, h * 2), color=0)
53
+ for i, im in enumerate(images):
54
+ image.paste(im, offset[i])
55
+
56
+ offset = torch.tensor([[0, 0], [w, 0], [0, h], [w, h]]).repeat(1, 2)
57
+ target = {}
58
+ for k in targets[0]:
59
+ if k == "boxes":
60
+ v = [t[k] + offset[i] for i, t in enumerate(targets)]
61
+ else:
62
+ v = [t[k] for t in targets]
63
+
64
+ if isinstance(v[0], torch.Tensor):
65
+ v = torch.cat(v, dim=0)
66
+
67
+ target[k] = v
68
+
69
+ if "boxes" in target:
70
+ # target['boxes'] = target['boxes'].clamp(0, 640 * 2 - 1)
71
+ w, h = image.size
72
+ target["boxes"] = convert_to_tv_tensor(
73
+ target["boxes"], "boxes", box_format="xyxy", spatial_size=[h, w]
74
+ )
75
+
76
+ if "masks" in target:
77
+ target["masks"] = convert_to_tv_tensor(target["masks"], "masks")
78
+
79
+ image, target = self.random_affine(image, target)
80
+ # image, target = self.resize(image, target)
81
+ image, target = self.crop(image, target)
82
+
83
+ return image, target, dataset
src/data/transforms/presets.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
src/misc/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from .dist_utils import setup_print, setup_seed
7
+ from .logger import *
8
+ from .profiler_utils import stats
9
+ from .visualizer import *
src/misc/box_ops.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from typing import List, Tuple
7
+
8
+ import torch
9
+ import torchvision
10
+ from torch import Tensor
11
+
12
+
13
+ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
14
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
15
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
16
+ return torchvision.ops.generalized_box_iou(boxes1, boxes2)
17
+
18
+
19
+ # elementwise
20
+ def elementwise_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
21
+ """
22
+ Args:
23
+ boxes1, [N, 4]
24
+ boxes2, [N, 4]
25
+ Returns:
26
+ iou, [N, ]
27
+ union, [N, ]
28
+ """
29
+ area1 = torchvision.ops.box_area(boxes1) # [N, ]
30
+ area2 = torchvision.ops.box_area(boxes2) # [N, ]
31
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N, 2]
32
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N, 2]
33
+ wh = (rb - lt).clamp(min=0) # [N, 2]
34
+ inter = wh[:, 0] * wh[:, 1] # [N, ]
35
+ union = area1 + area2 - inter
36
+ iou = inter / union
37
+ return iou, union
38
+
39
+
40
+ def elementwise_generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
41
+ """
42
+ Args:
43
+ boxes1, [N, 4] with [x1, y1, x2, y2]
44
+ boxes2, [N, 4] with [x1, y1, x2, y2]
45
+ Returns:
46
+ giou, [N, ]
47
+ """
48
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
49
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
50
+ iou, union = elementwise_box_iou(boxes1, boxes2)
51
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2]) # [N, 2]
52
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) # [N, 2]
53
+ wh = (rb - lt).clamp(min=0) # [N, 2]
54
+ area = wh[:, 0] * wh[:, 1]
55
+ return iou - (area - union) / area
56
+
57
+
58
+ def check_point_inside_box(points: Tensor, boxes: Tensor, eps=1e-9) -> Tensor:
59
+ """
60
+ Args:
61
+ points, [K, 2], (x, y)
62
+ boxes, [N, 4], (x1, y1, y2, y2)
63
+ Returns:
64
+ Tensor (bool), [K, N]
65
+ """
66
+ x, y = [p.unsqueeze(-1) for p in points.unbind(-1)]
67
+ x1, y1, x2, y2 = [x.unsqueeze(0) for x in boxes.unbind(-1)]
68
+
69
+ l = x - x1
70
+ t = y - y1
71
+ r = x2 - x
72
+ b = y2 - y
73
+
74
+ ltrb = torch.stack([l, t, r, b], dim=-1)
75
+ mask = ltrb.min(dim=-1).values > eps
76
+
77
+ return mask
78
+
79
+
80
+ def point_box_distance(points: Tensor, boxes: Tensor) -> Tensor:
81
+ """
82
+ Args:
83
+ boxes, [N, 4], (x1, y1, x2, y2)
84
+ points, [N, 2], (x, y)
85
+ Returns:
86
+ Tensor (N, 4), (l, t, r, b)
87
+ """
88
+ x1y1, x2y2 = torch.split(boxes, 2, dim=-1)
89
+ lt = points - x1y1
90
+ rb = x2y2 - points
91
+ return torch.concat([lt, rb], dim=-1)
92
+
93
+
94
+ def point_distance_box(points: Tensor, distances: Tensor) -> Tensor:
95
+ """
96
+ Args:
97
+ points (Tensor), [N, 2], (x, y)
98
+ distances (Tensor), [N, 4], (l, t, r, b)
99
+ Returns:
100
+ boxes (Tensor), (N, 4), (x1, y1, x2, y2)
101
+ """
102
+ lt, rb = torch.split(distances, 2, dim=-1)
103
+ x1y1 = -lt + points
104
+ x2y2 = rb + points
105
+ boxes = torch.concat([x1y1, x2y2], dim=-1)
106
+ return boxes
src/misc/dist_utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reference
3
+ - https://github.com/pytorch/vision/blob/main/references/detection/utils.py
4
+ - https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406
5
+
6
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
7
+ """
8
+
9
+ import atexit
10
+ import os
11
+ import random
12
+ import time
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.backends.cudnn
17
+ import torch.distributed
18
+ import torch.nn as nn
19
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20
+ from torch.nn.parallel import DataParallel as DP
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from torch.utils.data import DistributedSampler
23
+
24
+ # from torch.utils.data.dataloader import DataLoader
25
+ from ..data import DataLoader
26
+
27
+
28
+ def setup_distributed(
29
+ print_rank: int = 0,
30
+ print_method: str = "builtin",
31
+ seed: int = None,
32
+ ):
33
+ """
34
+ env setup
35
+ args:
36
+ print_rank,
37
+ print_method, (builtin, rich)
38
+ seed,
39
+ """
40
+ try:
41
+ # https://pytorch.org/docs/stable/elastic/run.html
42
+ RANK = int(os.getenv("RANK", -1))
43
+ LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))
44
+ WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
45
+
46
+ # torch.distributed.init_process_group(backend=backend, init_method='env://')
47
+ torch.distributed.init_process_group(init_method="env://")
48
+ torch.distributed.barrier()
49
+
50
+ rank = torch.distributed.get_rank()
51
+ torch.cuda.set_device(rank)
52
+ torch.cuda.empty_cache()
53
+ enabled_dist = True
54
+ if get_rank() == print_rank:
55
+ print("Initialized distributed mode...")
56
+
57
+ except Exception:
58
+ enabled_dist = False
59
+ print("Not init distributed mode.")
60
+
61
+ setup_print(get_rank() == print_rank, method=print_method)
62
+ if seed is not None:
63
+ setup_seed(seed)
64
+
65
+ return enabled_dist
66
+
67
+
68
+ def setup_print(is_main, method="builtin"):
69
+ """This function disables printing when not in master process"""
70
+ import builtins as __builtin__
71
+
72
+ if method == "builtin":
73
+ builtin_print = __builtin__.print
74
+
75
+ elif method == "rich":
76
+ import rich
77
+
78
+ builtin_print = rich.print
79
+
80
+ else:
81
+ raise AttributeError("")
82
+
83
+ def print(*args, **kwargs):
84
+ force = kwargs.pop("force", False)
85
+ if is_main or force:
86
+ builtin_print(*args, **kwargs)
87
+
88
+ __builtin__.print = print
89
+
90
+
91
+ def is_dist_available_and_initialized():
92
+ if not torch.distributed.is_available():
93
+ return False
94
+ if not torch.distributed.is_initialized():
95
+ return False
96
+ return True
97
+
98
+
99
+ @atexit.register
100
+ def cleanup():
101
+ """cleanup distributed environment"""
102
+ if is_dist_available_and_initialized():
103
+ torch.distributed.barrier()
104
+ torch.distributed.destroy_process_group()
105
+
106
+
107
+ def get_rank():
108
+ if not is_dist_available_and_initialized():
109
+ return 0
110
+ return torch.distributed.get_rank()
111
+
112
+
113
+ def get_world_size():
114
+ if not is_dist_available_and_initialized():
115
+ return 1
116
+ return torch.distributed.get_world_size()
117
+
118
+
119
+ def is_main_process():
120
+ return get_rank() == 0
121
+
122
+
123
+ def save_on_master(*args, **kwargs):
124
+ if is_main_process():
125
+ torch.save(*args, **kwargs)
126
+
127
+
128
+ def warp_model(
129
+ model: torch.nn.Module,
130
+ sync_bn: bool = False,
131
+ dist_mode: str = "ddp",
132
+ find_unused_parameters: bool = False,
133
+ compile: bool = False,
134
+ compile_mode: str = "reduce-overhead",
135
+ **kwargs,
136
+ ):
137
+ if is_dist_available_and_initialized():
138
+ rank = get_rank()
139
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model
140
+ if dist_mode == "dp":
141
+ model = DP(model, device_ids=[rank], output_device=rank)
142
+ elif dist_mode == "ddp":
143
+ model = DDP(
144
+ model,
145
+ device_ids=[rank],
146
+ output_device=rank,
147
+ find_unused_parameters=find_unused_parameters,
148
+ )
149
+ else:
150
+ raise AttributeError("")
151
+
152
+ if compile:
153
+ model = torch.compile(model, mode=compile_mode)
154
+
155
+ return model
156
+
157
+
158
+ def de_model(model):
159
+ return de_parallel(de_complie(model))
160
+
161
+
162
+ def warp_loader(loader, shuffle=False):
163
+ if is_dist_available_and_initialized():
164
+ sampler = DistributedSampler(loader.dataset, shuffle=shuffle)
165
+ loader = DataLoader(
166
+ loader.dataset,
167
+ loader.batch_size,
168
+ sampler=sampler,
169
+ drop_last=loader.drop_last,
170
+ collate_fn=loader.collate_fn,
171
+ pin_memory=loader.pin_memory,
172
+ num_workers=loader.num_workers,
173
+ )
174
+ return loader
175
+
176
+
177
+ def is_parallel(model) -> bool:
178
+ # Returns True if model is of type DP or DDP
179
+ return type(model) in (
180
+ torch.nn.parallel.DataParallel,
181
+ torch.nn.parallel.DistributedDataParallel,
182
+ )
183
+
184
+
185
+ def de_parallel(model) -> nn.Module:
186
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
187
+ return model.module if is_parallel(model) else model
188
+
189
+
190
+ def reduce_dict(data, avg=True):
191
+ """
192
+ Args
193
+ data dict: input, {k: v, ...}
194
+ avg bool: true
195
+ """
196
+ world_size = get_world_size()
197
+ if world_size < 2:
198
+ return data
199
+
200
+ with torch.no_grad():
201
+ keys, values = [], []
202
+ for k in sorted(data.keys()):
203
+ keys.append(k)
204
+ values.append(data[k])
205
+
206
+ values = torch.stack(values, dim=0)
207
+ torch.distributed.all_reduce(values)
208
+
209
+ if avg is True:
210
+ values /= world_size
211
+
212
+ return {k: v for k, v in zip(keys, values)}
213
+
214
+
215
+ def all_gather(data):
216
+ """
217
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
218
+ Args:
219
+ data: any picklable object
220
+ Returns:
221
+ list[data]: list of data gathered from each rank
222
+ """
223
+ world_size = get_world_size()
224
+ if world_size == 1:
225
+ return [data]
226
+ data_list = [None] * world_size
227
+ torch.distributed.all_gather_object(data_list, data)
228
+ return data_list
229
+
230
+
231
+ def sync_time():
232
+ """sync_time"""
233
+ if torch.cuda.is_available():
234
+ torch.cuda.synchronize()
235
+
236
+ return time.time()
237
+
238
+
239
+ def setup_seed(seed: int, deterministic=False):
240
+ """setup_seed for reproducibility
241
+ torch.manual_seed(3407) is all you need. https://arxiv.org/abs/2109.08203
242
+ """
243
+ seed = seed + get_rank()
244
+ random.seed(seed)
245
+ np.random.seed(seed)
246
+ torch.manual_seed(seed)
247
+
248
+ if torch.cuda.is_available():
249
+ torch.cuda.manual_seed_all(seed)
250
+
251
+ # memory will be large when setting deterministic to True
252
+ if torch.backends.cudnn.is_available() and deterministic:
253
+ torch.backends.cudnn.deterministic = True
254
+
255
+
256
+ # for torch.compile
257
+ def check_compile():
258
+ import warnings
259
+
260
+ import torch
261
+
262
+ gpu_ok = False
263
+ if torch.cuda.is_available():
264
+ device_cap = torch.cuda.get_device_capability()
265
+ if device_cap in ((7, 0), (8, 0), (9, 0)):
266
+ gpu_ok = True
267
+ if not gpu_ok:
268
+ warnings.warn(
269
+ "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " "than expected."
270
+ )
271
+ return gpu_ok
272
+
273
+
274
+ def is_compile(model):
275
+ import torch._dynamo
276
+
277
+ return type(model) in (torch._dynamo.OptimizedModule,)
278
+
279
+
280
+ def de_complie(model):
281
+ return model._orig_mod if is_compile(model) else model
src/misc/lazy_loader.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py
3
+ """
4
+
5
+ import importlib
6
+ import types
7
+
8
+
9
+ class LazyLoader(types.ModuleType):
10
+ """Lazily import a module, mainly to avoid pulling in large dependencies.
11
+
12
+ `paddle`, and `ffmpeg` are examples of modules that are large and not always
13
+ needed, and this allows them to only be loaded when they are used.
14
+ """
15
+
16
+ # The lint error here is incorrect.
17
+ def __init__(self, local_name, parent_module_globals, name, warning=None):
18
+ self._local_name = local_name
19
+ self._parent_module_globals = parent_module_globals
20
+ self._warning = warning
21
+
22
+ # These members allows doctest correctly process this module member without
23
+ # triggering self._load(). self._load() mutates parant_module_globals and
24
+ # triggers a dict mutated during iteration error from doctest.py.
25
+ # - for from_module()
26
+ self.__module__ = name.rsplit(".", 1)[0]
27
+ # - for is_routine()
28
+ self.__wrapped__ = None
29
+
30
+ super(LazyLoader, self).__init__(name)
31
+
32
+ def _load(self):
33
+ """Load the module and insert it into the parent's globals."""
34
+ # Import the target module and insert it into the parent's namespace
35
+ module = importlib.import_module(self.__name__)
36
+ self._parent_module_globals[self._local_name] = module
37
+
38
+ # Emit a warning if one was specified
39
+ if self._warning:
40
+ # logging.warning(self._warning)
41
+ # Make sure to only warn once.
42
+ self._warning = None
43
+
44
+ # Update this object's dict so that if someone keeps a reference to the
45
+ # LazyLoader, lookups are efficient (__getattr__ is only called on lookups
46
+ # that fail).
47
+ self.__dict__.update(module.__dict__)
48
+
49
+ return module
50
+
51
+ def __getattr__(self, item):
52
+ module = self._load()
53
+ return getattr(module, item)
54
+
55
+ def __repr__(self):
56
+ # Carefully to not trigger _load, since repr may be called in very
57
+ # sensitive places.
58
+ return f"<LazyLoader {self.__name__} as {self._local_name}>"
59
+
60
+ def __dir__(self):
61
+ module = self._load()
62
+ return dir(module)
63
+
64
+
65
+ # import paddle.nn as nn
66
+ # nn = LazyLoader("nn", globals(), "paddle.nn")
67
+
68
+ # class M(nn.Layer):
69
+ # def __init__(self) -> None:
70
+ # super().__init__()
src/misc/logger.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ https://github.com/facebookresearch/detr/blob/main/util/misc.py
4
+ Mostly copy-paste from torchvision references.
5
+ """
6
+
7
+ import datetime
8
+ import pickle
9
+ import time
10
+ from collections import defaultdict, deque
11
+ from typing import Dict
12
+
13
+ import torch
14
+ import torch.distributed as tdist
15
+
16
+ from .dist_utils import get_world_size, is_dist_available_and_initialized
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not is_dist_available_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ tdist.barrier()
45
+ tdist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ def all_gather(data):
83
+ """
84
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
85
+ Args:
86
+ data: any picklable object
87
+ Returns:
88
+ list[data]: list of data gathered from each rank
89
+ """
90
+ world_size = get_world_size()
91
+ if world_size == 1:
92
+ return [data]
93
+
94
+ # serialized to a Tensor
95
+ buffer = pickle.dumps(data)
96
+ storage = torch.ByteStorage.from_buffer(buffer)
97
+ tensor = torch.ByteTensor(storage).to("cuda")
98
+
99
+ # obtain Tensor size of each rank
100
+ local_size = torch.tensor([tensor.numel()], device="cuda")
101
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
102
+ tdist.all_gather(size_list, local_size)
103
+ size_list = [int(size.item()) for size in size_list]
104
+ max_size = max(size_list)
105
+
106
+ # receiving Tensor from all ranks
107
+ # we pad the tensor because torch all_gather does not support
108
+ # gathering tensors of different shapes
109
+ tensor_list = []
110
+ for _ in size_list:
111
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
112
+ if local_size != max_size:
113
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
114
+ tensor = torch.cat((tensor, padding), dim=0)
115
+ tdist.all_gather(tensor_list, tensor)
116
+
117
+ data_list = []
118
+ for size, tensor in zip(size_list, tensor_list):
119
+ buffer = tensor.cpu().numpy().tobytes()[:size]
120
+ data_list.append(pickle.loads(buffer))
121
+
122
+ return data_list
123
+
124
+
125
+ def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]:
126
+ """
127
+ Args:
128
+ input_dict (dict): all the values will be reduced
129
+ average (bool): whether to do average or sum
130
+ Reduce the values in the dictionary from all processes so that all processes
131
+ have the averaged results. Returns a dict with the same fields as
132
+ input_dict, after reduction.
133
+ """
134
+ world_size = get_world_size()
135
+ if world_size < 2:
136
+ return input_dict
137
+ with torch.no_grad():
138
+ names = []
139
+ values = []
140
+ # sort the keys so that they are consistent across processes
141
+ for k in sorted(input_dict.keys()):
142
+ names.append(k)
143
+ values.append(input_dict[k])
144
+ values = torch.stack(values, dim=0)
145
+ tdist.all_reduce(values)
146
+ if average:
147
+ values /= world_size
148
+ reduced_dict = {k: v for k, v in zip(names, values)}
149
+ return reduced_dict
150
+
151
+
152
+ class MetricLogger(object):
153
+ def __init__(self, delimiter="\t"):
154
+ self.meters = defaultdict(SmoothedValue)
155
+ self.delimiter = delimiter
156
+
157
+ def update(self, **kwargs):
158
+ for k, v in kwargs.items():
159
+ if isinstance(v, torch.Tensor):
160
+ v = v.item()
161
+ assert isinstance(v, (float, int))
162
+ self.meters[k].update(v)
163
+
164
+ def __getattr__(self, attr):
165
+ if attr in self.meters:
166
+ return self.meters[attr]
167
+ if attr in self.__dict__:
168
+ return self.__dict__[attr]
169
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
170
+
171
+ def __str__(self):
172
+ loss_str = []
173
+ for name, meter in self.meters.items():
174
+ loss_str.append("{}: {}".format(name, str(meter)))
175
+ return self.delimiter.join(loss_str)
176
+
177
+ def synchronize_between_processes(self):
178
+ for meter in self.meters.values():
179
+ meter.synchronize_between_processes()
180
+
181
+ def add_meter(self, name, meter):
182
+ self.meters[name] = meter
183
+
184
+ def log_every(self, iterable, print_freq, header=None):
185
+ i = 0
186
+ if not header:
187
+ header = ""
188
+ start_time = time.time()
189
+ end = time.time()
190
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
191
+ data_time = SmoothedValue(fmt="{avg:.4f}")
192
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
193
+ if torch.cuda.is_available():
194
+ log_msg = self.delimiter.join(
195
+ [
196
+ header,
197
+ "[{0" + space_fmt + "}/{1}]",
198
+ "eta: {eta}",
199
+ "{meters}",
200
+ "time: {time}",
201
+ "data: {data}",
202
+ "max mem: {memory:.0f}",
203
+ ]
204
+ )
205
+ else:
206
+ log_msg = self.delimiter.join(
207
+ [
208
+ header,
209
+ "[{0" + space_fmt + "}/{1}]",
210
+ "eta: {eta}",
211
+ "{meters}",
212
+ "time: {time}",
213
+ "data: {data}",
214
+ ]
215
+ )
216
+ MB = 1024.0 * 1024.0
217
+ for obj in iterable:
218
+ data_time.update(time.time() - end)
219
+ yield obj
220
+ iter_time.update(time.time() - end)
221
+ if i % print_freq == 0 or i == len(iterable) - 1:
222
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
223
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
224
+ if torch.cuda.is_available():
225
+ print(
226
+ log_msg.format(
227
+ i,
228
+ len(iterable),
229
+ eta=eta_string,
230
+ meters=str(self),
231
+ time=str(iter_time),
232
+ data=str(data_time),
233
+ memory=torch.cuda.max_memory_allocated() / MB,
234
+ )
235
+ )
236
+ else:
237
+ print(
238
+ log_msg.format(
239
+ i,
240
+ len(iterable),
241
+ eta=eta_string,
242
+ meters=str(self),
243
+ time=str(iter_time),
244
+ data=str(data_time),
245
+ )
246
+ )
247
+ i += 1
248
+ end = time.time()
249
+ total_time = time.time() - start_time
250
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
251
+ print(
252
+ "{} Total time: {} ({:.4f} s / it)".format(
253
+ header, total_time_str, total_time / len(iterable)
254
+ )
255
+ )
src/misc/profiler_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
3
+ """
4
+
5
+ import copy
6
+ from typing import Tuple
7
+
8
+ from calflops import calculate_flops
9
+
10
+
11
+ def stats(
12
+ cfg,
13
+ input_shape: Tuple = (1, 3, 640, 640),
14
+ ) -> Tuple[int, dict]:
15
+ base_size = cfg.train_dataloader.collate_fn.base_size
16
+ input_shape = (1, 3, base_size, base_size)
17
+
18
+ model_for_info = copy.deepcopy(cfg.model).deploy()
19
+
20
+ flops, macs, _ = calculate_flops(
21
+ model=model_for_info,
22
+ input_shape=input_shape,
23
+ output_as_string=True,
24
+ output_precision=4,
25
+ print_detailed=False,
26
+ )
27
+ params = sum(p.numel() for p in model_for_info.parameters())
28
+ del model_for_info
29
+
30
+ return params, {"Model FLOPs:%s MACs:%s Params:%s" % (flops, macs, params)}
src/misc/visualizer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ "
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import PIL
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from typing import List, Dict
12
+
13
+ torchvision.disable_beta_transforms_warning()
14
+
15
+ __all__ = ["show_sample", "save_samples"]
16
+
17
+ def save_samples(samples: torch.Tensor, targets: List[Dict], output_dir: str, split: str, normalized: bool, box_fmt: str):
18
+ '''
19
+ normalized: whether the boxes are normalized to [0, 1]
20
+ box_fmt: 'xyxy', 'xywh', 'cxcywh', D-FINE uses 'cxcywh' for training, 'xyxy' for validation
21
+ '''
22
+ from torchvision.transforms.functional import to_pil_image
23
+ from torchvision.ops import box_convert
24
+ from pathlib import Path
25
+ from PIL import ImageDraw, ImageFont
26
+ import os
27
+
28
+ os.makedirs(Path(output_dir) / Path(f"{split}_samples"), exist_ok=True)
29
+ # Predefined colors (standard color names recognized by PIL)
30
+ BOX_COLORS = [
31
+ "red", "blue", "green", "orange", "purple",
32
+ "cyan", "magenta", "yellow", "lime", "pink",
33
+ "teal", "lavender", "brown", "beige", "maroon",
34
+ "navy", "olive", "coral", "turquoise", "gold"
35
+ ]
36
+
37
+ LABEL_TEXT_COLOR = "white"
38
+
39
+ font = ImageFont.load_default()
40
+ font.size = 32
41
+
42
+ for i, (sample, target) in enumerate(zip(samples, targets)):
43
+ sample_visualization = sample.clone().cpu()
44
+ target_boxes = target["boxes"].clone().cpu()
45
+ target_labels = target["labels"].clone().cpu()
46
+ target_image_id = target["image_id"].item()
47
+ target_image_path = target["image_path"]
48
+ target_image_path_stem = Path(target_image_path).stem
49
+
50
+ sample_visualization = to_pil_image(sample_visualization)
51
+ sample_visualization_w, sample_visualization_h = sample_visualization.size
52
+
53
+ # normalized to pixel space
54
+ if normalized:
55
+ target_boxes[:, 0] = target_boxes[:, 0] * sample_visualization_w
56
+ target_boxes[:, 2] = target_boxes[:, 2] * sample_visualization_w
57
+ target_boxes[:, 1] = target_boxes[:, 1] * sample_visualization_h
58
+ target_boxes[:, 3] = target_boxes[:, 3] * sample_visualization_h
59
+
60
+ # any box format -> xyxy
61
+ target_boxes = box_convert(target_boxes, in_fmt=box_fmt, out_fmt="xyxy")
62
+
63
+ # clip to image size
64
+ target_boxes[:, 0] = torch.clamp(target_boxes[:, 0], 0, sample_visualization_w)
65
+ target_boxes[:, 1] = torch.clamp(target_boxes[:, 1], 0, sample_visualization_h)
66
+ target_boxes[:, 2] = torch.clamp(target_boxes[:, 2], 0, sample_visualization_w)
67
+ target_boxes[:, 3] = torch.clamp(target_boxes[:, 3], 0, sample_visualization_h)
68
+
69
+ target_boxes = target_boxes.numpy().astype(np.int32)
70
+ target_labels = target_labels.numpy().astype(np.int32)
71
+
72
+ draw = ImageDraw.Draw(sample_visualization)
73
+
74
+ # draw target boxes
75
+ for box, label in zip(target_boxes, target_labels):
76
+ x1, y1, x2, y2 = box
77
+
78
+ # Select color based on class ID
79
+ box_color = BOX_COLORS[int(label) % len(BOX_COLORS)]
80
+
81
+ # Draw box (thick)
82
+ draw.rectangle([x1, y1, x2, y2], outline=box_color, width=3)
83
+
84
+ label_text = f"{label}"
85
+
86
+ # Measure text size
87
+ text_width, text_height = draw.textbbox((0, 0), label_text, font=font)[2:4]
88
+
89
+ # Draw text background
90
+ padding = 2
91
+ draw.rectangle(
92
+ [x1, y1 - text_height - padding * 2, x1 + text_width + padding * 2, y1],
93
+ fill=box_color
94
+ )
95
+
96
+ # Draw text (LABEL_TEXT_COLOR)
97
+ draw.text((x1 + padding, y1 - text_height - padding), label_text,
98
+ fill=LABEL_TEXT_COLOR, font=font)
99
+
100
+ save_path = Path(output_dir) / f"{split}_samples" / f"{target_image_id}_{target_image_path_stem}.webp"
101
+ sample_visualization.save(save_path)
102
+
103
+ def show_sample(sample):
104
+ """for coco dataset/dataloader"""
105
+ import matplotlib.pyplot as plt
106
+ from torchvision.transforms.v2 import functional as F
107
+ from torchvision.utils import draw_bounding_boxes
108
+
109
+ image, target = sample
110
+ if isinstance(image, PIL.Image.Image):
111
+ image = F.to_image_tensor(image)
112
+
113
+ image = F.convert_dtype(image, torch.uint8)
114
+ annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
115
+
116
+ fig, ax = plt.subplots()
117
+ ax.imshow(annotated_image.permute(1, 2, 0).numpy())
118
+ ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
119
+ fig.tight_layout()
120
+ fig.show()
121
+ plt.show()
src/nn/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from .arch import *
7
+
8
+ #
9
+ from .backbone import *
10
+ from .backbone import (
11
+ FrozenBatchNorm2d,
12
+ freeze_batch_norm2d,
13
+ get_activation,
14
+ )
15
+ from .criterion import *
16
+ from .postprocessor import *
src/nn/arch/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from .classification import ClassHead, Classification
7
+ from .yolo import YOLO
src/nn/arch/classification.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ...core import register
10
+
11
+ __all__ = ["Classification", "ClassHead"]
12
+
13
+
14
+ @register()
15
+ class Classification(torch.nn.Module):
16
+ __inject__ = ["backbone", "head"]
17
+
18
+ def __init__(self, backbone: nn.Module, head: nn.Module = None):
19
+ super().__init__()
20
+
21
+ self.backbone = backbone
22
+ self.head = head
23
+
24
+ def forward(self, x):
25
+ x = self.backbone(x)
26
+
27
+ if self.head is not None:
28
+ x = self.head(x)
29
+
30
+ return x
31
+
32
+
33
+ @register()
34
+ class ClassHead(nn.Module):
35
+ def __init__(self, hidden_dim, num_classes):
36
+ super().__init__()
37
+ self.pool = nn.AdaptiveAvgPool2d(1)
38
+ self.proj = nn.Linear(hidden_dim, num_classes)
39
+
40
+ def forward(self, x):
41
+ x = x[0] if isinstance(x, (list, tuple)) else x
42
+ x = self.pool(x)
43
+ x = x.reshape(x.shape[0], -1)
44
+ x = self.proj(x)
45
+ return x
src/nn/arch/yolo.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+
8
+ from ...core import register
9
+
10
+ __all__ = [
11
+ "YOLO",
12
+ ]
13
+
14
+
15
+ @register()
16
+ class YOLO(torch.nn.Module):
17
+ __inject__ = [
18
+ "backbone",
19
+ "neck",
20
+ "head",
21
+ ]
22
+
23
+ def __init__(self, backbone: torch.nn.Module, neck, head):
24
+ super().__init__()
25
+ self.backbone = backbone
26
+ self.neck = neck
27
+ self.head = head
28
+
29
+ def forward(self, x, **kwargs):
30
+ x = self.backbone(x)
31
+ x = self.neck(x)
32
+ x = self.head(x)
33
+ return x
34
+
35
+ def deploy(
36
+ self,
37
+ ):
38
+ self.eval()
39
+ for m in self.modules():
40
+ if m is not self and hasattr(m, "deploy"):
41
+ m.deploy()
42
+ return self
src/nn/backbone/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from .common import (
7
+ FrozenBatchNorm2d,
8
+ freeze_batch_norm2d,
9
+ get_activation,
10
+ )
11
+ from .csp_darknet import CSPPAN, CSPDarkNet
12
+ from .csp_resnet import CSPResNet
13
+ from .hgnetv2 import HGNetv2
14
+ from .presnet import PResNet
15
+ from .test_resnet import MResNet
16
+ from .timm_model import TimmModel
17
+ from .torchvision_model import TorchVisionModel
src/nn/backbone/common.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class ConvNormLayer(nn.Module):
11
+ def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
12
+ super().__init__()
13
+ self.conv = nn.Conv2d(
14
+ ch_in,
15
+ ch_out,
16
+ kernel_size,
17
+ stride,
18
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
19
+ bias=bias,
20
+ )
21
+ self.norm = nn.BatchNorm2d(ch_out)
22
+ self.act = nn.Identity() if act is None else get_activation(act)
23
+
24
+ def forward(self, x):
25
+ return self.act(self.norm(self.conv(x)))
26
+
27
+
28
+ class FrozenBatchNorm2d(nn.Module):
29
+ """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
30
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
31
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
32
+ without which any other models than torchvision.models.resnet[18,34,50,101]
33
+ produce nans.
34
+ """
35
+
36
+ def __init__(self, num_features, eps=1e-5):
37
+ super(FrozenBatchNorm2d, self).__init__()
38
+ n = num_features
39
+ self.register_buffer("weight", torch.ones(n))
40
+ self.register_buffer("bias", torch.zeros(n))
41
+ self.register_buffer("running_mean", torch.zeros(n))
42
+ self.register_buffer("running_var", torch.ones(n))
43
+ self.eps = eps
44
+ self.num_features = n
45
+
46
+ def _load_from_state_dict(
47
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
48
+ ):
49
+ num_batches_tracked_key = prefix + "num_batches_tracked"
50
+ if num_batches_tracked_key in state_dict:
51
+ del state_dict[num_batches_tracked_key]
52
+
53
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
54
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
55
+ )
56
+
57
+ def forward(self, x):
58
+ # move reshapes to the beginning
59
+ # to make it fuser-friendly
60
+ w = self.weight.reshape(1, -1, 1, 1)
61
+ b = self.bias.reshape(1, -1, 1, 1)
62
+ rv = self.running_var.reshape(1, -1, 1, 1)
63
+ rm = self.running_mean.reshape(1, -1, 1, 1)
64
+ scale = w * (rv + self.eps).rsqrt()
65
+ bias = b - rm * scale
66
+ return x * scale + bias
67
+
68
+ def extra_repr(self):
69
+ return "{num_features}, eps={eps}".format(**self.__dict__)
70
+
71
+
72
+ def freeze_batch_norm2d(module: nn.Module) -> nn.Module:
73
+ if isinstance(module, nn.BatchNorm2d):
74
+ module = FrozenBatchNorm2d(module.num_features)
75
+ else:
76
+ for name, child in module.named_children():
77
+ _child = freeze_batch_norm2d(child)
78
+ if _child is not child:
79
+ setattr(module, name, _child)
80
+ return module
81
+
82
+
83
+ def get_activation(act: str, inplace: bool = True):
84
+ """get activation"""
85
+ if act is None:
86
+ return nn.Identity()
87
+
88
+ elif isinstance(act, nn.Module):
89
+ return act
90
+
91
+ act = act.lower()
92
+
93
+ if act == "silu" or act == "swish":
94
+ m = nn.SiLU()
95
+
96
+ elif act == "relu":
97
+ m = nn.ReLU()
98
+
99
+ elif act == "leaky_relu":
100
+ m = nn.LeakyReLU()
101
+
102
+ elif act == "silu":
103
+ m = nn.SiLU()
104
+
105
+ elif act == "gelu":
106
+ m = nn.GELU()
107
+
108
+ elif act == "hardsigmoid":
109
+ m = nn.Hardsigmoid()
110
+
111
+ else:
112
+ raise RuntimeError("")
113
+
114
+ if hasattr(m, "inplace"):
115
+ m.inplace = inplace
116
+
117
+ return m
src/nn/backbone/csp_darknet.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import math
7
+ import warnings
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...core import register
14
+ from .common import get_activation
15
+
16
+
17
+ def autopad(k, p=None):
18
+ if p is None:
19
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
20
+ return p
21
+
22
+
23
+ def make_divisible(c, d):
24
+ return math.ceil(c / d) * d
25
+
26
+
27
+ class Conv(nn.Module):
28
+ def __init__(self, cin, cout, k=1, s=1, p=None, g=1, act="silu") -> None:
29
+ super().__init__()
30
+ self.conv = nn.Conv2d(cin, cout, k, s, autopad(k, p), groups=g, bias=False)
31
+ self.bn = nn.BatchNorm2d(cout)
32
+ self.act = get_activation(act, inplace=True)
33
+
34
+ def forward(self, x):
35
+ return self.act(self.bn(self.conv(x)))
36
+
37
+
38
+ class Bottleneck(nn.Module):
39
+ # Standard bottleneck
40
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act="silu"):
41
+ super().__init__()
42
+ c_ = int(c2 * e) # hidden channels
43
+ self.cv1 = Conv(c1, c_, 1, 1, act=act)
44
+ self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act)
45
+ self.add = shortcut and c1 == c2
46
+
47
+ def forward(self, x):
48
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
49
+
50
+
51
+ class C3(nn.Module):
52
+ # CSP Bottleneck with 3 convolutions
53
+ def __init__(
54
+ self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act="silu"
55
+ ): # ch_in, ch_out, number, shortcut, groups, expansion
56
+ super().__init__()
57
+ c_ = int(c2 * e) # hidden channels
58
+ self.cv1 = Conv(c1, c_, 1, 1, act=act)
59
+ self.cv2 = Conv(c1, c_, 1, 1, act=act)
60
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n)))
61
+ self.cv3 = Conv(2 * c_, c2, 1, act=act)
62
+
63
+ def forward(self, x):
64
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
65
+
66
+
67
+ class SPPF(nn.Module):
68
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
69
+ def __init__(self, c1, c2, k=5, act="silu"): # equivalent to SPP(k=(5, 9, 13))
70
+ super().__init__()
71
+ c_ = c1 // 2 # hidden channels
72
+ self.cv1 = Conv(c1, c_, 1, 1, act=act)
73
+ self.cv2 = Conv(c_ * 4, c2, 1, 1, act=act)
74
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
75
+
76
+ def forward(self, x):
77
+ x = self.cv1(x)
78
+ with warnings.catch_warnings():
79
+ warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning
80
+ y1 = self.m(x)
81
+ y2 = self.m(y1)
82
+ return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
83
+
84
+
85
+ @register()
86
+ class CSPDarkNet(nn.Module):
87
+ __share__ = ["depth_multi", "width_multi"]
88
+
89
+ def __init__(
90
+ self,
91
+ in_channels=3,
92
+ width_multi=1.0,
93
+ depth_multi=1.0,
94
+ return_idx=[2, 3, -1],
95
+ act="silu",
96
+ ) -> None:
97
+ super().__init__()
98
+
99
+ channels = [64, 128, 256, 512, 1024]
100
+ channels = [make_divisible(c * width_multi, 8) for c in channels]
101
+
102
+ depths = [3, 6, 9, 3]
103
+ depths = [max(round(d * depth_multi), 1) for d in depths]
104
+
105
+ self.layers = nn.ModuleList([Conv(in_channels, channels[0], 6, 2, 2, act=act)])
106
+ for i, (c, d) in enumerate(zip(channels, depths), 1):
107
+ layer = nn.Sequential(
108
+ *[Conv(c, channels[i], 3, 2, act=act), C3(channels[i], channels[i], n=d, act=act)]
109
+ )
110
+ self.layers.append(layer)
111
+
112
+ self.layers.append(SPPF(channels[-1], channels[-1], k=5, act=act))
113
+
114
+ self.return_idx = return_idx
115
+ self.out_channels = [channels[i] for i in self.return_idx]
116
+ self.strides = [[2, 4, 8, 16, 32][i] for i in self.return_idx]
117
+ self.depths = depths
118
+ self.act = act
119
+
120
+ def forward(self, x):
121
+ outputs = []
122
+ for _, m in enumerate(self.layers):
123
+ x = m(x)
124
+ outputs.append(x)
125
+
126
+ return [outputs[i] for i in self.return_idx]
127
+
128
+
129
+ @register()
130
+ class CSPPAN(nn.Module):
131
+ """
132
+ P5 ---> 1x1 ---------------------------------> concat --> c3 --> det
133
+ | up | conv /2
134
+ P4 ---> concat ---> c3 ---> 1x1 --> concat ---> c3 -----------> det
135
+ | up | conv /2
136
+ P3 -----------------------> concat ---> c3 ---------------------> det
137
+ """
138
+
139
+ __share__ = [
140
+ "depth_multi",
141
+ ]
142
+
143
+ def __init__(self, in_channels=[256, 512, 1024], depth_multi=1.0, act="silu") -> None:
144
+ super().__init__()
145
+ depth = max(round(3 * depth_multi), 1)
146
+
147
+ self.out_channels = in_channels
148
+ self.fpn_stems = nn.ModuleList(
149
+ [
150
+ Conv(cin, cout, 1, 1, act=act)
151
+ for cin, cout in zip(in_channels[::-1], in_channels[::-1][1:])
152
+ ]
153
+ )
154
+ self.fpn_csps = nn.ModuleList(
155
+ [
156
+ C3(cin, cout, depth, False, act=act)
157
+ for cin, cout in zip(in_channels[::-1], in_channels[::-1][1:])
158
+ ]
159
+ )
160
+
161
+ self.pan_stems = nn.ModuleList([Conv(c, c, 3, 2, act=act) for c in in_channels[:-1]])
162
+ self.pan_csps = nn.ModuleList([C3(c, c, depth, False, act=act) for c in in_channels[1:]])
163
+
164
+ def forward(self, feats):
165
+ fpn_feats = []
166
+ for i, feat in enumerate(feats[::-1]):
167
+ if i == 0:
168
+ feat = self.fpn_stems[i](feat)
169
+ fpn_feats.append(feat)
170
+ else:
171
+ _feat = F.interpolate(fpn_feats[-1], scale_factor=2, mode="nearest")
172
+ feat = torch.concat([_feat, feat], dim=1)
173
+ feat = self.fpn_csps[i - 1](feat)
174
+ if i < len(self.fpn_stems):
175
+ feat = self.fpn_stems[i](feat)
176
+ fpn_feats.append(feat)
177
+
178
+ pan_feats = []
179
+ for i, feat in enumerate(fpn_feats[::-1]):
180
+ if i == 0:
181
+ pan_feats.append(feat)
182
+ else:
183
+ _feat = self.pan_stems[i - 1](pan_feats[-1])
184
+ feat = torch.concat([_feat, feat], dim=1)
185
+ feat = self.pan_csps[i - 1](feat)
186
+ pan_feats.append(feat)
187
+
188
+ return pan_feats
189
+
190
+
191
+ if __name__ == "__main__":
192
+ data = torch.rand(1, 3, 320, 640)
193
+
194
+ width_multi = 0.75
195
+ depth_multi = 0.33
196
+
197
+ m = CSPDarkNet(3, width_multi=width_multi, depth_multi=depth_multi, act="silu")
198
+ outputs = m(data)
199
+ print([o.shape for o in outputs])
200
+
201
+ m = CSPPAN(in_channels=m.out_channels, depth_multi=depth_multi, act="silu")
202
+ outputs = m(outputs)
203
+ print([o.shape for o in outputs])
src/nn/backbone/csp_resnet.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.6/ppdet/modeling/backbones/cspresnet.py
3
+
4
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...core import register
14
+ from .common import get_activation
15
+
16
+ __all__ = ["CSPResNet"]
17
+
18
+
19
+ donwload_url = {
20
+ "s": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_s_pretrained_from_paddle.pth",
21
+ "m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_m_pretrained_from_paddle.pth",
22
+ "l": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_l_pretrained_from_paddle.pth",
23
+ "x": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_x_pretrained_from_paddle.pth",
24
+ }
25
+
26
+
27
+ class ConvBNLayer(nn.Module):
28
+ def __init__(self, ch_in, ch_out, filter_size=3, stride=1, groups=1, padding=0, act=None):
29
+ super().__init__()
30
+ self.conv = nn.Conv2d(
31
+ ch_in, ch_out, filter_size, stride, padding, groups=groups, bias=False
32
+ )
33
+ self.bn = nn.BatchNorm2d(ch_out)
34
+ self.act = get_activation(act)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ x = self.conv(x)
38
+ x = self.bn(x)
39
+ x = self.act(x)
40
+ return x
41
+
42
+
43
+ class RepVggBlock(nn.Module):
44
+ def __init__(self, ch_in, ch_out, act="relu", alpha: bool = False):
45
+ super().__init__()
46
+ self.ch_in = ch_in
47
+ self.ch_out = ch_out
48
+ self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=None)
49
+ self.conv2 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=None)
50
+ self.act = get_activation(act)
51
+
52
+ if alpha:
53
+ self.alpha = nn.Parameter(
54
+ torch.ones(
55
+ 1,
56
+ )
57
+ )
58
+ else:
59
+ self.alpha = None
60
+
61
+ def forward(self, x):
62
+ if hasattr(self, "conv"):
63
+ y = self.conv(x)
64
+ else:
65
+ if self.alpha:
66
+ y = self.conv1(x) + self.alpha * self.conv2(x)
67
+ else:
68
+ y = self.conv1(x) + self.conv2(x)
69
+ y = self.act(y)
70
+ return y
71
+
72
+ def convert_to_deploy(self):
73
+ if not hasattr(self, "conv"):
74
+ self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)
75
+
76
+ kernel, bias = self.get_equivalent_kernel_bias()
77
+ self.conv.weight.data = kernel
78
+ self.conv.bias.data = bias
79
+
80
+ def get_equivalent_kernel_bias(self):
81
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
82
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
83
+
84
+ if self.alpha:
85
+ return kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor(
86
+ kernel1x1
87
+ ), bias3x3 + self.alpha * bias1x1
88
+ else:
89
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
90
+
91
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
92
+ if kernel1x1 is None:
93
+ return 0
94
+ else:
95
+ return F.pad(kernel1x1, [1, 1, 1, 1])
96
+
97
+ def _fuse_bn_tensor(self, branch: ConvBNLayer):
98
+ if branch is None:
99
+ return 0, 0
100
+ kernel = branch.conv.weight
101
+ running_mean = branch.norm.running_mean
102
+ running_var = branch.norm.running_var
103
+ gamma = branch.norm.weight
104
+ beta = branch.norm.bias
105
+ eps = branch.norm.eps
106
+ std = (running_var + eps).sqrt()
107
+ t = (gamma / std).reshape(-1, 1, 1, 1)
108
+ return kernel * t, beta - running_mean * gamma / std
109
+
110
+
111
+ class BasicBlock(nn.Module):
112
+ def __init__(self, ch_in, ch_out, act="relu", shortcut=True, use_alpha=False):
113
+ super().__init__()
114
+ assert ch_in == ch_out
115
+ self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
116
+ self.conv2 = RepVggBlock(ch_out, ch_out, act=act, alpha=use_alpha)
117
+ self.shortcut = shortcut
118
+
119
+ def forward(self, x):
120
+ y = self.conv1(x)
121
+ y = self.conv2(y)
122
+ if self.shortcut:
123
+ return x + y
124
+ else:
125
+ return y
126
+
127
+
128
+ class EffectiveSELayer(nn.Module):
129
+ """Effective Squeeze-Excitation
130
+ From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
131
+ """
132
+
133
+ def __init__(self, channels, act="hardsigmoid"):
134
+ super(EffectiveSELayer, self).__init__()
135
+ self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
136
+ self.act = get_activation(act)
137
+
138
+ def forward(self, x: torch.Tensor):
139
+ x_se = x.mean((2, 3), keepdim=True)
140
+ x_se = self.fc(x_se)
141
+ x_se = self.act(x_se)
142
+ return x * x_se
143
+
144
+
145
+ class CSPResStage(nn.Module):
146
+ def __init__(self, block_fn, ch_in, ch_out, n, stride, act="relu", attn="eca", use_alpha=False):
147
+ super().__init__()
148
+ ch_mid = (ch_in + ch_out) // 2
149
+ if stride == 2:
150
+ self.conv_down = ConvBNLayer(ch_in, ch_mid, 3, stride=2, padding=1, act=act)
151
+ else:
152
+ self.conv_down = None
153
+ self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
154
+ self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
155
+ self.blocks = nn.Sequential(
156
+ *[
157
+ block_fn(ch_mid // 2, ch_mid // 2, act=act, shortcut=True, use_alpha=use_alpha)
158
+ for i in range(n)
159
+ ]
160
+ )
161
+ if attn:
162
+ self.attn = EffectiveSELayer(ch_mid, act="hardsigmoid")
163
+ else:
164
+ self.attn = None
165
+
166
+ self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act)
167
+
168
+ def forward(self, x):
169
+ if self.conv_down is not None:
170
+ x = self.conv_down(x)
171
+ y1 = self.conv1(x)
172
+ y2 = self.blocks(self.conv2(x))
173
+ y = torch.concat([y1, y2], dim=1)
174
+ if self.attn is not None:
175
+ y = self.attn(y)
176
+ y = self.conv3(y)
177
+ return y
178
+
179
+
180
+ @register()
181
+ class CSPResNet(nn.Module):
182
+ layers = [3, 6, 6, 3]
183
+ channels = [64, 128, 256, 512, 1024]
184
+ model_cfg = {
185
+ "s": {
186
+ "depth_mult": 0.33,
187
+ "width_mult": 0.50,
188
+ },
189
+ "m": {
190
+ "depth_mult": 0.67,
191
+ "width_mult": 0.75,
192
+ },
193
+ "l": {
194
+ "depth_mult": 1.00,
195
+ "width_mult": 1.00,
196
+ },
197
+ "x": {
198
+ "depth_mult": 1.33,
199
+ "width_mult": 1.25,
200
+ },
201
+ }
202
+
203
+ def __init__(
204
+ self,
205
+ name: str,
206
+ act="silu",
207
+ return_idx=[1, 2, 3],
208
+ use_large_stem=True,
209
+ use_alpha=False,
210
+ pretrained=False,
211
+ ):
212
+ super().__init__()
213
+ depth_mult = self.model_cfg[name]["depth_mult"]
214
+ width_mult = self.model_cfg[name]["width_mult"]
215
+
216
+ channels = [max(round(c * width_mult), 1) for c in self.channels]
217
+ layers = [max(round(l * depth_mult), 1) for l in self.layers]
218
+ act = get_activation(act)
219
+
220
+ if use_large_stem:
221
+ self.stem = nn.Sequential(
222
+ OrderedDict(
223
+ [
224
+ (
225
+ "conv1",
226
+ ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act),
227
+ ),
228
+ (
229
+ "conv2",
230
+ ConvBNLayer(
231
+ channels[0] // 2, channels[0] // 2, 3, stride=1, padding=1, act=act
232
+ ),
233
+ ),
234
+ (
235
+ "conv3",
236
+ ConvBNLayer(
237
+ channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act
238
+ ),
239
+ ),
240
+ ]
241
+ )
242
+ )
243
+ else:
244
+ self.stem = nn.Sequential(
245
+ OrderedDict(
246
+ [
247
+ (
248
+ "conv1",
249
+ ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act),
250
+ ),
251
+ (
252
+ "conv2",
253
+ ConvBNLayer(
254
+ channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act
255
+ ),
256
+ ),
257
+ ]
258
+ )
259
+ )
260
+
261
+ n = len(channels) - 1
262
+ self.stages = nn.Sequential(
263
+ OrderedDict(
264
+ [
265
+ (
266
+ str(i),
267
+ CSPResStage(
268
+ BasicBlock,
269
+ channels[i],
270
+ channels[i + 1],
271
+ layers[i],
272
+ 2,
273
+ act=act,
274
+ use_alpha=use_alpha,
275
+ ),
276
+ )
277
+ for i in range(n)
278
+ ]
279
+ )
280
+ )
281
+
282
+ self._out_channels = channels[1:]
283
+ self._out_strides = [4 * 2**i for i in range(n)]
284
+ self.return_idx = return_idx
285
+
286
+ if pretrained:
287
+ if isinstance(pretrained, bool) or "http" in pretrained:
288
+ state = torch.hub.load_state_dict_from_url(donwload_url[name], map_location="cpu")
289
+ else:
290
+ state = torch.load(pretrained, map_location="cpu")
291
+ self.load_state_dict(state)
292
+ print(f"Load CSPResNet_{name} state_dict")
293
+
294
+ def forward(self, x):
295
+ x = self.stem(x)
296
+ outs = []
297
+ for idx, stage in enumerate(self.stages):
298
+ x = stage(x)
299
+ if idx in self.return_idx:
300
+ outs.append(x)
301
+
302
+ return outs
src/nn/backbone/hgnetv2.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reference
3
+ - https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
4
+
5
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from ...core import register
16
+ from .common import FrozenBatchNorm2d
17
+
18
+ # Constants for initialization
19
+ kaiming_normal_ = nn.init.kaiming_normal_
20
+ zeros_ = nn.init.zeros_
21
+ ones_ = nn.init.ones_
22
+
23
+ __all__ = ["HGNetv2"]
24
+
25
+ def safe_barrier():
26
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
27
+ torch.distributed.barrier()
28
+ else:
29
+ pass
30
+
31
+ def safe_get_rank():
32
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
33
+ return torch.distributed.get_rank()
34
+ else:
35
+ return 0
36
+
37
+ class LearnableAffineBlock(nn.Module):
38
+ def __init__(self, scale_value=1.0, bias_value=0.0):
39
+ super().__init__()
40
+ self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
41
+ self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
42
+
43
+ def forward(self, x):
44
+ return self.scale * x + self.bias
45
+
46
+
47
+ class ConvBNAct(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_chs,
51
+ out_chs,
52
+ kernel_size,
53
+ stride=1,
54
+ groups=1,
55
+ padding="",
56
+ use_act=True,
57
+ use_lab=False,
58
+ ):
59
+ super().__init__()
60
+ self.use_act = use_act
61
+ self.use_lab = use_lab
62
+ if padding == "same":
63
+ self.conv = nn.Sequential(
64
+ nn.ZeroPad2d([0, 1, 0, 1]),
65
+ nn.Conv2d(in_chs, out_chs, kernel_size, stride, groups=groups, bias=False),
66
+ )
67
+ else:
68
+ self.conv = nn.Conv2d(
69
+ in_chs,
70
+ out_chs,
71
+ kernel_size,
72
+ stride,
73
+ padding=(kernel_size - 1) // 2,
74
+ groups=groups,
75
+ bias=False,
76
+ )
77
+ self.bn = nn.BatchNorm2d(out_chs)
78
+ if self.use_act:
79
+ self.act = nn.ReLU()
80
+ else:
81
+ self.act = nn.Identity()
82
+ if self.use_act and self.use_lab:
83
+ self.lab = LearnableAffineBlock()
84
+ else:
85
+ self.lab = nn.Identity()
86
+
87
+ def forward(self, x):
88
+ x = self.conv(x)
89
+ x = self.bn(x)
90
+ x = self.act(x)
91
+ x = self.lab(x)
92
+ return x
93
+
94
+
95
+ class LightConvBNAct(nn.Module):
96
+ def __init__(
97
+ self,
98
+ in_chs,
99
+ out_chs,
100
+ kernel_size,
101
+ groups=1,
102
+ use_lab=False,
103
+ ):
104
+ super().__init__()
105
+ self.conv1 = ConvBNAct(
106
+ in_chs,
107
+ out_chs,
108
+ kernel_size=1,
109
+ use_act=False,
110
+ use_lab=use_lab,
111
+ )
112
+ self.conv2 = ConvBNAct(
113
+ out_chs,
114
+ out_chs,
115
+ kernel_size=kernel_size,
116
+ groups=out_chs,
117
+ use_act=True,
118
+ use_lab=use_lab,
119
+ )
120
+
121
+ def forward(self, x):
122
+ x = self.conv1(x)
123
+ x = self.conv2(x)
124
+ return x
125
+
126
+
127
+ class StemBlock(nn.Module):
128
+ # for HGNetv2
129
+ def __init__(self, in_chs, mid_chs, out_chs, use_lab=False):
130
+ super().__init__()
131
+ self.stem1 = ConvBNAct(
132
+ in_chs,
133
+ mid_chs,
134
+ kernel_size=3,
135
+ stride=2,
136
+ use_lab=use_lab,
137
+ )
138
+ self.stem2a = ConvBNAct(
139
+ mid_chs,
140
+ mid_chs // 2,
141
+ kernel_size=2,
142
+ stride=1,
143
+ use_lab=use_lab,
144
+ )
145
+ self.stem2b = ConvBNAct(
146
+ mid_chs // 2,
147
+ mid_chs,
148
+ kernel_size=2,
149
+ stride=1,
150
+ use_lab=use_lab,
151
+ )
152
+ self.stem3 = ConvBNAct(
153
+ mid_chs * 2,
154
+ mid_chs,
155
+ kernel_size=3,
156
+ stride=2,
157
+ use_lab=use_lab,
158
+ )
159
+ self.stem4 = ConvBNAct(
160
+ mid_chs,
161
+ out_chs,
162
+ kernel_size=1,
163
+ stride=1,
164
+ use_lab=use_lab,
165
+ )
166
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True)
167
+
168
+ def forward(self, x):
169
+ x = self.stem1(x)
170
+ x = F.pad(x, (0, 1, 0, 1))
171
+ x2 = self.stem2a(x)
172
+ x2 = F.pad(x2, (0, 1, 0, 1))
173
+ x2 = self.stem2b(x2)
174
+ x1 = self.pool(x)
175
+ x = torch.cat([x1, x2], dim=1)
176
+ x = self.stem3(x)
177
+ x = self.stem4(x)
178
+ return x
179
+
180
+
181
+ class EseModule(nn.Module):
182
+ def __init__(self, chs):
183
+ super().__init__()
184
+ self.conv = nn.Conv2d(
185
+ chs,
186
+ chs,
187
+ kernel_size=1,
188
+ stride=1,
189
+ padding=0,
190
+ )
191
+ self.sigmoid = nn.Sigmoid()
192
+
193
+ def forward(self, x):
194
+ identity = x
195
+ x = x.mean((2, 3), keepdim=True)
196
+ x = self.conv(x)
197
+ x = self.sigmoid(x)
198
+ return torch.mul(identity, x)
199
+
200
+
201
+ class HG_Block(nn.Module):
202
+ def __init__(
203
+ self,
204
+ in_chs,
205
+ mid_chs,
206
+ out_chs,
207
+ layer_num,
208
+ kernel_size=3,
209
+ residual=False,
210
+ light_block=False,
211
+ use_lab=False,
212
+ agg="ese",
213
+ drop_path=0.0,
214
+ ):
215
+ super().__init__()
216
+ self.residual = residual
217
+
218
+ self.layers = nn.ModuleList()
219
+ for i in range(layer_num):
220
+ if light_block:
221
+ self.layers.append(
222
+ LightConvBNAct(
223
+ in_chs if i == 0 else mid_chs,
224
+ mid_chs,
225
+ kernel_size=kernel_size,
226
+ use_lab=use_lab,
227
+ )
228
+ )
229
+ else:
230
+ self.layers.append(
231
+ ConvBNAct(
232
+ in_chs if i == 0 else mid_chs,
233
+ mid_chs,
234
+ kernel_size=kernel_size,
235
+ stride=1,
236
+ use_lab=use_lab,
237
+ )
238
+ )
239
+
240
+ # feature aggregation
241
+ total_chs = in_chs + layer_num * mid_chs
242
+ if agg == "se":
243
+ aggregation_squeeze_conv = ConvBNAct(
244
+ total_chs,
245
+ out_chs // 2,
246
+ kernel_size=1,
247
+ stride=1,
248
+ use_lab=use_lab,
249
+ )
250
+ aggregation_excitation_conv = ConvBNAct(
251
+ out_chs // 2,
252
+ out_chs,
253
+ kernel_size=1,
254
+ stride=1,
255
+ use_lab=use_lab,
256
+ )
257
+ self.aggregation = nn.Sequential(
258
+ aggregation_squeeze_conv,
259
+ aggregation_excitation_conv,
260
+ )
261
+ else:
262
+ aggregation_conv = ConvBNAct(
263
+ total_chs,
264
+ out_chs,
265
+ kernel_size=1,
266
+ stride=1,
267
+ use_lab=use_lab,
268
+ )
269
+ att = EseModule(out_chs)
270
+ self.aggregation = nn.Sequential(
271
+ aggregation_conv,
272
+ att,
273
+ )
274
+
275
+ self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity()
276
+
277
+ def forward(self, x):
278
+ identity = x
279
+ output = [x]
280
+ for layer in self.layers:
281
+ x = layer(x)
282
+ output.append(x)
283
+ x = torch.cat(output, dim=1)
284
+ x = self.aggregation(x)
285
+ if self.residual:
286
+ x = self.drop_path(x) + identity
287
+ return x
288
+
289
+
290
+ class HG_Stage(nn.Module):
291
+ def __init__(
292
+ self,
293
+ in_chs,
294
+ mid_chs,
295
+ out_chs,
296
+ block_num,
297
+ layer_num,
298
+ downsample=True,
299
+ light_block=False,
300
+ kernel_size=3,
301
+ use_lab=False,
302
+ agg="se",
303
+ drop_path=0.0,
304
+ ):
305
+ super().__init__()
306
+ self.downsample = downsample
307
+ if downsample:
308
+ self.downsample = ConvBNAct(
309
+ in_chs,
310
+ in_chs,
311
+ kernel_size=3,
312
+ stride=2,
313
+ groups=in_chs,
314
+ use_act=False,
315
+ use_lab=use_lab,
316
+ )
317
+ else:
318
+ self.downsample = nn.Identity()
319
+
320
+ blocks_list = []
321
+ for i in range(block_num):
322
+ blocks_list.append(
323
+ HG_Block(
324
+ in_chs if i == 0 else out_chs,
325
+ mid_chs,
326
+ out_chs,
327
+ layer_num,
328
+ residual=False if i == 0 else True,
329
+ kernel_size=kernel_size,
330
+ light_block=light_block,
331
+ use_lab=use_lab,
332
+ agg=agg,
333
+ drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
334
+ )
335
+ )
336
+ self.blocks = nn.Sequential(*blocks_list)
337
+
338
+ def forward(self, x):
339
+ x = self.downsample(x)
340
+ x = self.blocks(x)
341
+ return x
342
+
343
+
344
+ @register()
345
+ class HGNetv2(nn.Module):
346
+ """
347
+ HGNetV2
348
+ Args:
349
+ stem_channels: list. Number of channels for the stem block.
350
+ stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc.
351
+ use_lab: boolean. Whether to use LearnableAffineBlock in network.
352
+ lr_mult_list: list. Control the learning rate of different stages.
353
+ Returns:
354
+ model: nn.Layer. Specific HGNetV2 model depends on args.
355
+ """
356
+
357
+ arch_configs = {
358
+ "B0": {
359
+ "stem_channels": [3, 16, 16],
360
+ "stage_config": {
361
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
362
+ "stage1": [16, 16, 64, 1, False, False, 3, 3],
363
+ "stage2": [64, 32, 256, 1, True, False, 3, 3],
364
+ "stage3": [256, 64, 512, 2, True, True, 5, 3],
365
+ "stage4": [512, 128, 1024, 1, True, True, 5, 3],
366
+ },
367
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth",
368
+ },
369
+ "B1": {
370
+ "stem_channels": [3, 24, 32],
371
+ "stage_config": {
372
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
373
+ "stage1": [32, 32, 64, 1, False, False, 3, 3],
374
+ "stage2": [64, 48, 256, 1, True, False, 3, 3],
375
+ "stage3": [256, 96, 512, 2, True, True, 5, 3],
376
+ "stage4": [512, 192, 1024, 1, True, True, 5, 3],
377
+ },
378
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth",
379
+ },
380
+ "B2": {
381
+ "stem_channels": [3, 24, 32],
382
+ "stage_config": {
383
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
384
+ "stage1": [32, 32, 96, 1, False, False, 3, 4],
385
+ "stage2": [96, 64, 384, 1, True, False, 3, 4],
386
+ "stage3": [384, 128, 768, 3, True, True, 5, 4],
387
+ "stage4": [768, 256, 1536, 1, True, True, 5, 4],
388
+ },
389
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth",
390
+ },
391
+ "B3": {
392
+ "stem_channels": [3, 24, 32],
393
+ "stage_config": {
394
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
395
+ "stage1": [32, 32, 128, 1, False, False, 3, 5],
396
+ "stage2": [128, 64, 512, 1, True, False, 3, 5],
397
+ "stage3": [512, 128, 1024, 3, True, True, 5, 5],
398
+ "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
399
+ },
400
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth",
401
+ },
402
+ "B4": {
403
+ "stem_channels": [3, 32, 48],
404
+ "stage_config": {
405
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
406
+ "stage1": [48, 48, 128, 1, False, False, 3, 6],
407
+ "stage2": [128, 96, 512, 1, True, False, 3, 6],
408
+ "stage3": [512, 192, 1024, 3, True, True, 5, 6],
409
+ "stage4": [1024, 384, 2048, 1, True, True, 5, 6],
410
+ },
411
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth",
412
+ },
413
+ "B5": {
414
+ "stem_channels": [3, 32, 64],
415
+ "stage_config": {
416
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
417
+ "stage1": [64, 64, 128, 1, False, False, 3, 6],
418
+ "stage2": [128, 128, 512, 2, True, False, 3, 6],
419
+ "stage3": [512, 256, 1024, 5, True, True, 5, 6],
420
+ "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
421
+ },
422
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth",
423
+ },
424
+ "B6": {
425
+ "stem_channels": [3, 48, 96],
426
+ "stage_config": {
427
+ # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num
428
+ "stage1": [96, 96, 192, 2, False, False, 3, 6],
429
+ "stage2": [192, 192, 512, 3, True, False, 3, 6],
430
+ "stage3": [512, 384, 1024, 6, True, True, 5, 6],
431
+ "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
432
+ },
433
+ "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth",
434
+ },
435
+ }
436
+
437
+ def __init__(
438
+ self,
439
+ name,
440
+ use_lab=False,
441
+ return_idx=[1, 2, 3],
442
+ freeze_stem_only=True,
443
+ freeze_at=0,
444
+ freeze_norm=True,
445
+ pretrained=True,
446
+ local_model_dir="weight/hgnetv2/",
447
+ ):
448
+ super().__init__()
449
+ self.use_lab = use_lab
450
+ self.return_idx = return_idx
451
+
452
+ stem_channels = self.arch_configs[name]["stem_channels"]
453
+ stage_config = self.arch_configs[name]["stage_config"]
454
+ download_url = self.arch_configs[name]["url"]
455
+
456
+ self._out_strides = [4, 8, 16, 32]
457
+ self._out_channels = [stage_config[k][2] for k in stage_config]
458
+
459
+ # stem
460
+ self.stem = StemBlock(
461
+ in_chs=stem_channels[0],
462
+ mid_chs=stem_channels[1],
463
+ out_chs=stem_channels[2],
464
+ use_lab=use_lab,
465
+ )
466
+
467
+ # stages
468
+ self.stages = nn.ModuleList()
469
+ for i, k in enumerate(stage_config):
470
+ (
471
+ in_channels,
472
+ mid_channels,
473
+ out_channels,
474
+ block_num,
475
+ downsample,
476
+ light_block,
477
+ kernel_size,
478
+ layer_num,
479
+ ) = stage_config[k]
480
+ self.stages.append(
481
+ HG_Stage(
482
+ in_channels,
483
+ mid_channels,
484
+ out_channels,
485
+ block_num,
486
+ layer_num,
487
+ downsample,
488
+ light_block,
489
+ kernel_size,
490
+ use_lab,
491
+ )
492
+ )
493
+
494
+ if freeze_at >= 0:
495
+ self._freeze_parameters(self.stem)
496
+ if not freeze_stem_only:
497
+ for i in range(min(freeze_at + 1, len(self.stages))):
498
+ self._freeze_parameters(self.stages[i])
499
+
500
+ if freeze_norm:
501
+ self._freeze_norm(self)
502
+
503
+ if pretrained:
504
+ RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
505
+ try:
506
+ model_path = local_model_dir + "PPHGNetV2_" + name + "_stage1.pth"
507
+ if os.path.exists(model_path):
508
+ state = torch.load(model_path, map_location="cpu")
509
+ print(f"Loaded stage1 {name} HGNetV2 from local file.")
510
+ else:
511
+ # If the file doesn't exist locally, download from the URL
512
+ if safe_get_rank() == 0:
513
+ print(
514
+ GREEN
515
+ + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection."
516
+ + RESET
517
+ )
518
+ print(
519
+ GREEN
520
+ + "Please check your network connection. Or download the model manually from "
521
+ + RESET
522
+ + f"{download_url}"
523
+ + GREEN
524
+ + " to "
525
+ + RESET
526
+ + f"{local_model_dir}."
527
+ + RESET
528
+ )
529
+ state = torch.hub.load_state_dict_from_url(
530
+ download_url, map_location="cpu", model_dir=local_model_dir
531
+ )
532
+ safe_barrier()
533
+ else:
534
+ safe_barrier()
535
+ state = torch.load(local_model_dir)
536
+
537
+ print(f"Loaded stage1 {name} HGNetV2 from URL.")
538
+
539
+ self.load_state_dict(state)
540
+
541
+ except (Exception, KeyboardInterrupt) as e:
542
+ if safe_get_rank() == 0:
543
+ print(f"{str(e)}")
544
+ logging.error(
545
+ RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET
546
+ )
547
+ logging.error(
548
+ GREEN
549
+ + "Please check your network connection. Or download the model manually from "
550
+ + RESET
551
+ + f"{download_url}"
552
+ + GREEN
553
+ + " to "
554
+ + RESET
555
+ + f"{local_model_dir}."
556
+ + RESET
557
+ )
558
+ exit()
559
+
560
+ def _freeze_norm(self, m: nn.Module):
561
+ if isinstance(m, nn.BatchNorm2d):
562
+ m = FrozenBatchNorm2d(m.num_features)
563
+ else:
564
+ for name, child in m.named_children():
565
+ _child = self._freeze_norm(child)
566
+ if _child is not child:
567
+ setattr(m, name, _child)
568
+ return m
569
+
570
+ def _freeze_parameters(self, m: nn.Module):
571
+ for p in m.parameters():
572
+ p.requires_grad = False
573
+
574
+ def forward(self, x):
575
+ x = self.stem(x)
576
+ outs = []
577
+ for idx, stage in enumerate(self.stages):
578
+ x = stage(x)
579
+ if idx in self.return_idx:
580
+ outs.append(x)
581
+ return outs
src/nn/backbone/presnet.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from ...core import register
13
+ from .common import FrozenBatchNorm2d, get_activation
14
+
15
+ __all__ = ["PResNet"]
16
+
17
+
18
+ ResNet_cfg = {
19
+ 18: [2, 2, 2, 2],
20
+ 34: [3, 4, 6, 3],
21
+ 50: [3, 4, 6, 3],
22
+ 101: [3, 4, 23, 3],
23
+ # 152: [3, 8, 36, 3],
24
+ }
25
+
26
+
27
+ donwload_url = {
28
+ 18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth",
29
+ 34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth",
30
+ 50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth",
31
+ 101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth",
32
+ }
33
+
34
+
35
+ class ConvNormLayer(nn.Module):
36
+ def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
37
+ super().__init__()
38
+ self.conv = nn.Conv2d(
39
+ ch_in,
40
+ ch_out,
41
+ kernel_size,
42
+ stride,
43
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
44
+ bias=bias,
45
+ )
46
+ self.norm = nn.BatchNorm2d(ch_out)
47
+ self.act = get_activation(act)
48
+
49
+ def forward(self, x):
50
+ return self.act(self.norm(self.conv(x)))
51
+
52
+
53
+ class BasicBlock(nn.Module):
54
+ expansion = 1
55
+
56
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
57
+ super().__init__()
58
+
59
+ self.shortcut = shortcut
60
+
61
+ if not shortcut:
62
+ if variant == "d" and stride == 2:
63
+ self.short = nn.Sequential(
64
+ OrderedDict(
65
+ [
66
+ ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
67
+ ("conv", ConvNormLayer(ch_in, ch_out, 1, 1)),
68
+ ]
69
+ )
70
+ )
71
+ else:
72
+ self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
73
+
74
+ self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
75
+ self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
76
+ self.act = nn.Identity() if act is None else get_activation(act)
77
+
78
+ def forward(self, x):
79
+ out = self.branch2a(x)
80
+ out = self.branch2b(out)
81
+ if self.shortcut:
82
+ short = x
83
+ else:
84
+ short = self.short(x)
85
+
86
+ out = out + short
87
+ out = self.act(out)
88
+
89
+ return out
90
+
91
+
92
+ class BottleNeck(nn.Module):
93
+ expansion = 4
94
+
95
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
96
+ super().__init__()
97
+
98
+ if variant == "a":
99
+ stride1, stride2 = stride, 1
100
+ else:
101
+ stride1, stride2 = 1, stride
102
+
103
+ width = ch_out
104
+
105
+ self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
106
+ self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
107
+ self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
108
+
109
+ self.shortcut = shortcut
110
+ if not shortcut:
111
+ if variant == "d" and stride == 2:
112
+ self.short = nn.Sequential(
113
+ OrderedDict(
114
+ [
115
+ ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
116
+ ("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)),
117
+ ]
118
+ )
119
+ )
120
+ else:
121
+ self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
122
+
123
+ self.act = nn.Identity() if act is None else get_activation(act)
124
+
125
+ def forward(self, x):
126
+ out = self.branch2a(x)
127
+ out = self.branch2b(out)
128
+ out = self.branch2c(out)
129
+
130
+ if self.shortcut:
131
+ short = x
132
+ else:
133
+ short = self.short(x)
134
+
135
+ out = out + short
136
+ out = self.act(out)
137
+
138
+ return out
139
+
140
+
141
+ class Blocks(nn.Module):
142
+ def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
143
+ super().__init__()
144
+
145
+ self.blocks = nn.ModuleList()
146
+ for i in range(count):
147
+ self.blocks.append(
148
+ block(
149
+ ch_in,
150
+ ch_out,
151
+ stride=2 if i == 0 and stage_num != 2 else 1,
152
+ shortcut=False if i == 0 else True,
153
+ variant=variant,
154
+ act=act,
155
+ )
156
+ )
157
+
158
+ if i == 0:
159
+ ch_in = ch_out * block.expansion
160
+
161
+ def forward(self, x):
162
+ out = x
163
+ for block in self.blocks:
164
+ out = block(out)
165
+ return out
166
+
167
+
168
+ @register()
169
+ class PResNet(nn.Module):
170
+ def __init__(
171
+ self,
172
+ depth,
173
+ variant="d",
174
+ num_stages=4,
175
+ return_idx=[0, 1, 2, 3],
176
+ act="relu",
177
+ freeze_at=-1,
178
+ freeze_norm=True,
179
+ pretrained=False,
180
+ ):
181
+ super().__init__()
182
+
183
+ block_nums = ResNet_cfg[depth]
184
+ ch_in = 64
185
+ if variant in ["c", "d"]:
186
+ conv_def = [
187
+ [3, ch_in // 2, 3, 2, "conv1_1"],
188
+ [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
189
+ [ch_in // 2, ch_in, 3, 1, "conv1_3"],
190
+ ]
191
+ else:
192
+ conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
193
+
194
+ self.conv1 = nn.Sequential(
195
+ OrderedDict(
196
+ [
197
+ (name, ConvNormLayer(cin, cout, k, s, act=act))
198
+ for cin, cout, k, s, name in conv_def
199
+ ]
200
+ )
201
+ )
202
+
203
+ ch_out_list = [64, 128, 256, 512]
204
+ block = BottleNeck if depth >= 50 else BasicBlock
205
+
206
+ _out_channels = [block.expansion * v for v in ch_out_list]
207
+ _out_strides = [4, 8, 16, 32]
208
+
209
+ self.res_layers = nn.ModuleList()
210
+ for i in range(num_stages):
211
+ stage_num = i + 2
212
+ self.res_layers.append(
213
+ Blocks(
214
+ block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant
215
+ )
216
+ )
217
+ ch_in = _out_channels[i]
218
+
219
+ self.return_idx = return_idx
220
+ self.out_channels = [_out_channels[_i] for _i in return_idx]
221
+ self.out_strides = [_out_strides[_i] for _i in return_idx]
222
+
223
+ if freeze_at >= 0:
224
+ self._freeze_parameters(self.conv1)
225
+ for i in range(min(freeze_at, num_stages)):
226
+ self._freeze_parameters(self.res_layers[i])
227
+
228
+ if freeze_norm:
229
+ self._freeze_norm(self)
230
+
231
+ if pretrained:
232
+ if isinstance(pretrained, bool) or "http" in pretrained:
233
+ state = torch.hub.load_state_dict_from_url(
234
+ donwload_url[depth], map_location="cpu", model_dir="weight"
235
+ )
236
+ else:
237
+ state = torch.load(pretrained, map_location="cpu")
238
+ self.load_state_dict(state)
239
+ print(f"Load PResNet{depth} state_dict")
240
+
241
+ def _freeze_parameters(self, m: nn.Module):
242
+ for p in m.parameters():
243
+ p.requires_grad = False
244
+
245
+ def _freeze_norm(self, m: nn.Module):
246
+ if isinstance(m, nn.BatchNorm2d):
247
+ m = FrozenBatchNorm2d(m.num_features)
248
+ else:
249
+ for name, child in m.named_children():
250
+ _child = self._freeze_norm(child)
251
+ if _child is not child:
252
+ setattr(m, name, _child)
253
+ return m
254
+
255
+ def forward(self, x):
256
+ conv1 = self.conv1(x)
257
+ x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
258
+ outs = []
259
+ for idx, stage in enumerate(self.res_layers):
260
+ x = stage(x)
261
+ if idx in self.return_idx:
262
+ outs.append(x)
263
+ return outs
src/nn/backbone/test_resnet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from ...core import register
8
+
9
+
10
+ class BasicBlock(nn.Module):
11
+ expansion = 1
12
+
13
+ def __init__(self, in_planes, planes, stride=1):
14
+ super(BasicBlock, self).__init__()
15
+
16
+ self.conv1 = nn.Conv2d(
17
+ in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
18
+ )
19
+ self.bn1 = nn.BatchNorm2d(planes)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+
24
+ self.shortcut = nn.Sequential()
25
+ if stride != 1 or in_planes != self.expansion * planes:
26
+ self.shortcut = nn.Sequential(
27
+ nn.Conv2d(
28
+ in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False
29
+ ),
30
+ nn.BatchNorm2d(self.expansion * planes),
31
+ )
32
+
33
+ def forward(self, x):
34
+ out = F.relu(self.bn1(self.conv1(x)))
35
+ out = self.bn2(self.conv2(out))
36
+ out += self.shortcut(x)
37
+ out = F.relu(out)
38
+ return out
39
+
40
+
41
+ class _ResNet(nn.Module):
42
+ def __init__(self, block, num_blocks, num_classes=10):
43
+ super().__init__()
44
+ self.in_planes = 64
45
+
46
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
47
+ self.bn1 = nn.BatchNorm2d(64)
48
+
49
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
50
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
51
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
52
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
53
+
54
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
55
+
56
+ def _make_layer(self, block, planes, num_blocks, stride):
57
+ strides = [stride] + [1] * (num_blocks - 1)
58
+ layers = []
59
+ for stride in strides:
60
+ layers.append(block(self.in_planes, planes, stride))
61
+ self.in_planes = planes * block.expansion
62
+ return nn.Sequential(*layers)
63
+
64
+ def forward(self, x):
65
+ out = F.relu(self.bn1(self.conv1(x)))
66
+ out = self.layer1(out)
67
+ out = self.layer2(out)
68
+ out = self.layer3(out)
69
+ out = self.layer4(out)
70
+ out = F.avg_pool2d(out, 4)
71
+ out = out.view(out.size(0), -1)
72
+ out = self.linear(out)
73
+ return out
74
+
75
+
76
+ @register()
77
+ class MResNet(nn.Module):
78
+ def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None:
79
+ super().__init__()
80
+ self.model = _ResNet(BasicBlock, num_blocks, num_classes)
81
+
82
+ def forward(self, x):
83
+ return self.model(x)
src/nn/backbone/timm_model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copyright(c) 2023 lyuwenyu. All Rights Reserved.
2
+
3
+ https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583
4
+ """
5
+
6
+ import torch
7
+ from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
8
+
9
+ from ...core import register
10
+ from .utils import IntermediateLayerGetter
11
+
12
+
13
+ @register()
14
+ class TimmModel(torch.nn.Module):
15
+ def __init__(
16
+ self, name, return_layers, pretrained=False, exportable=True, features_only=True, **kwargs
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+ import timm
21
+
22
+ model = timm.create_model(
23
+ name,
24
+ pretrained=pretrained,
25
+ exportable=exportable,
26
+ features_only=features_only,
27
+ **kwargs,
28
+ )
29
+ # nodes, _ = get_graph_node_names(model)
30
+ # print(nodes)
31
+ # features = {'': ''}
32
+ # model = create_feature_extractor(model, return_nodes=features)
33
+
34
+ assert set(return_layers).issubset(
35
+ model.feature_info.module_name()
36
+ ), f"return_layers should be a subset of {model.feature_info.module_name()}"
37
+
38
+ # self.model = model
39
+ self.model = IntermediateLayerGetter(model, return_layers)
40
+
41
+ return_idx = [model.feature_info.module_name().index(name) for name in return_layers]
42
+ self.strides = [model.feature_info.reduction()[i] for i in return_idx]
43
+ self.channels = [model.feature_info.channels()[i] for i in return_idx]
44
+ self.return_idx = return_idx
45
+ self.return_layers = return_layers
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ outputs = self.model(x)
49
+ # outputs = [outputs[i] for i in self.return_idx]
50
+ return outputs
51
+
52
+
53
+ if __name__ == "__main__":
54
+ model = TimmModel(name="resnet34", return_layers=["layer2", "layer3"])
55
+ data = torch.rand(1, 3, 640, 640)
56
+ outputs = model(data)
57
+
58
+ for output in outputs:
59
+ print(output.shape)
60
+
61
+ """
62
+ model:
63
+ type: TimmModel
64
+ name: resnet34
65
+ return_layers: ['layer2', 'layer4']
66
+ """
src/nn/backbone/torchvision_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+ from ...core import register
10
+ from .utils import IntermediateLayerGetter
11
+
12
+ __all__ = ["TorchVisionModel"]
13
+
14
+
15
+ @register()
16
+ class TorchVisionModel(torch.nn.Module):
17
+ def __init__(self, name, return_layers, weights=None, **kwargs) -> None:
18
+ super().__init__()
19
+
20
+ if weights is not None:
21
+ weights = getattr(torchvision.models.get_model_weights(name), weights)
22
+
23
+ model = torchvision.models.get_model(name, weights=weights, **kwargs)
24
+
25
+ # TODO hard code.
26
+ if hasattr(model, "features"):
27
+ model = IntermediateLayerGetter(model.features, return_layers)
28
+ else:
29
+ model = IntermediateLayerGetter(model, return_layers)
30
+
31
+ self.model = model
32
+
33
+ def forward(self, x):
34
+ return self.model(x)
35
+
36
+
37
+ # TorchVisionModel('swin_t', return_layers=['5', '7'])
38
+ # TorchVisionModel('resnet34', return_layers=['layer2','layer3', 'layer4'])
39
+
40
+ # TorchVisionModel:
41
+ # name: swin_t
42
+ # return_layers: ['5', '7']
43
+ # weights: DEFAULT
44
+
45
+
46
+ # model:
47
+ # type: TorchVisionModel
48
+ # name: resnet34
49
+ # return_layers: ['layer2','layer3', 'layer4']
50
+ # weights: DEFAULT
src/nn/backbone/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py
3
+
4
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from typing import Dict, List
9
+
10
+ import torch.nn as nn
11
+
12
+
13
+ class IntermediateLayerGetter(nn.ModuleDict):
14
+ """
15
+ Module wrapper that returns intermediate layers from a model
16
+
17
+ It has a strong assumption that the modules have been registered
18
+ into the model in the same order as they are used.
19
+ This means that one should **not** reuse the same nn.Module
20
+ twice in the forward if you want this to work.
21
+
22
+ Additionally, it is only able to query submodules that are directly
23
+ assigned to the model. So if `model` is passed, `model.feature1` can
24
+ be returned, but not `model.feature1.layer2`.
25
+ """
26
+
27
+ _version = 3
28
+
29
+ def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
30
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
31
+ raise ValueError(
32
+ "return_layers are not present in model. {}".format(
33
+ [name for name, _ in model.named_children()]
34
+ )
35
+ )
36
+ orig_return_layers = return_layers
37
+ return_layers = {str(k): str(k) for k in return_layers}
38
+ layers = OrderedDict()
39
+ for name, module in model.named_children():
40
+ layers[name] = module
41
+ if name in return_layers:
42
+ del return_layers[name]
43
+ if not return_layers:
44
+ break
45
+
46
+ super().__init__(layers)
47
+ self.return_layers = orig_return_layers
48
+
49
+ def forward(self, x):
50
+ outputs = []
51
+ for name, module in self.items():
52
+ x = module(x)
53
+ if name in self.return_layers:
54
+ outputs.append(x)
55
+
56
+ return outputs
src/nn/criterion/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch.nn as nn
7
+
8
+ from ...core import register
9
+ from .det_criterion import DetCriterion
10
+
11
+ CrossEntropyLoss = register()(nn.CrossEntropyLoss)
src/nn/criterion/det_criterion.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torch.distributed
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+
11
+ from ...core import register
12
+ from ...misc import box_ops, dist_utils
13
+
14
+
15
+ @register()
16
+ class DetCriterion(torch.nn.Module):
17
+ """Default Detection Criterion"""
18
+
19
+ __share__ = ["num_classes"]
20
+ __inject__ = ["matcher"]
21
+
22
+ def __init__(
23
+ self,
24
+ losses,
25
+ weight_dict,
26
+ num_classes=80,
27
+ alpha=0.75,
28
+ gamma=2.0,
29
+ box_fmt="cxcywh",
30
+ matcher=None,
31
+ ):
32
+ """
33
+ Args:
34
+ losses (list[str]): requested losses, support ['boxes', 'vfl', 'focal']
35
+ weight_dict (dict[str, float)]: corresponding losses weight, including
36
+ ['loss_bbox', 'loss_giou', 'loss_vfl', 'loss_focal']
37
+ box_fmt (str): in box format, 'cxcywh' or 'xyxy'
38
+ matcher (Matcher): matcher used to match source to target
39
+ """
40
+ super().__init__()
41
+ self.losses = losses
42
+ self.weight_dict = weight_dict
43
+ self.alpha = alpha
44
+ self.gamma = gamma
45
+ self.num_classes = num_classes
46
+ self.box_fmt = box_fmt
47
+ assert matcher is not None, ""
48
+ self.matcher = matcher
49
+
50
+ def forward(self, outputs, targets, **kwargs):
51
+ """
52
+ Args:
53
+ outputs: Dict[Tensor], 'pred_boxes', 'pred_logits', 'meta'.
54
+ targets, List[Dict[str, Tensor]], len(targets) == batch_size.
55
+ kwargs, store other information such as current epoch id.
56
+ Return:
57
+ losses, Dict[str, Tensor]
58
+ """
59
+ matched = self.matcher(outputs, targets)
60
+ values = matched["values"]
61
+ indices = matched["indices"]
62
+ num_boxes = self._get_positive_nums(indices)
63
+
64
+ # Compute all the requested losses
65
+ losses = {}
66
+ for loss in self.losses:
67
+ l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
68
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
69
+ losses.update(l_dict)
70
+ return losses
71
+
72
+ def _get_src_permutation_idx(self, indices):
73
+ # permute predictions following indices
74
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
75
+ src_idx = torch.cat([src for (src, _) in indices])
76
+ return batch_idx, src_idx
77
+
78
+ def _get_tgt_permutation_idx(self, indices):
79
+ # permute targets following indices
80
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
81
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
82
+ return batch_idx, tgt_idx
83
+
84
+ def _get_positive_nums(self, indices):
85
+ # number of positive samples
86
+ num_pos = sum(len(i) for (i, _) in indices)
87
+ num_pos = torch.as_tensor([num_pos], dtype=torch.float32, device=indices[0][0].device)
88
+ if dist_utils.is_dist_available_and_initialized():
89
+ torch.distributed.all_reduce(num_pos)
90
+ num_pos = torch.clamp(num_pos / dist_utils.get_world_size(), min=1).item()
91
+ return num_pos
92
+
93
+ def loss_labels_focal(self, outputs, targets, indices, num_boxes):
94
+ assert "pred_logits" in outputs
95
+ src_logits = outputs["pred_logits"]
96
+
97
+ idx = self._get_src_permutation_idx(indices)
98
+ target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
99
+ target_classes = torch.full(
100
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
101
+ )
102
+ target_classes[idx] = target_classes_o
103
+
104
+ target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1].to(
105
+ src_logits.dtype
106
+ )
107
+ loss = torchvision.ops.sigmoid_focal_loss(
108
+ src_logits, target, self.alpha, self.gamma, reduction="none"
109
+ )
110
+ loss = loss.sum() / num_boxes
111
+ return {"loss_focal": loss}
112
+
113
+ def loss_labels_vfl(self, outputs, targets, indices, num_boxes):
114
+ assert "pred_boxes" in outputs
115
+ idx = self._get_src_permutation_idx(indices)
116
+
117
+ src_boxes = outputs["pred_boxes"][idx]
118
+ target_boxes = torch.cat([t["boxes"][j] for t, (_, j) in zip(targets, indices)], dim=0)
119
+
120
+ src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
121
+ target_boxes = torchvision.ops.box_convert(
122
+ target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
123
+ )
124
+ iou, _ = box_ops.elementwise_box_iou(src_boxes.detach(), target_boxes)
125
+
126
+ src_logits: torch.Tensor = outputs["pred_logits"]
127
+ target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)])
128
+ target_classes = torch.full(
129
+ src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
130
+ )
131
+ target_classes[idx] = target_classes_o
132
+ target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
133
+
134
+ target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
135
+ target_score_o[idx] = iou.to(src_logits.dtype)
136
+ target_score = target_score_o.unsqueeze(-1) * target
137
+
138
+ src_score = F.sigmoid(src_logits.detach())
139
+ weight = self.alpha * src_score.pow(self.gamma) * (1 - target) + target_score
140
+
141
+ loss = F.binary_cross_entropy_with_logits(
142
+ src_logits, target_score, weight=weight, reduction="none"
143
+ )
144
+ loss = loss.sum() / num_boxes
145
+ return {"loss_vfl": loss}
146
+
147
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
148
+ assert "pred_boxes" in outputs
149
+ idx = self._get_src_permutation_idx(indices)
150
+ src_boxes = outputs["pred_boxes"][idx]
151
+ target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
152
+
153
+ losses = {}
154
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
155
+ losses["loss_bbox"] = loss_bbox.sum() / num_boxes
156
+
157
+ src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
158
+ target_boxes = torchvision.ops.box_convert(
159
+ target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
160
+ )
161
+ loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
162
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
163
+ return losses
164
+
165
+ def loss_boxes_giou(self, outputs, targets, indices, num_boxes):
166
+ assert "pred_boxes" in outputs
167
+ idx = self._get_src_permutation_idx(indices)
168
+ src_boxes = outputs["pred_boxes"][idx]
169
+ target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
170
+
171
+ losses = {}
172
+ src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
173
+ target_boxes = torchvision.ops.box_convert(
174
+ target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy"
175
+ )
176
+ loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes)
177
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
178
+ return losses
179
+
180
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
181
+ loss_map = {
182
+ "boxes": self.loss_boxes,
183
+ "giou": self.loss_boxes_giou,
184
+ "vfl": self.loss_labels_vfl,
185
+ "focal": self.loss_labels_focal,
186
+ }
187
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
188
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
src/nn/postprocessor/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from .nms_postprocessor import DetNMSPostProcessor
src/nn/postprocessor/box_revert.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from enum import Enum
7
+
8
+ import torch
9
+ import torchvision
10
+ from torch import Tensor
11
+
12
+
13
+ class BoxProcessFormat(Enum):
14
+ """Box process format
15
+
16
+ Available formats are
17
+ * ``RESIZE``
18
+ * ``RESIZE_KEEP_RATIO``
19
+ * ``RESIZE_KEEP_RATIO_PADDING``
20
+ """
21
+
22
+ RESIZE = 1
23
+ RESIZE_KEEP_RATIO = 2
24
+ RESIZE_KEEP_RATIO_PADDING = 3
25
+
26
+
27
+ def box_revert(
28
+ boxes: Tensor,
29
+ orig_sizes: Tensor = None,
30
+ eval_sizes: Tensor = None,
31
+ inpt_sizes: Tensor = None,
32
+ inpt_padding: Tensor = None,
33
+ normalized: bool = True,
34
+ in_fmt: str = "cxcywh",
35
+ out_fmt: str = "xyxy",
36
+ process_fmt=BoxProcessFormat.RESIZE,
37
+ ) -> Tensor:
38
+ """
39
+ Args:
40
+ boxes(Tensor), [N, :, 4], (x1, y1, x2, y2), pred boxes.
41
+ inpt_sizes(Tensor), [N, 2], (w, h). input sizes.
42
+ orig_sizes(Tensor), [N, 2], (w, h). origin sizes.
43
+ inpt_padding (Tensor), [N, 2], (w_pad, h_pad, ...).
44
+ (inpt_sizes + inpt_padding) == eval_sizes
45
+ """
46
+ assert in_fmt in ("cxcywh", "xyxy"), ""
47
+
48
+ if normalized and eval_sizes is not None:
49
+ boxes = boxes * eval_sizes.repeat(1, 2).unsqueeze(1)
50
+
51
+ if inpt_padding is not None:
52
+ if in_fmt == "xyxy":
53
+ boxes -= inpt_padding[:, :2].repeat(1, 2).unsqueeze(1)
54
+ elif in_fmt == "cxcywh":
55
+ boxes[..., :2] -= inpt_padding[:, :2].repeat(1, 2).unsqueeze(1)
56
+
57
+ if orig_sizes is not None:
58
+ orig_sizes = orig_sizes.repeat(1, 2).unsqueeze(1)
59
+ if inpt_sizes is not None:
60
+ inpt_sizes = inpt_sizes.repeat(1, 2).unsqueeze(1)
61
+ boxes = boxes * (orig_sizes / inpt_sizes)
62
+ else:
63
+ boxes = boxes * orig_sizes
64
+
65
+ boxes = torchvision.ops.box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt)
66
+ return boxes
src/nn/postprocessor/detr_postprocessor.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+
11
+ __all__ = ["DetDETRPostProcessor"]
12
+
13
+ from .box_revert import BoxProcessFormat, box_revert
14
+
15
+
16
+ def mod(a, b):
17
+ out = a - a // b * b
18
+ return out
19
+
20
+
21
+ class DetDETRPostProcessor(nn.Module):
22
+ def __init__(
23
+ self,
24
+ num_classes=80,
25
+ use_focal_loss=True,
26
+ num_top_queries=300,
27
+ box_process_format=BoxProcessFormat.RESIZE,
28
+ ) -> None:
29
+ super().__init__()
30
+ self.use_focal_loss = use_focal_loss
31
+ self.num_top_queries = num_top_queries
32
+ self.num_classes = int(num_classes)
33
+ self.box_process_format = box_process_format
34
+ self.deploy_mode = False
35
+
36
+ def extra_repr(self) -> str:
37
+ return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}"
38
+
39
+ def forward(self, outputs, **kwargs):
40
+ logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
41
+
42
+ if self.use_focal_loss:
43
+ scores = F.sigmoid(logits)
44
+ scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1)
45
+ labels = index % self.num_classes
46
+ # labels = mod(index, self.num_classes) # for tensorrt
47
+ index = index // self.num_classes
48
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
49
+
50
+ else:
51
+ scores = F.softmax(logits)[:, :, :-1]
52
+ scores, labels = scores.max(dim=-1)
53
+ if scores.shape[1] > self.num_top_queries:
54
+ scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
55
+ labels = torch.gather(labels, dim=1, index=index)
56
+ boxes = torch.gather(
57
+ boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])
58
+ )
59
+
60
+ if kwargs is not None:
61
+ boxes = box_revert(
62
+ boxes,
63
+ in_fmt="cxcywh",
64
+ out_fmt="xyxy",
65
+ process_fmt=self.box_process_format,
66
+ normalized=True,
67
+ **kwargs,
68
+ )
69
+
70
+ # TODO for onnx export
71
+ if self.deploy_mode:
72
+ return labels, boxes, scores
73
+
74
+ results = []
75
+ for lab, box, sco in zip(labels, boxes, scores):
76
+ result = dict(labels=lab, boxes=box, scores=sco)
77
+ results.append(result)
78
+
79
+ return results
80
+
81
+ def deploy(
82
+ self,
83
+ ):
84
+ self.eval()
85
+ self.deploy_mode = True
86
+ return self
src/nn/postprocessor/nms_postprocessor.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
3
+ Copyright(c) 2023 lyuwenyu. All Rights Reserved.
4
+ """
5
+
6
+ from typing import Dict
7
+
8
+ import torch
9
+ import torch.distributed
10
+ import torch.nn.functional as F
11
+ import torchvision
12
+ from torch import Tensor
13
+
14
+ from ...core import register
15
+
16
+ __all__ = [
17
+ "DetNMSPostProcessor",
18
+ ]
19
+
20
+
21
+ @register()
22
+ class DetNMSPostProcessor(torch.nn.Module):
23
+ def __init__(
24
+ self,
25
+ iou_threshold=0.7,
26
+ score_threshold=0.01,
27
+ keep_topk=300,
28
+ box_fmt="cxcywh",
29
+ logit_fmt="sigmoid",
30
+ ) -> None:
31
+ super().__init__()
32
+ self.iou_threshold = iou_threshold
33
+ self.score_threshold = score_threshold
34
+ self.keep_topk = keep_topk
35
+ self.box_fmt = box_fmt.lower()
36
+ self.logit_fmt = logit_fmt.lower()
37
+ self.logit_func = getattr(F, self.logit_fmt, None)
38
+ self.deploy_mode = False
39
+
40
+ def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor):
41
+ logits, boxes = outputs["pred_logits"], outputs["pred_boxes"]
42
+ pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy")
43
+ pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
44
+
45
+ values, pred_labels = torch.max(logits, dim=-1)
46
+
47
+ if self.logit_func:
48
+ pred_scores = self.logit_func(values)
49
+ else:
50
+ pred_scores = values
51
+
52
+ # TODO for onnx export
53
+ if self.deploy_mode:
54
+ blobs = {
55
+ "pred_labels": pred_labels,
56
+ "pred_boxes": pred_boxes,
57
+ "pred_scores": pred_scores,
58
+ }
59
+ return blobs
60
+
61
+ results = []
62
+ for i in range(logits.shape[0]):
63
+ score_keep = pred_scores[i] > self.score_threshold
64
+ pred_box = pred_boxes[i][score_keep]
65
+ pred_label = pred_labels[i][score_keep]
66
+ pred_score = pred_scores[i][score_keep]
67
+
68
+ keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold)
69
+ keep = keep[: self.keep_topk]
70
+
71
+ blob = {
72
+ "labels": pred_label[keep],
73
+ "boxes": pred_box[keep],
74
+ "scores": pred_score[keep],
75
+ }
76
+
77
+ results.append(blob)
78
+
79
+ return results
80
+
81
+ def deploy(
82
+ self,
83
+ ):
84
+ self.eval()
85
+ self.deploy_mode = True
86
+ return self