Spaces:
Running
Running
File size: 2,291 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# -*- coding: utf-8 -*-
# Author: ximing
# Description: SVGDreamer - optim
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
from functools import partial
import torch
from omegaconf import DictConfig
def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
param_dict = {}
if optimizer_name == "adam":
optimizer = partial(torch.optim.Adam, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
if config.get('betas'):
param_dict['betas'] = config.betas
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
if config.get('eps'):
param_dict['eps'] = config.eps
elif optimizer_name == "adamW":
optimizer = partial(torch.optim.AdamW, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
if config.get('betas'):
param_dict['betas'] = config.betas
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
if config.get('eps'):
param_dict['eps'] = config.eps
elif optimizer_name == "radam":
optimizer = partial(torch.optim.RAdam, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
if config.get('betas'):
param_dict['betas'] = config.betas
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
elif optimizer_name == "sgd":
optimizer = partial(torch.optim.SGD, params=parameters)
if lr is not None:
optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
if config.get('momentum'):
param_dict['momentum'] = config.momentum
if config.get('weight_decay'):
param_dict['weight_decay'] = config.weight_decay
if config.get('nesterov'):
param_dict['nesterov'] = config.nesterov
else:
raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")
if len(param_dict.keys()) > 0:
return optimizer(**param_dict)
else:
return optimizer()
|