File size: 2,122 Bytes
505e401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import copy
import torch

def apply_overrides(params, overrides):
    params = copy.deepcopy(params)
    for param_name in overrides:
        if param_name not in params:
            print(f'override failed: no parameter named {param_name}')
            raise ValueError
        params[param_name] = overrides[param_name]
    return params

def get_default_params_train(overrides={}):

    params = {}
    
    '''
    misc
    '''
    params['device'] = 'cuda' # cuda, cpu
    params['save_base'] = './experiments/'
    params['experiment_name'] = 'demo'
    params['timestamp'] = False
    
    '''
    data
    '''
    params['species_set'] = 'all' # all, snt_birds
    params['hard_cap_seed'] = 9472
    params['hard_cap_num_per_class'] = -1 # -1 for no hard capping
    params['aux_species_seed'] = 8099
    params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in

    '''
    model
    '''
    params['model'] = 'ResidualFCNet'  # ResidualFCNet, LinNet
    params['num_filts'] = 256  # embedding dimension
    params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env
    params['depth'] = 4
    
    '''
    loss
    '''
    params['loss'] = 'an_full' # an_full, an_ssdl, an_slds
    params['pos_weight'] = 2048
    
    '''
    optimization
    '''
    params['batch_size'] = 2048
    params['lr'] = 0.0005
    params['lr_decay'] = 0.98
    params['num_epochs'] = 10
    
    '''
    saving
    '''
    params['log_frequency'] = 512
    
    params = apply_overrides(params, overrides)
    
    return params

def get_default_params_eval(overrides={}):

    params = {}

    '''
    misc
    '''
    params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params['seed'] = 2022
    params['exp_base'] = './experiments'
    params['ckp_name'] = 'model.pt'
    params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature
    params['experiment_name'] = 'demo'

    '''
    geo prior
    '''
    params['batch_size'] = 2048

    '''
    geo feature
    '''
    params['cell_size'] = 25

    params = apply_overrides(params, overrides)

    return params