Spaces:
Runtime error
Runtime error
File size: 3,038 Bytes
864ebc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Heads for downstream tasks
# --------------------------------------------------------
"""
A head is a module where the __init__ defines only the head hyperparameters.
A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
"""
import torch
import torch.nn as nn
from .dpt_block import DPTOutputAdapter
class PixelwiseTaskWithDPT(nn.Module):
"""DPT module for CroCo.
by default, hooks_idx will be equal to:
* for encoder-only: 4 equally spread layers
* for encoder+decoder: last encoder + 3 equally spread layers of the decoder
"""
def __init__(
self,
*,
hooks_idx=None,
layer_dims=[96, 192, 384, 768],
output_width_ratio=1,
num_channels=1,
postprocess=None,
**kwargs,
):
super(PixelwiseTaskWithDPT, self).__init__()
self.return_all_blocks = True # backbone needs to return all layers
self.postprocess = postprocess
self.output_width_ratio = output_width_ratio
self.num_channels = num_channels
self.hooks_idx = hooks_idx
self.layer_dims = layer_dims
def setup(self, croconet):
dpt_args = {
"output_width_ratio": self.output_width_ratio,
"num_channels": self.num_channels,
}
if self.hooks_idx is None:
if hasattr(croconet, "dec_blocks"): # encoder + decoder
step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
hooks_idx = [
croconet.dec_depth + croconet.enc_depth - 1 - i * step
for i in range(3, -1, -1)
]
else: # encoder only
step = croconet.enc_depth // 4
hooks_idx = [
croconet.enc_depth - 1 - i * step for i in range(3, -1, -1)
]
self.hooks_idx = hooks_idx
print(
f" PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}"
)
dpt_args["hooks"] = self.hooks_idx
dpt_args["layer_dims"] = self.layer_dims
self.dpt = DPTOutputAdapter(**dpt_args)
dim_tokens = [
croconet.enc_embed_dim
if hook < croconet.enc_depth
else croconet.dec_embed_dim
for hook in self.hooks_idx
]
dpt_init_args = {"dim_tokens_enc": dim_tokens}
self.dpt.init(**dpt_init_args)
def forward(self, x, img_info):
out = self.dpt(x, image_size=(img_info["height"], img_info["width"]))
if self.postprocess:
out = self.postprocess(out)
return out
|