Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Author: ximing | |
# Description: SVGDreamer - optim | |
# Copyright (c) 2023, XiMing Xing. | |
# License: MIT License | |
from functools import partial | |
import torch | |
from omegaconf import DictConfig | |
def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None): | |
param_dict = {} | |
if optimizer_name == "adam": | |
optimizer = partial(torch.optim.Adam, params=parameters) | |
if lr is not None: | |
optimizer = partial(torch.optim.Adam, params=parameters, lr=lr) | |
if config.get('betas'): | |
param_dict['betas'] = config.betas | |
if config.get('weight_decay'): | |
param_dict['weight_decay'] = config.weight_decay | |
if config.get('eps'): | |
param_dict['eps'] = config.eps | |
elif optimizer_name == "adamW": | |
optimizer = partial(torch.optim.AdamW, params=parameters) | |
if lr is not None: | |
optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr) | |
if config.get('betas'): | |
param_dict['betas'] = config.betas | |
if config.get('weight_decay'): | |
param_dict['weight_decay'] = config.weight_decay | |
if config.get('eps'): | |
param_dict['eps'] = config.eps | |
elif optimizer_name == "radam": | |
optimizer = partial(torch.optim.RAdam, params=parameters) | |
if lr is not None: | |
optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr) | |
if config.get('betas'): | |
param_dict['betas'] = config.betas | |
if config.get('weight_decay'): | |
param_dict['weight_decay'] = config.weight_decay | |
elif optimizer_name == "sgd": | |
optimizer = partial(torch.optim.SGD, params=parameters) | |
if lr is not None: | |
optimizer = partial(torch.optim.SGD, params=parameters, lr=lr) | |
if config.get('momentum'): | |
param_dict['momentum'] = config.momentum | |
if config.get('weight_decay'): | |
param_dict['weight_decay'] = config.weight_decay | |
if config.get('nesterov'): | |
param_dict['nesterov'] = config.nesterov | |
else: | |
raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.") | |
if len(param_dict.keys()) > 0: | |
return optimizer(**param_dict) | |
else: | |
return optimizer() | |