|
""" |
|
Optimizer |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import torch |
|
from pointcept.utils.logger import get_root_logger |
|
from pointcept.utils.registry import Registry |
|
|
|
OPTIMIZERS = Registry("optimizers") |
|
|
|
|
|
OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") |
|
OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") |
|
OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") |
|
|
|
|
|
def build_optimizer(cfg, model, param_dicts=None): |
|
if param_dicts is None: |
|
cfg.params = model.parameters() |
|
else: |
|
cfg.params = [dict(names=[], params=[], lr=cfg.lr)] |
|
for i in range(len(param_dicts)): |
|
param_group = dict(names=[], params=[]) |
|
if "lr" in param_dicts[i].keys(): |
|
param_group["lr"] = param_dicts[i].lr |
|
if "momentum" in param_dicts[i].keys(): |
|
param_group["momentum"] = param_dicts[i].momentum |
|
if "weight_decay" in param_dicts[i].keys(): |
|
param_group["weight_decay"] = param_dicts[i].weight_decay |
|
cfg.params.append(param_group) |
|
|
|
for n, p in model.named_parameters(): |
|
flag = False |
|
for i in range(len(param_dicts)): |
|
if param_dicts[i].keyword in n: |
|
cfg.params[i + 1]["names"].append(n) |
|
cfg.params[i + 1]["params"].append(p) |
|
flag = True |
|
break |
|
if not flag: |
|
cfg.params[0]["names"].append(n) |
|
cfg.params[0]["params"].append(p) |
|
|
|
logger = get_root_logger() |
|
for i in range(len(cfg.params)): |
|
param_names = cfg.params[i].pop("names") |
|
message = "" |
|
for key in cfg.params[i].keys(): |
|
if key != "params": |
|
message += f" {key}: {cfg.params[i][key]};" |
|
logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") |
|
return OPTIMIZERS.build(cfg=cfg) |
|
|