Spaces:
Paused
Paused
import torch | |
def get_optimizer( | |
params, | |
optimizer_type='adam', | |
learning_rate=1e-6, | |
optimizer_params=None | |
): | |
if optimizer_params is None: | |
optimizer_params = {} | |
lower_type = optimizer_type.lower() | |
if lower_type.startswith("dadaptation"): | |
# dadaptation optimizer does not use standard learning rate. 1 is the default value | |
import dadaptation | |
print("Using DAdaptAdam optimizer") | |
use_lr = learning_rate | |
if use_lr < 0.1: | |
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 | |
use_lr = 1.0 | |
if lower_type.endswith('lion'): | |
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) | |
elif lower_type.endswith('adam'): | |
optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) | |
elif lower_type == 'dadaptation': | |
# backwards compatibility | |
optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params) | |
# warn user that dadaptation is deprecated | |
print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") | |
elif lower_type.startswith("prodigy8bit"): | |
from toolkit.optimizers.prodigy_8bit import Prodigy8bit | |
print("Using Prodigy optimizer") | |
use_lr = learning_rate | |
if use_lr < 0.1: | |
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 | |
use_lr = 1.0 | |
print(f"Using lr {use_lr}") | |
# let net be the neural network you want to train | |
# you can choose weight decay value based on your problem, 0 by default | |
optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params) | |
elif lower_type.startswith("prodigy"): | |
from prodigyopt import Prodigy | |
print("Using Prodigy optimizer") | |
use_lr = learning_rate | |
if use_lr < 0.1: | |
# dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 | |
use_lr = 1.0 | |
print(f"Using lr {use_lr}") | |
# let net be the neural network you want to train | |
# you can choose weight decay value based on your problem, 0 by default | |
optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params) | |
elif lower_type == "adam8": | |
from toolkit.optimizers.adam8bit import Adam8bit | |
optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) | |
elif lower_type == "adamw8": | |
from toolkit.optimizers.adam8bit import Adam8bit | |
optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params) | |
elif lower_type.endswith("8bit"): | |
import bitsandbytes | |
if lower_type == "adam8bit": | |
return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) | |
if lower_type == "ademamix8bit": | |
return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) | |
elif lower_type == "adamw8bit": | |
return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) | |
elif lower_type == "lion8bit": | |
return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) | |
else: | |
raise ValueError(f'Unknown optimizer type {optimizer_type}') | |
elif lower_type == 'adam': | |
optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) | |
elif lower_type == 'adamw': | |
optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) | |
elif lower_type == 'lion': | |
try: | |
from lion_pytorch import Lion | |
return Lion(params, lr=learning_rate, **optimizer_params) | |
except ImportError: | |
raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") | |
elif lower_type == 'adagrad': | |
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) | |
elif lower_type == 'adafactor': | |
from toolkit.optimizers.adafactor import Adafactor | |
if 'relative_step' not in optimizer_params: | |
optimizer_params['relative_step'] = False | |
if 'scale_parameter' not in optimizer_params: | |
optimizer_params['scale_parameter'] = False | |
if 'warmup_init' not in optimizer_params: | |
optimizer_params['warmup_init'] = False | |
optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) | |
elif lower_type == 'automagic': | |
from toolkit.optimizers.automagic import Automagic | |
optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params) | |
else: | |
raise ValueError(f'Unknown optimizer type {optimizer_type}') | |
return optimizer | |