Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,089 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import copy
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from ._config import BaseConfig
from .workspace import create
from .yaml_utils import load_config, merge_config, merge_dict
class YAMLConfig(BaseConfig):
def __init__(self, cfg_path: str, **kwargs) -> None:
super().__init__()
cfg = load_config(cfg_path)
cfg = merge_dict(cfg, kwargs)
self.yaml_cfg = copy.deepcopy(cfg)
for k in super().__dict__:
if not k.startswith("_") and k in cfg:
self.__dict__[k] = cfg[k]
@property
def global_cfg(self):
return merge_config(self.yaml_cfg, inplace=False, overwrite=False)
@property
def model(self) -> torch.nn.Module:
if self._model is None and "model" in self.yaml_cfg:
self._model = create(self.yaml_cfg["model"], self.global_cfg)
return super().model
@property
def postprocessor(self) -> torch.nn.Module:
if self._postprocessor is None and "postprocessor" in self.yaml_cfg:
self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg)
return super().postprocessor
@property
def criterion(self) -> torch.nn.Module:
if self._criterion is None and "criterion" in self.yaml_cfg:
self._criterion = create(self.yaml_cfg["criterion"], self.global_cfg)
return super().criterion
@property
def optimizer(self) -> optim.Optimizer:
if self._optimizer is None and "optimizer" in self.yaml_cfg:
params = self.get_optim_params(self.yaml_cfg["optimizer"], self.model)
self._optimizer = create("optimizer", self.global_cfg, params=params)
return super().optimizer
@property
def lr_scheduler(self) -> optim.lr_scheduler.LRScheduler:
if self._lr_scheduler is None and "lr_scheduler" in self.yaml_cfg:
self._lr_scheduler = create("lr_scheduler", self.global_cfg, optimizer=self.optimizer)
print(f"Initial lr: {self._lr_scheduler.get_last_lr()}")
return super().lr_scheduler
@property
def lr_warmup_scheduler(self) -> optim.lr_scheduler.LRScheduler:
if self._lr_warmup_scheduler is None and "lr_warmup_scheduler" in self.yaml_cfg:
self._lr_warmup_scheduler = create(
"lr_warmup_scheduler", self.global_cfg, lr_scheduler=self.lr_scheduler
)
return super().lr_warmup_scheduler
@property
def train_dataloader(self) -> DataLoader:
if self._train_dataloader is None and "train_dataloader" in self.yaml_cfg:
self._train_dataloader = self.build_dataloader("train_dataloader")
return super().train_dataloader
@property
def val_dataloader(self) -> DataLoader:
if self._val_dataloader is None and "val_dataloader" in self.yaml_cfg:
self._val_dataloader = self.build_dataloader("val_dataloader")
return super().val_dataloader
@property
def ema(self) -> torch.nn.Module:
if self._ema is None and self.yaml_cfg.get("use_ema", False):
self._ema = create("ema", self.global_cfg, model=self.model)
return super().ema
@property
def scaler(self):
if self._scaler is None and self.yaml_cfg.get("use_amp", False):
self._scaler = create("scaler", self.global_cfg)
return super().scaler
@property
def evaluator(self):
if self._evaluator is None and "evaluator" in self.yaml_cfg:
if self.yaml_cfg["evaluator"]["type"] == "CocoEvaluator":
from ..data import get_coco_api_from_dataset
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
self._evaluator = create("evaluator", self.global_cfg, coco_gt=base_ds)
else:
raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}")
return super().evaluator
@property
def use_wandb(self) -> bool:
return self.yaml_cfg.get("use_wandb", False)
@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
"""
E.g.:
^(?=.*a)(?=.*b).*$ means including a and b
^(?=.*(?:a|b)).*$ means including a or b
^(?=.*a)(?!.*b).*$ means including a, but not b
"""
assert "type" in cfg, ""
cfg = copy.deepcopy(cfg)
if "params" not in cfg:
return model.parameters()
assert isinstance(cfg["params"], list), ""
param_groups = []
visited = []
for pg in cfg["params"]:
pattern = pg["params"]
params = {
k: v
for k, v in model.named_parameters()
if v.requires_grad and len(re.findall(pattern, k)) > 0
}
pg["params"] = params.values()
param_groups.append(pg)
visited.extend(list(params.keys()))
# print(params.keys())
names = [k for k, v in model.named_parameters() if v.requires_grad]
if len(visited) < len(names):
unseen = set(names) - set(visited)
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
param_groups.append({"params": params.values()})
visited.extend(list(params.keys()))
# print(params.keys())
assert len(visited) == len(names), ""
return param_groups
@staticmethod
def get_rank_batch_size(cfg):
"""compute batch size for per rank if total_batch_size is provided."""
assert ("total_batch_size" in cfg or "batch_size" in cfg) and not (
"total_batch_size" in cfg and "batch_size" in cfg
), "`batch_size` or `total_batch_size` should be choosed one"
total_batch_size = cfg.get("total_batch_size", None)
if total_batch_size is None:
bs = cfg.get("batch_size")
else:
from ..misc import dist_utils
assert (
total_batch_size % dist_utils.get_world_size() == 0
), "total_batch_size should be divisible by world size"
bs = total_batch_size // dist_utils.get_world_size()
return bs
def build_dataloader(self, name: str):
bs = self.get_rank_batch_size(self.yaml_cfg[name])
global_cfg = self.global_cfg
if "total_batch_size" in global_cfg[name]:
# pop unexpected key for dataloader init
_ = global_cfg[name].pop("total_batch_size")
print(f"building {name} with batch_size={bs}...")
loader = create(name, global_cfg, batch_size=bs)
loader.shuffle = self.yaml_cfg[name].get("shuffle", False)
return loader
|