hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Author: ximing
# Description: SVGDreamer - optim
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
from functools import partial
import torch
from omegaconf import DictConfig
def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
param_dict = {}
if optimizer_name == "adam":
optimizer = partial(torch.optim.Adam, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
if config.get('betas'):
param_dict['betas'] = config.betas
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
if config.get('eps'):
param_dict['eps'] = config.eps
elif optimizer_name == "adamW":
optimizer = partial(torch.optim.AdamW, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
if config.get('betas'):
param_dict['betas'] = config.betas
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
if config.get('eps'):
param_dict['eps'] = config.eps
elif optimizer_name == "radam":
optimizer = partial(torch.optim.RAdam, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
if config.get('betas'):
param_dict['betas'] = config.betas
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
elif optimizer_name == "sgd":
optimizer = partial(torch.optim.SGD, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
if config.get('momentum'):
param_dict['momentum'] = config.momentum
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
if config.get('nesterov'):
param_dict['nesterov'] = config.nesterov
else:
raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")
if len(param_dict.keys()) > 0:
return optimizer(**param_dict)
else:
return optimizer()