File size: 4,185 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from typing import Optional

import torch
from omegaconf import DictConfig

_ACCESS_CFG = DictConfig({"detach": False, "convert_to_cpu": False})
_ACCESS_ENABLED = False


def set_access_cfg(cfg: 'DictConfig'):
    if cfg is None or not isinstance(cfg, DictConfig):
        raise TypeError(f"cfg must be a DictConfig")
    global _ACCESS_CFG
    _ACCESS_CFG = cfg


class AccessMixin(ABC):
    """
    Allows access to output of intermediate layers of a model
    """

    def __init__(self):
        super().__init__()
        self._registry = {}  # dictionary of lists

    def register_accessible_tensor(self, name, tensor):
        """
        Register tensor for later use.
        """
        if self.access_cfg.get('convert_to_cpu', False):
            tensor = tensor.cpu()

        if self.access_cfg.get('detach', False):
            tensor = tensor.detach()

        if not hasattr(self, '_registry'):
            self._registry = {}

        if name not in self._registry:
            self._registry[name] = []

        self._registry[name].append(tensor)

    @classmethod
    def get_module_registry(cls, module: torch.nn.Module):
        """
        Extract all registries from named submodules, return dictionary where
        the keys are the flattened module names, the values are the internal registry
        of each such module.
        """
        module_registry = {}
        for name, m in module.named_modules():
            if hasattr(m, '_registry') and len(m._registry) > 0:
                module_registry[name] = m._registry
        return module_registry

    def reset_registry(self: torch.nn.Module, registry_key: Optional[str] = None):
        """
        Reset the registries of all named sub-modules
        """
        if hasattr(self, "_registry"):
            if registry_key is None:
                self._registry.clear()
            else:
                if registry_key in self._registry:
                    self._registry.pop(registry_key)
                else:
                    raise KeyError(
                        f"Registry key `{registry_key}` provided, but registry does not have this key.\n"
                        f"Available keys in registry : {list(self._registry.keys())}"
                    )

        for _, m in self.named_modules():
            if hasattr(m, "_registry"):
                if registry_key is None:
                    m._registry.clear()
                else:
                    if registry_key in self._registry:
                        self._registry.pop(registry_key)
                    else:
                        raise KeyError(
                            f"Registry key `{registry_key}` provided, but registry does not have this key.\n"
                            f"Available keys in registry : {list(self._registry.keys())}"
                        )

        # Explicitly disable registry cache after reset
        AccessMixin.set_access_enabled(access_enabled=False)

    @property
    def access_cfg(self):
        """
        Returns:
            The global access config shared across all access mixin modules.
        """
        global _ACCESS_CFG
        return _ACCESS_CFG

    @classmethod
    def update_access_cfg(cls, cfg: dict):
        global _ACCESS_CFG
        _ACCESS_CFG.update(cfg)

    @classmethod
    def is_access_enabled(cls):
        global _ACCESS_ENABLED
        return _ACCESS_ENABLED

    @classmethod
    def set_access_enabled(cls, access_enabled: bool):
        global _ACCESS_ENABLED
        _ACCESS_ENABLED = access_enabled