File size: 8,038 Bytes
882f6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

from typing import Optional, Tuple, Sequence, TypeVar, Union, Mapping, Any, List, Dict

import torch as th
import numpy as np

TensorOrContainer = Union[
    th.Tensor, str, int, Sequence["TensorOrContainer"], Mapping[str, "TensorOrContainer"]
]
NdarrayOrContainer = Union[
    np.ndarray,
    str,
    int,
    Sequence["NdarrayOrContainer"],
    Mapping[str, "NdarrayOrContainer"],
]
TensorNdarrayOrContainer = Union[
    th.Tensor,
    np.ndarray,
    str,
    int,
    Sequence["TensorNdarrayOrContainer"],
    Mapping[str, "TensorNdarrayOrContainer"],
]
TensorNdarrayModuleOrContainer = Union[
    th.Tensor,
    np.ndarray,
    th.nn.Module,
    str,
    int,
    Sequence["TensorNdarrayModuleOrContainer"],
    Mapping[str, "TensorNdarrayModuleOrContainer"],
]
TTensorOrContainer = TypeVar("TTensorOrContainer", bound=TensorOrContainer)
TNdarrayOrContainer = TypeVar("TNdarrayOrContainer", bound=NdarrayOrContainer)
TTensorNdarrayOrContainer = TypeVar("TTensorNdarrayOrContainer", bound=TensorNdarrayOrContainer)
TTensorNdarrayModuleOrContainer = TypeVar(
    "TTensorNdarrayModuleOrContainer", bound=TensorNdarrayModuleOrContainer
)


import torch as th

import logging

logger = logging.getLogger(__name__)


class ParamHolder(th.nn.Module):
    def __init__(
        self,
        param_shape: Tuple[int, ...],
        key_list: Sequence[str],
        init_value: Union[None, bool, float, int, th.Tensor] = None,
    ) -> None:
        super().__init__()

        if isinstance(param_shape, int):
            param_shape = (param_shape,)
        self.key_list: Sequence[str] = sorted(key_list)
        shp = (len(self.key_list),) + param_shape
        self.params = th.nn.Parameter(th.zeros(*shp))

        if init_value is not None:
            self.params.data[:] = init_value

    def state_dict(self, *args: Any, saving: bool = False, **kwargs: Any) -> Dict[str, Any]:
        sd = super().state_dict(*args, **kwargs)
        if saving:
            assert "key_list" not in sd
            sd["key_list"] = self.key_list
        return sd

    # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
    #  inconsistently.
    def load_state_dict(
        self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any
    ) -> th.nn.modules.module._IncompatibleKeys:
        # Note: Mapping is immutable while Dict is mutable. According to pyre ErrorCode[14],
        # the type of state_dict must be Mapping or supertype of Mapping to keep consistent
        # with the overrided function in its superclass.
        sd = dict(state_dict)
        if "key_list" not in sd:
            logger.warning("Missing key list list in state dict, only checking params shape.")
            assert sd["params"].shape == self.params.shape
            sd["key_list"] = self.key_list

        matching_kl = sd["key_list"] == self.key_list
        if strict:
            logger.warning("Attempting to load from mismatched key lists.")
        assert sd["params"].shape[1:] == self.params.shape[1:]

        if not matching_kl:
            src_kl = sd["key_list"]
            new_kl = sorted(set(self.key_list) | set(src_kl))
            new_shp = (len(new_kl),) + tuple(self.params.shape[1:])
            new_params = th.zeros(*new_shp, device=self.params.device)
            for f in self.key_list:
                new_params[new_kl.index(f)] = self.params[self.key_list.index(f)]
            upd = 0
            new = 0
            for f in src_kl:
                new_params[new_kl.index(f)] = sd["params"][src_kl.index(f)]
                if f in self.key_list:
                    upd += 1
                else:
                    new += 1
            logger.info(
                f"Updated {upd} keys ({100*upd/len(self.key_list):0.2f}%), added {new} new keys."
            )

            self.key_list = new_kl
            sd["params"] = new_params
            self.params = th.nn.Parameter(new_params)
        del sd["key_list"]
        return super().load_state_dict(sd, strict=strict, **kwargs)

    def to_idx(self, *args: Any) -> th.Tensor:
        if len(args) == 1:
            keys = args[0]
        else:
            keys = zip(*args)

        return th.tensor(
            [self.key_list.index(k) for k in keys],
            dtype=th.long,
            device=self.params.device,
        )

    def from_idx(self, idxs: th.Tensor) -> List[str]:
        return [self.key_list[idx] for idx in idxs]

    def forward(self, idxs: th.Tensor) -> th.Tensor:
        return self.params[idxs]
    


def to_device(
    things: TTensorNdarrayModuleOrContainer,
    device: th.device,
    cache: Optional[Dict[str, th.Tensor]] = None,
    key: Optional[str] = None,
    verbose: bool = False,
    max_bs: Optional[int] = None,
    non_blocking: bool = False,
) -> TTensorNdarrayModuleOrContainer:
    """Sends a potentially nested container of Tensors to the specified
    device. Non-tensors are preserved as-is.

    Args:
        things: Container with tensors or other containers of tensors to send
            to a GPU.

        device: Device to send the tensors to.

        cache: Optional dictionary to use as a cache for CUDAfied tensors. If
            passed, use this cache to allocate a tensor once and then resize /
            refill it on future calls to to_device() instead of reallocating
            it.

        key: If using the cache, store the tensor in this key, only for
            internal use.

        verbose: Print some info when a cached tensor is resized.

        max_bs: Maximum batch size allowed for tensors in cache

        non_blocking: if True and this copy is between CPU and GPU, the copy
            may occur asynchronously with respect to the host. For other cases,
            this argument has no effect.

    Returns:
        collection: The input collection with all tensors transferred to the given device.
    """
    device = th.device(device)

    pr = print if verbose else lambda *args, **kwargs: None

    if isinstance(things, th.Tensor) and things.device != device:
        if cache is not None:
            assert key is not None
            batch_size = things.shape[0]
            if key in cache:
                assert things.shape[1:] == cache[key].shape[1:]
                if batch_size > cache[key].shape[0]:
                    pr("Resized:", key, "from", cache[key].shape[0], "to", batch_size)
                    cache[key].resize_as_(things)
            else:
                buf_shape = list(things.shape)
                if max_bs is not None:
                    assert max_bs >= batch_size
                    buf_shape[0] = max_bs
                cache[key] = th.zeros(*buf_shape, dtype=things.dtype, device=device)
                pr("Allocated:", key, buf_shape)
            cache[key][:batch_size].copy_(things, non_blocking=non_blocking)

            return cache[key][:batch_size]
        else:
            return things.to(device, non_blocking=non_blocking)
    elif isinstance(things, th.nn.Module):
        return things.to(device, non_blocking=non_blocking)
    elif isinstance(things, dict):
        key = key + "." if key is not None else ""
        return {
            k: to_device(v, device, cache, key + k, verbose, max_bs, non_blocking)
            for k, v in things.items()
        }
    elif isinstance(things, Sequence) and not isinstance(things, str):
        key = key if key is not None else ""
        out = [
            to_device(v, device, cache, key + f"_{i}", verbose, max_bs, non_blocking)
            for i, v in enumerate(things)
        ]
        if isinstance(things, tuple):
            out = tuple(out)
        return out
    elif isinstance(things, np.ndarray):
        return to_device(th.from_numpy(things), device, cache, key, verbose, max_bs, non_blocking)
    else:
        return things