Spaces:
Runtime error
Runtime error
File size: 5,192 Bytes
864ebc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# CroCo model for downstream tasks
# --------------------------------------------------------
import torch
from .croco import CroCoNet
def croco_args_from_ckpt(ckpt):
if "croco_kwargs" in ckpt: # CroCo v2 released models
return ckpt["croco_kwargs"]
elif "args" in ckpt and hasattr(
ckpt["args"], "model"
): # pretrained using the official code release
s = ckpt[
"args"
].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
assert s.startswith("CroCoNet(")
return eval(
"dict" + s[len("CroCoNet") :]
) # transform it into the string of a dictionary and evaluate it
else: # CroCo v1 released models
return dict()
class CroCoDownstreamMonocularEncoder(CroCoNet):
def __init__(self, head, **kwargs):
"""Build network for monocular downstream task, only using the encoder.
It takes an extra argument head, that is called with the features
and a dictionary img_info containing 'width' and 'height' keys
The head is setup with the croconet arguments in this init function
NOTE: It works by *calling super().__init__() but with redefined setters
"""
super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
head.setup(self)
self.head = head
def _set_mask_generator(self, *args, **kwargs):
"""No mask generator"""
return
def _set_mask_token(self, *args, **kwargs):
"""No mask token"""
self.mask_token = None
return
def _set_decoder(self, *args, **kwargs):
"""No decoder"""
return
def _set_prediction_head(self, *args, **kwargs):
"""No 'prediction head' for downstream tasks."""
return
def forward(self, img):
"""
img if of size batch_size x 3 x h x w
"""
B, C, H, W = img.size()
img_info = {"height": H, "width": W}
need_all_layers = (
hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
)
out, _, _ = self._encode_image(
img, do_mask=False, return_all_blocks=need_all_layers
)
return self.head(out, img_info)
class CroCoDownstreamBinocular(CroCoNet):
def __init__(self, head, **kwargs):
"""Build network for binocular downstream task
It takes an extra argument head, that is called with the features
and a dictionary img_info containing 'width' and 'height' keys
The head is setup with the croconet arguments in this init function
"""
super(CroCoDownstreamBinocular, self).__init__(**kwargs)
head.setup(self)
self.head = head
def _set_mask_generator(self, *args, **kwargs):
"""No mask generator"""
return
def _set_mask_token(self, *args, **kwargs):
"""No mask token"""
self.mask_token = None
return
def _set_prediction_head(self, *args, **kwargs):
"""No prediction head for downstream tasks, define your own head"""
return
def encode_image_pairs(self, img1, img2, return_all_blocks=False):
"""run encoder for a pair of images
it is actually ~5% faster to concatenate the images along the batch dimension
than to encode them separately
"""
## the two commented lines below is the naive version with separate encoding
# out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
# out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
## and now the faster version
out, pos, _ = self._encode_image(
torch.cat((img1, img2), dim=0),
do_mask=False,
return_all_blocks=return_all_blocks,
)
if return_all_blocks:
out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
out2 = out2[-1]
else:
out, out2 = out.chunk(2, dim=0)
pos, pos2 = pos.chunk(2, dim=0)
return out, out2, pos, pos2
def forward(self, img1, img2):
B, C, H, W = img1.size()
img_info = {"height": H, "width": W}
return_all_blocks = (
hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
)
out, out2, pos, pos2 = self.encode_image_pairs(
img1, img2, return_all_blocks=return_all_blocks
)
if return_all_blocks:
decout = self._decoder(
out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks
)
decout = out + decout
else:
decout = self._decoder(
out, pos, None, out2, pos2, return_all_blocks=return_all_blocks
)
return self.head(decout, img_info)
|