Spaces:
Runtime error
Runtime error
File size: 919 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# Copyright (c) Open-CD. All rights reserved.
from typing import List, Optional
import torch
from torch import Tensor
from opencd.registry import MODELS
from .siamencoder_decoder import SiamEncoderDecoder
@MODELS.register_module()
class DIEncoderDecoder(SiamEncoderDecoder):
"""Dual Input Encoder Decoder segmentors.
DIEncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be dumped during inference.
"""
def extract_feat(self, inputs: Tensor) -> List[Tensor]:
"""Extract features from images."""
# `in_channels` is not in the ATTRIBUTE for some backbone CLASS.
img_from, img_to = torch.split(inputs, self.backbone_inchannels, dim=1)
x = self.backbone(img_from, img_to)
if self.with_neck:
x = self.neck(x)
return x
|