File size: 2,390 Bytes
5eff22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from collections import OrderedDict
import torch.nn as nn


class ModelBook:
    """Maintain the mapping between modules and their paths.

    Example:
        book = ModelBook(model_ft)
        for p, m in book.conv2d_modules():
            print('path:', p, 'num of filters:', m.out_channels)
            assert m is book.get_module(p)
    """

    def __init__(self, model):
        self._model = model
        self._modules = OrderedDict()
        self._paths = OrderedDict()
        path = []
        self._construct(self._model, path)

    def _construct(self, module, path):
        if not module._modules:
            return
        for name, m in module._modules.items():
            cur_path = tuple(path + [name])
            self._paths[m] = cur_path
            self._modules[cur_path] = m
            self._construct(m, path + [name])

    def conv2d_modules(self):
        return self.modules(nn.Conv2d)

    def linear_modules(self):
        return self.modules(nn.Linear)

    def modules(self, module_type=None):
        for p, m in self._modules.items():
            if not module_type or isinstance(m, module_type):
                yield p, m

    def num_of_conv2d_modules(self):
        return self.num_of_modules(nn.Conv2d)

    def num_of_conv2d_filters(self):
        """Return the sum of out_channels of all conv2d layers.

        Here we treat the sub weight with size of [in_channels, h, w] as a single filter.
        """
        num_filters = 0
        for _, m in self.conv2d_modules():
            num_filters += m.out_channels
        return num_filters

    def num_of_linear_modules(self):
        return self.num_of_modules(nn.Linear)

    def num_of_linear_filters(self):
        num_filters = 0
        for _, m in self.linear_modules():
            num_filters += m.out_features
        return num_filters

    def num_of_modules(self, module_type=None):
        num = 0
        for p, m in self._modules.items():
            if not module_type or isinstance(m, module_type):
                num += 1
        return num

    def get_module(self, path):
        return self._modules.get(path)

    def get_path(self, module):
        return self._paths.get(module)

    def update(self, path, module):
        old_module = self._modules[path]
        del self._paths[old_module]
        self._paths[module] = path
        self._modules[path] = module