Spaces:
Sleeping
Sleeping
File size: 644 Bytes
6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
def build_optimizer(args, model):
ve_params = list(map(id, model.visual_extractor.parameters()))
ed_params = filter(lambda x: id(x) not in ve_params, model.parameters())
optimizer = getattr(torch.optim, args.optim)(
[{'params': model.visual_extractor.parameters(), 'lr': args.lr_ve},
{'params': ed_params, 'lr': args.lr_ed}],
weight_decay=args.weight_decay,
amsgrad=args.amsgrad
)
return optimizer
def build_lr_scheduler(args, optimizer):
lr_scheduler = getattr(torch.optim.lr_scheduler, args.lr_scheduler)(optimizer, args.step_size, args.gamma)
return lr_scheduler
|