Spaces:
Runtime error
Runtime error
# Copyright (c) Open-CD. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
from typing import List, Tuple | |
from mmengine.model import BaseModule | |
from mmengine.structures import PixelData | |
from torch import Tensor, nn | |
# from mmseg.models import builder | |
from mmseg.models.utils import resize | |
from mmseg.structures import SegDataSample | |
from mmseg.utils import ConfigType, SampleList, add_prefix | |
from opencd.registry import MODELS | |
class MultiHeadDecoder(BaseModule): | |
"""Base class for MultiHeadDecoder. | |
Args: | |
binary_cd_head (dict): The decode head for binary change detection branch. | |
binary_cd_neck (dict): The feature fusion part for binary \ | |
change detection branch | |
semantic_cd_head (dict): The decode head for semantic change \ | |
detection `from` branch. | |
semantic_cd_head_aux (dict): The decode head for semantic change \ | |
detection `to` branch. If None, the siamese semantic head will \ | |
be used. Default: None | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
binary_cd_head, | |
binary_cd_neck=None, | |
semantic_cd_head=None, | |
semantic_cd_head_aux=None, | |
init_cfg=None): | |
super().__init__(init_cfg) | |
self.binary_cd_head = MODELS.build(binary_cd_head) | |
self.siamese_semantic_head = True | |
if binary_cd_neck is not None: | |
self.binary_cd_neck = MODELS.build(binary_cd_neck) | |
if semantic_cd_head is not None: | |
self.semantic_cd_head = MODELS.build(semantic_cd_head) | |
if semantic_cd_head_aux is not None: | |
self.siamese_semantic_head = False | |
self.semantic_cd_head_aux = MODELS.build(semantic_cd_head_aux) | |
else: | |
self.semantic_cd_head_aux = self.semantic_cd_head | |
def forward(self, inputs): | |
"""Placeholder of forward function. | |
The return value should be a dict() containing: | |
`seg_logits`, `seg_logits_from` and `seg_logits_to`. | |
For example: | |
return dict( | |
seg_logits=out, | |
seg_logits_from=out1, | |
seg_logits_to=out2) | |
""" | |
pass | |
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, | |
train_cfg: ConfigType) -> dict: | |
"""Forward function for training. | |
Args: | |
inputs (Tuple[Tensor]): List of multi-level img features. | |
batch_data_samples (list[:obj:`SegDataSample`]): The seg | |
data samples. It usually includes information such | |
as `img_metas` or `gt_semantic_seg`. | |
train_cfg (dict): The training config. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
seg_logits = self.forward(inputs) | |
losses = self.loss_by_feat(seg_logits, batch_data_samples) | |
return losses | |
def predict(self, inputs, batch_img_metas: List[dict], test_cfg, | |
**kwargs) -> List[Tensor]: | |
"""Forward function for testing.""" | |
seg_logits = self.forward(inputs) | |
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) | |
def predict_by_feat(self, seg_logits: Tensor, | |
batch_img_metas: List[dict]) -> Tensor: | |
"""Transform a batch of output seg_logits to the input shape. | |
Args: | |
seg_logits (Tensor): The output from decode head forward function. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
Returns: | |
Tensor: Outputs segmentation logits map. | |
""" | |
assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \ | |
== list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \ | |
and `seg_logits_to` should be contained." | |
self.align_corners = { | |
'seg_logits': self.binary_cd_head.align_corners, | |
'seg_logits_from': self.semantic_cd_head.align_corners, | |
'seg_logits_to': self.semantic_cd_head_aux.align_corners} | |
for seg_name, seg_logit in seg_logits.items(): | |
seg_logits[seg_name] = resize( | |
input=seg_logit, | |
size=batch_img_metas[0]['img_shape'], | |
mode='bilinear', | |
align_corners=self.align_corners[seg_name]) | |
return seg_logits | |
def get_sub_batch_data_samples(self, batch_data_samples: SampleList, | |
sub_metainfo_name: str, | |
sub_data_name: str) -> list: | |
sub_batch_sample_list = [] | |
for i in range(len(batch_data_samples)): | |
data_sample = SegDataSample() | |
gt_sem_seg_data = dict( | |
data=batch_data_samples[i].get(sub_data_name).data) | |
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data) | |
img_meta = {} | |
seg_map_path = batch_data_samples[i].metainfo.get(sub_metainfo_name) | |
for key in batch_data_samples[i].metainfo.keys(): | |
if not 'seg_map_path' in key: | |
img_meta[key] = batch_data_samples[i].metainfo.get(key) | |
img_meta['seg_map_path'] = seg_map_path | |
data_sample.set_metainfo(img_meta) | |
sub_batch_sample_list.append(data_sample) | |
return sub_batch_sample_list | |
def loss_by_feat(self, seg_logits: dict, | |
batch_data_samples: SampleList, **kwargs) -> dict: | |
"""Compute segmentation loss.""" | |
assert ['seg_logits', 'seg_logits_from', 'seg_logits_to'] \ | |
== list(seg_logits.keys()), "`seg_logits`, `seg_logits_from` \ | |
and `seg_logits_to` should be contained." | |
losses = dict() | |
binary_cd_loss_decode = self.binary_cd_head.loss_by_feat( | |
seg_logits['seg_logits'], | |
self.get_sub_batch_data_samples(batch_data_samples, | |
sub_metainfo_name='seg_map_path', | |
sub_data_name='gt_sem_seg')) | |
losses.update(add_prefix(binary_cd_loss_decode, 'binary_cd')) | |
if getattr(self, 'semantic_cd_head'): | |
semantic_cd_loss_decode_from = self.semantic_cd_head.loss_by_feat( | |
seg_logits['seg_logits_from'], | |
self.get_sub_batch_data_samples(batch_data_samples, | |
sub_metainfo_name='seg_map_path_from', | |
sub_data_name='gt_sem_seg_from')) | |
losses.update(add_prefix(semantic_cd_loss_decode_from, 'semantic_cd_from')) | |
semantic_cd_loss_decode_to = self.semantic_cd_head_aux.loss_by_feat( | |
seg_logits['seg_logits_to'], | |
self.get_sub_batch_data_samples(batch_data_samples, | |
sub_metainfo_name='seg_map_path_to', | |
sub_data_name='gt_sem_seg_to')) | |
losses.update(add_prefix(semantic_cd_loss_decode_to, 'semantic_cd_to')) | |
return losses |