Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
import torch.nn as nn | |
class ModelBook: | |
"""Maintain the mapping between modules and their paths. | |
Example: | |
book = ModelBook(model_ft) | |
for p, m in book.conv2d_modules(): | |
print('path:', p, 'num of filters:', m.out_channels) | |
assert m is book.get_module(p) | |
""" | |
def __init__(self, model): | |
self._model = model | |
self._modules = OrderedDict() | |
self._paths = OrderedDict() | |
path = [] | |
self._construct(self._model, path) | |
def _construct(self, module, path): | |
if not module._modules: | |
return | |
for name, m in module._modules.items(): | |
cur_path = tuple(path + [name]) | |
self._paths[m] = cur_path | |
self._modules[cur_path] = m | |
self._construct(m, path + [name]) | |
def conv2d_modules(self): | |
return self.modules(nn.Conv2d) | |
def linear_modules(self): | |
return self.modules(nn.Linear) | |
def modules(self, module_type=None): | |
for p, m in self._modules.items(): | |
if not module_type or isinstance(m, module_type): | |
yield p, m | |
def num_of_conv2d_modules(self): | |
return self.num_of_modules(nn.Conv2d) | |
def num_of_conv2d_filters(self): | |
"""Return the sum of out_channels of all conv2d layers. | |
Here we treat the sub weight with size of [in_channels, h, w] as a single filter. | |
""" | |
num_filters = 0 | |
for _, m in self.conv2d_modules(): | |
num_filters += m.out_channels | |
return num_filters | |
def num_of_linear_modules(self): | |
return self.num_of_modules(nn.Linear) | |
def num_of_linear_filters(self): | |
num_filters = 0 | |
for _, m in self.linear_modules(): | |
num_filters += m.out_features | |
return num_filters | |
def num_of_modules(self, module_type=None): | |
num = 0 | |
for p, m in self._modules.items(): | |
if not module_type or isinstance(m, module_type): | |
num += 1 | |
return num | |
def get_module(self, path): | |
return self._modules.get(path) | |
def get_path(self, module): | |
return self._paths.get(module) | |
def update(self, path, module): | |
old_module = self._modules[path] | |
del self._paths[old_module] | |
self._paths[module] = path | |
self._modules[path] = module | |