File size: 3,052 Bytes
d643072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import time
from copy import deepcopy
from typing import Optional

import torch.backends.cudnn
import torch.distributed
import torch.nn as nn

from ..apps.utils import (
    dist_init,
    dump_config,
    get_dist_local_rank,
    get_dist_rank,
    get_dist_size,
    init_modules,
    is_master,
    load_config,
    partial_update_config,
    zero_last_gamma,
)
from ..models.utils import build_kwargs_from_config, load_state_dict_from_file

__all__ = [
    "save_exp_config",
    "setup_dist_env",
    "setup_seed",
    "setup_exp_config",
    "init_model",
]


def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
    if not is_master():
        return
    dump_config(exp_config, os.path.join(path, name))


def setup_dist_env(gpu: Optional[str] = None) -> None:
    if gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    if not torch.distributed.is_initialized():
        dist_init()
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(get_dist_local_rank())


def setup_seed(manual_seed: int, resume: bool) -> None:
    if resume:
        manual_seed = int(time.time())
    manual_seed = get_dist_rank() + manual_seed
    torch.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)


def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict:
    # load config
    if not os.path.isfile(config_path):
        raise ValueError(config_path)

    fpaths = [config_path]
    if recursive:
        extension = os.path.splitext(config_path)[1]
        while os.path.dirname(config_path) != config_path:
            config_path = os.path.dirname(config_path)
            fpath = os.path.join(config_path, "default" + extension)
            if os.path.isfile(fpath):
                fpaths.append(fpath)
        fpaths = fpaths[::-1]

    default_config = load_config(fpaths[0])
    exp_config = deepcopy(default_config)
    for fpath in fpaths[1:]:
        partial_update_config(exp_config, load_config(fpath))
    # update config via args
    if opt_args is not None:
        partial_update_config(exp_config, opt_args)

    return exp_config


def init_model(
    network: nn.Module,
    init_from: Optional[str] = None,
    backbone_init_from: Optional[str] = None,
    rand_init="trunc_normal",
    last_gamma=None,
) -> None:
    # initialization
    init_modules(network, init_type=rand_init)
    # zero gamma of last bn in each block
    if last_gamma is not None:
        zero_last_gamma(network, last_gamma)

    # load weight
    if init_from is not None and os.path.isfile(init_from):
        network.load_state_dict(load_state_dict_from_file(init_from))
        print(f"Loaded init from {init_from}")
    elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
        network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
        print(f"Loaded backbone init from {backbone_init_from}")
    else:
        print(f"Random init ({rand_init}) with last gamma {last_gamma}")