File size: 2,831 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import random
import pickle
import numpy as np
import torch
import torch.backends.cudnn as cudnn

def fix_seed(seed):
    # random
    random.seed(seed)
    # numpy
    np.random.seed(seed)
    # pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def check_extension(filename):
    if os.path.splitext(filename)[1] != ".pkl":
        return filename + ".pkl"
    return filename

def save_dataset(dataset, filename, display=True):
    filedir = os.path.split(filename)[0]
    if display:
        print(filename)
    if not os.path.isdir(filedir):
        os.makedirs(filedir)
    with open(check_extension(filename), 'wb') as f:
        pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)

def load_dataset(filename):
    with open(check_extension(filename), 'rb') as f:
        return pickle.load(f)

def split_dataset(dataset, ratio=[8.0, 0.1, 0.1], random_seed=1234):
    assert abs(sum(ratio) - 1) < 1e-9, "sum of ratio should equal to 1."
    num_samples = len(dataset)
    split_size = []
    for i in range(len(ratio)):
        if i == len(ratio) - 1:
            split_size.append(num_samples - sum(split_size))
        else:
            split_size.append(int(num_samples * ratio[i]))
    print(f"split_size = {split_size}")
    return torch.utils.data.random_split(dataset, split_size, generator=torch.Generator().manual_seed(random_seed))

def set_device(gpu):
    """
    Parameters
    ----------
    gpu: int 
        Used GPU #. gpu=-1 indicates using cpu.

    Returns
    -------
    use_cuda: bool
        whether a gpu is used or not
    device: str
        device name
    """
    if gpu >= 0:
        assert torch.cuda.is_available(), "There is no available GPU."
        torch.cuda.set_device(gpu)
        device = f"cuda:{gpu}"
        use_cuda = True
        cudnn.benchmark = True
        print(f'selected device: GPU #{gpu}')
    else:
        device = "cpu"
        use_cuda = False
        print(f'selected device: CPU')
    return use_cuda, device

def calc_tour_length(tour, coords):
    tour_length = []
    for i in range(len(tour) - 1):
        path_length = np.linalg.norm(coords[tour[i]] - coords[tour[i + 1]])
        tour_length.append(path_length)
    tour_length = np.sum(tour_length)
    return tour_length

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def apply_along_axis(func, inputs, dim: int = 0):
    return torch.stack([
        func(input) for input in torch.unbind(inputs, dim=dim)
    ])

def batched_bincount(x, dim, max_value):
    target = torch.zeros(x.shape[0], max_value, dtype=x.dtype, device=x.device)
    values = torch.ones_like(x)
    target.scatter_add_(dim, x, values)
    return target