File size: 5,192 Bytes
864ebc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

# --------------------------------------------------------
# CroCo model for downstream tasks
# --------------------------------------------------------

import torch

from .croco import CroCoNet


def croco_args_from_ckpt(ckpt):
    if "croco_kwargs" in ckpt:  # CroCo v2 released models
        return ckpt["croco_kwargs"]
    elif "args" in ckpt and hasattr(
        ckpt["args"], "model"
    ):  # pretrained using the official code release
        s = ckpt[
            "args"
        ].model  # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
        assert s.startswith("CroCoNet(")
        return eval(
            "dict" + s[len("CroCoNet") :]
        )  # transform it into the string of a dictionary and evaluate it
    else:  # CroCo v1 released models
        return dict()


class CroCoDownstreamMonocularEncoder(CroCoNet):
    def __init__(self, head, **kwargs):
        """Build network for monocular downstream task, only using the encoder.

        It takes an extra argument head, that is called with the features

          and a dictionary img_info containing 'width' and 'height' keys

        The head is setup with the croconet arguments in this init function

        NOTE: It works by *calling super().__init__() but with redefined setters



        """
        super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
        head.setup(self)
        self.head = head

    def _set_mask_generator(self, *args, **kwargs):
        """No mask generator"""
        return

    def _set_mask_token(self, *args, **kwargs):
        """No mask token"""
        self.mask_token = None
        return

    def _set_decoder(self, *args, **kwargs):
        """No decoder"""
        return

    def _set_prediction_head(self, *args, **kwargs):
        """No 'prediction head' for downstream tasks."""
        return

    def forward(self, img):
        """

        img if of size batch_size x 3 x h x w

        """
        B, C, H, W = img.size()
        img_info = {"height": H, "width": W}
        need_all_layers = (
            hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
        )
        out, _, _ = self._encode_image(
            img, do_mask=False, return_all_blocks=need_all_layers
        )
        return self.head(out, img_info)


class CroCoDownstreamBinocular(CroCoNet):
    def __init__(self, head, **kwargs):
        """Build network for binocular downstream task

        It takes an extra argument head, that is called with the features

          and a dictionary img_info containing 'width' and 'height' keys

        The head is setup with the croconet arguments in this init function

        """
        super(CroCoDownstreamBinocular, self).__init__(**kwargs)
        head.setup(self)
        self.head = head

    def _set_mask_generator(self, *args, **kwargs):
        """No mask generator"""
        return

    def _set_mask_token(self, *args, **kwargs):
        """No mask token"""
        self.mask_token = None
        return

    def _set_prediction_head(self, *args, **kwargs):
        """No prediction head for downstream tasks, define your own head"""
        return

    def encode_image_pairs(self, img1, img2, return_all_blocks=False):
        """run encoder for a pair of images

        it is actually ~5% faster to concatenate the images along the batch dimension

         than to encode them separately

        """
        ## the two commented lines below is the naive version with separate encoding
        # out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
        # out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
        ## and now the faster version
        out, pos, _ = self._encode_image(
            torch.cat((img1, img2), dim=0),
            do_mask=False,
            return_all_blocks=return_all_blocks,
        )
        if return_all_blocks:
            out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
            out2 = out2[-1]
        else:
            out, out2 = out.chunk(2, dim=0)
        pos, pos2 = pos.chunk(2, dim=0)
        return out, out2, pos, pos2

    def forward(self, img1, img2):
        B, C, H, W = img1.size()
        img_info = {"height": H, "width": W}
        return_all_blocks = (
            hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
        )
        out, out2, pos, pos2 = self.encode_image_pairs(
            img1, img2, return_all_blocks=return_all_blocks
        )
        if return_all_blocks:
            decout = self._decoder(
                out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks
            )
            decout = out + decout
        else:
            decout = self._decoder(
                out, pos, None, out2, pos2, return_all_blocks=return_all_blocks
            )
        return self.head(decout, img_info)