|
import copy |
|
import torch |
|
|
|
def apply_overrides(params, overrides): |
|
params = copy.deepcopy(params) |
|
for param_name in overrides: |
|
if param_name not in params: |
|
print(f'override failed: no parameter named {param_name}') |
|
raise ValueError |
|
params[param_name] = overrides[param_name] |
|
return params |
|
|
|
def get_default_params_train(overrides={}): |
|
|
|
params = {} |
|
|
|
''' |
|
misc |
|
''' |
|
params['device'] = 'cuda' |
|
params['save_base'] = './experiments/' |
|
params['experiment_name'] = 'demo' |
|
params['timestamp'] = False |
|
|
|
''' |
|
data |
|
''' |
|
params['species_set'] = 'all' |
|
params['hard_cap_seed'] = 9472 |
|
params['hard_cap_num_per_class'] = -1 |
|
params['aux_species_seed'] = 8099 |
|
params['num_aux_species'] = 0 |
|
|
|
''' |
|
model |
|
''' |
|
params['model'] = 'ResidualFCNet' |
|
params['num_filts'] = 256 |
|
params['input_enc'] = 'sin_cos' |
|
params['depth'] = 4 |
|
|
|
''' |
|
loss |
|
''' |
|
params['loss'] = 'an_full' |
|
params['pos_weight'] = 2048 |
|
|
|
''' |
|
optimization |
|
''' |
|
params['batch_size'] = 2048 |
|
params['lr'] = 0.0005 |
|
params['lr_decay'] = 0.98 |
|
params['num_epochs'] = 10 |
|
|
|
''' |
|
saving |
|
''' |
|
params['log_frequency'] = 512 |
|
|
|
params = apply_overrides(params, overrides) |
|
|
|
return params |
|
|
|
def get_default_params_eval(overrides={}): |
|
|
|
params = {} |
|
|
|
''' |
|
misc |
|
''' |
|
params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
params['seed'] = 2022 |
|
params['exp_base'] = './experiments' |
|
params['ckp_name'] = 'model.pt' |
|
params['eval_type'] = 'snt' |
|
params['experiment_name'] = 'demo' |
|
|
|
''' |
|
geo prior |
|
''' |
|
params['batch_size'] = 2048 |
|
|
|
''' |
|
geo feature |
|
''' |
|
params['cell_size'] = 25 |
|
|
|
params = apply_overrides(params, overrides) |
|
|
|
return params |
|
|