PromptNet / modules /optimizers.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
raw
history blame contribute delete
644 Bytes
import torch
def build_optimizer(args, model):
ve_params = list(map(id, model.visual_extractor.parameters()))
ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
optimizer = getattr(torch.optim, args.optim)(
[{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve},
{'params': ed_params, 'lr': args.lr_ed}],
weight_decay=args.weight_decay,
amsgrad=args.amsgrad
)
return optimizer
def build_lr_scheduler(args, optimizer):
lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma)
return lr_scheduler