File size: 14,899 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
from detectron2.checkpoint import DetectionCheckpointer

from typing import Any
import torch
import torch.nn as nn
from fvcore.common.checkpoint import _IncompatibleKeys, _strip_prefix_if_present, TORCH_VERSION, quantization, \
    ObserverBase, FakeQuantizeBase
from torch import distributed as dist
from scipy import interpolate
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict


def append_prefix(k):
    prefix = 'backbone.bottom_up.backbone.'
    return prefix + k if not k.startswith(prefix) else k


def modify_ckpt_state(model, state_dict, logger=None):
    # reshape absolute position embedding for Swin
    if state_dict.get(append_prefix('absolute_pos_embed')) is not None:
        absolute_pos_embed = state_dict[append_prefix('absolute_pos_embed')]
        N1, L, C1 = absolute_pos_embed.size()
        N2, C2, H, W = model.backbone.bottom_up.backbone.absolute_pos_embed.size()
        if N1 != N2 or C1 != C2 or L != H * W:
            logger.warning("Error in loading absolute_pos_embed, pass")
        else:
            state_dict[append_prefix('absolute_pos_embed')] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)

    def get_dist_info():
        if dist.is_available() and dist.is_initialized():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        else:
            rank = 0
            world_size = 1
        return rank, world_size

    def resize_position_embeddings(max_position_embeddings, old_vocab_size,
                                   _k='backbone.bottom_up.backbone.embeddings.position_embeddings.weight',
                                   initializer_range=0.02, reuse_position_embedding=True):
        '''
        Reference: unilm
        ALso see discussions:
        https://github.com/pytorch/fairseq/issues/1685
        https://github.com/google-research/bert/issues/27
        '''

        new_position_embedding = state_dict[_k].data.new_tensor(torch.ones(
            size=(max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float)
        new_position_embedding = nn.Parameter(data=new_position_embedding, requires_grad=True)
        new_position_embedding.data.normal_(mean=0.0, std=initializer_range)
        if max_position_embeddings > old_vocab_size:
            logger.info("Resize > position embeddings !")
            max_range = max_position_embeddings if reuse_position_embedding else old_vocab_size
            shift = 0
            while shift < max_range:
                delta = min(old_vocab_size, max_range - shift)
                new_position_embedding.data[shift: shift + delta, :] = state_dict[_k][:delta, :]
                logger.info("  CP [%d ~ %d] into [%d ~ %d]  " % (0, delta, shift, shift + delta))
                shift += delta
            state_dict[_k] = new_position_embedding.data
            del new_position_embedding
        elif max_position_embeddings < old_vocab_size:
            logger.info("Resize < position embeddings !")
            new_position_embedding.data.copy_(state_dict[_k][:max_position_embeddings, :])
            state_dict[_k] = new_position_embedding.data
            del new_position_embedding

    rank, _ = get_dist_info()
    all_keys = list(state_dict.keys())
    for key in all_keys:
        if "embeddings.position_embeddings.weight" in key:
            if key not in model.state_dict():  # image only models do not use this key
                continue
            max_position_embeddings = model.state_dict()[key].shape[0]
            old_vocab_size = state_dict[key].shape[0]
            if max_position_embeddings != old_vocab_size:
                resize_position_embeddings(max_position_embeddings, old_vocab_size,_k=key)

        if "relative_position_index" in key:
            state_dict.pop(key)

        if "relative_position_bias_table" in key:
            rel_pos_bias = state_dict[key]
            src_num_pos, num_attn_heads = rel_pos_bias.size()
            if key not in model.state_dict():
                continue
            dst_num_pos, _ = model.state_dict()[key].size()
            dst_patch_shape = model.backbone.bottom_up.backbone.patch_embed.patch_shape
            if dst_patch_shape[0] != dst_patch_shape[1]:
                raise NotImplementedError()
            num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
            src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
            dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
            if src_size != dst_size:
                if rank == 0:
                    print("Position interpolate for %s from %dx%d to %dx%d" % (
                        key, src_size, src_size, dst_size, dst_size))
                extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
                rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

                def geometric_progression(a, r, n):
                    return a * (1.0 - r ** n) / (1.0 - r)

                left, right = 1.01, 1.5
                while right - left > 1e-6:
                    q = (left + right) / 2.0
                    gp = geometric_progression(1, q, src_size // 2)
                    if gp > dst_size // 2:
                        right = q
                    else:
                        left = q

                # if q > 1.13492:
                #     q = 1.13492

                dis = []
                cur = 1
                for i in range(src_size // 2):
                    dis.append(cur)
                    cur += q ** (i + 1)

                r_ids = [-_ for _ in reversed(dis)]

                x = r_ids + [0] + dis
                y = r_ids + [0] + dis

                t = dst_size // 2.0
                dx = np.arange(-t, t + 0.1, 1.0)
                dy = np.arange(-t, t + 0.1, 1.0)
                if rank == 0:
                    print("x = {}".format(x))
                    print("dx = {}".format(dx))

                all_rel_pos_bias = []

                for i in range(num_attn_heads):
                    z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
                    f = interpolate.interp2d(x, y, z, kind='cubic')
                    all_rel_pos_bias.append(
                        torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))

                rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
                new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
                state_dict[key] = new_rel_pos_bias

    if append_prefix('pos_embed') in state_dict:
        pos_embed_checkpoint = state_dict[append_prefix('pos_embed')]
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.backbone.bottom_up.backbone.patch_embed.num_patches
        num_extra_tokens = model.backbone.bottom_up.backbone.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        # new_size = int(num_patches ** 0.5)
        new_size_w = model.backbone.bottom_up.backbone.patch_embed.num_patches_w
        new_size_h = model.backbone.bottom_up.backbone.patch_embed.num_patches_h
        # class_token and dist_token are kept unchanged
        if orig_size != new_size_h or orig_size != new_size_w:
            if rank == 0:
                print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size_w, new_size_h))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size_w, new_size_h), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            state_dict[append_prefix('pos_embed')] = new_pos_embed

    # interpolate position bias table if needed
    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
    for table_key in relative_position_bias_table_keys:
        table_pretrained = state_dict[table_key]
        if table_key not in model.state_dict():
            continue
        table_current = model.state_dict()[table_key]
        L1, nH1 = table_pretrained.size()
        L2, nH2 = table_current.size()
        if nH1 != nH2:
            logger.warning(f"Error in loading {table_key}, pass")
        else:
            if L1 != L2:
                S1 = int(L1 ** 0.5)
                S2 = int(L2 ** 0.5)
                table_pretrained_resized = F.interpolate(
                    table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
                    size=(S2, S2), mode='bicubic')
                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)

    if append_prefix('rel_pos_bias.relative_position_bias_table') in state_dict and \
            model.backbone.bottom_up.backbone.use_rel_pos_bias and \
            not model.backbone.bottom_up.backbone.use_shared_rel_pos_bias and \
            append_prefix('blocks.0.attn.relative_position_bias_table') not in state_dict:
        logger.info("[BEIT] Expand the shared relative position embedding to each transformer block. ")
        num_layers = model.backbone.bottom_up.backbone.get_num_layers()
        rel_pos_bias = state_dict[append_prefix("rel_pos_bias.relative_position_bias_table")]
        for i in range(num_layers):
            state_dict["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
        state_dict.pop(append_prefix("rel_pos_bias.relative_position_bias_table"))

    return state_dict


class MyDetectionCheckpointer(DetectionCheckpointer):
    def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:
        """
        Load weights from a checkpoint.

        Args:
            checkpoint (Any): checkpoint contains the weights.

        Returns:
            ``NamedTuple`` with ``missing_keys``, ``unexpected_keys``,
                and ``incorrect_shapes`` fields:
                * **missing_keys** is a list of str containing the missing keys
                * **unexpected_keys** is a list of str containing the unexpected keys
                * **incorrect_shapes** is a list of (key, shape in checkpoint, shape in model)

            This is just like the return value of
            :func:`torch.nn.Module.load_state_dict`, but with extra support
            for ``incorrect_shapes``.
        """
        checkpoint_state_dict = checkpoint.pop("model")
        checkpoint_state_dict = self.rename_state_dict(checkpoint_state_dict)
        self._convert_ndarray_to_tensor(checkpoint_state_dict)

        # if the state_dict comes from a model that was wrapped in a
        # DataParallel or DistributedDataParallel during serialization,
        # remove the "module" prefix before performing the matching.
        _strip_prefix_if_present(checkpoint_state_dict, "module.")

        # workaround https://github.com/pytorch/pytorch/issues/24139
        model_state_dict = self.model.state_dict()
        incorrect_shapes = []

        # rename the para in checkpoint_state_dict
        # some bug here, do not support re load

        if 'backbone.fpn_lateral2.weight' not in checkpoint_state_dict.keys():
            checkpoint_state_dict = {
                append_prefix(k): checkpoint_state_dict[k]
                for k in checkpoint_state_dict.keys()
            }
        # else: resume a model, do not need append_prefix

        checkpoint_state_dict = modify_ckpt_state(self.model, checkpoint_state_dict, logger=self.logger)

        for k in list(checkpoint_state_dict.keys()):
            if k in model_state_dict:
                model_param = model_state_dict[k]
                # Allow mismatch for uninitialized parameters
                if TORCH_VERSION >= (1, 8) and isinstance(
                        model_param, nn.parameter.UninitializedParameter
                ):
                    continue
                shape_model = tuple(model_param.shape)
                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
                if shape_model != shape_checkpoint:

                    has_observer_base_classes = (
                            TORCH_VERSION >= (1, 8)
                            and hasattr(quantization, "ObserverBase")
                            and hasattr(quantization, "FakeQuantizeBase")
                    )
                    if has_observer_base_classes:
                        # Handle the special case of quantization per channel observers,
                        # where buffer shape mismatches are expected.
                        def _get_module_for_key(
                                model: torch.nn.Module, key: str
                        ) -> torch.nn.Module:
                            # foo.bar.param_or_buffer_name -> [foo, bar]
                            key_parts = key.split(".")[:-1]
                            cur_module = model
                            for key_part in key_parts:
                                cur_module = getattr(cur_module, key_part)
                            return cur_module

                        cls_to_skip = (
                            ObserverBase,
                            FakeQuantizeBase,
                        )
                        target_module = _get_module_for_key(self.model, k)
                        if isinstance(target_module, cls_to_skip):
                            # Do not remove modules with expected shape mismatches
                            # them from the state_dict loading. They have special logic
                            # in _load_from_state_dict to handle the mismatches.
                            continue

                    incorrect_shapes.append((k, shape_checkpoint, shape_model))
                    checkpoint_state_dict.pop(k)
        incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
        return _IncompatibleKeys(
            missing_keys=incompatible.missing_keys,
            unexpected_keys=incompatible.unexpected_keys,
            incorrect_shapes=incorrect_shapes,
        )

    def rename_state_dict(self, state_dict):
        new_state_dict = OrderedDict()
        layoutlm = False
        for k, v in state_dict.items():
            if 'layoutlmv3' in k:
                layoutlm = True
                new_state_dict[k.replace('layoutlmv3.', '')] = v
        if layoutlm:
            return new_state_dict
        return state_dict