|
|
|
from functools import partial |
|
from typing import List, Sequence, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from mmengine.structures import InstanceData |
|
from mmengine.utils import digit_version |
|
from six.moves import map, zip |
|
from torch import Tensor |
|
from torch.autograd import Function |
|
from torch.nn import functional as F |
|
|
|
from mmdet.structures import SampleList |
|
from mmdet.structures.bbox import BaseBoxes, get_box_type, stack_boxes |
|
from mmdet.structures.mask import BitmapMasks, PolygonMasks |
|
from mmdet.utils import OptInstanceList |
|
|
|
|
|
class SigmoidGeometricMean(Function): |
|
"""Forward and backward function of geometric mean of two sigmoid |
|
functions. |
|
|
|
This implementation with analytical gradient function substitutes |
|
the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The |
|
original implementation incurs none during gradient backprapagation |
|
if both x and y are very small values. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, x, y): |
|
x_sigmoid = x.sigmoid() |
|
y_sigmoid = y.sigmoid() |
|
z = (x_sigmoid * y_sigmoid).sqrt() |
|
ctx.save_for_backward(x_sigmoid, y_sigmoid, z) |
|
return z |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
x_sigmoid, y_sigmoid, z = ctx.saved_tensors |
|
grad_x = grad_output * z * (1 - x_sigmoid) / 2 |
|
grad_y = grad_output * z * (1 - y_sigmoid) / 2 |
|
return grad_x, grad_y |
|
|
|
|
|
sigmoid_geometric_mean = SigmoidGeometricMean.apply |
|
|
|
|
|
def interpolate_as(source, target, mode='bilinear', align_corners=False): |
|
"""Interpolate the `source` to the shape of the `target`. |
|
|
|
The `source` must be a Tensor, but the `target` can be a Tensor or a |
|
np.ndarray with the shape (..., target_h, target_w). |
|
|
|
Args: |
|
source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or |
|
(N, C, H, W). |
|
target (Tensor | np.ndarray): The interpolation target with the shape |
|
(..., target_h, target_w). |
|
mode (str): Algorithm used for interpolation. The options are the |
|
same as those in F.interpolate(). Default: ``'bilinear'``. |
|
align_corners (bool): The same as the argument in F.interpolate(). |
|
|
|
Returns: |
|
Tensor: The interpolated source Tensor. |
|
""" |
|
assert len(target.shape) >= 2 |
|
|
|
def _interpolate_as(source, target, mode='bilinear', align_corners=False): |
|
"""Interpolate the `source` (4D) to the shape of the `target`.""" |
|
target_h, target_w = target.shape[-2:] |
|
source_h, source_w = source.shape[-2:] |
|
if target_h != source_h or target_w != source_w: |
|
source = F.interpolate( |
|
source, |
|
size=(target_h, target_w), |
|
mode=mode, |
|
align_corners=align_corners) |
|
return source |
|
|
|
if len(source.shape) == 3: |
|
source = source[:, None, :, :] |
|
source = _interpolate_as(source, target, mode, align_corners) |
|
return source[:, 0, :, :] |
|
else: |
|
return _interpolate_as(source, target, mode, align_corners) |
|
|
|
|
|
def unpack_gt_instances(batch_data_samples: SampleList) -> tuple: |
|
"""Unpack ``gt_instances``, ``gt_instances_ignore`` and ``img_metas`` based |
|
on ``batch_data_samples`` |
|
|
|
Args: |
|
batch_data_samples (List[:obj:`DetDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
|
|
Returns: |
|
tuple: |
|
|
|
- batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
|
gt_instance. It usually includes ``bboxes`` and ``labels`` |
|
attributes. |
|
- batch_gt_instances_ignore (list[:obj:`InstanceData`]): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
- batch_img_metas (list[dict]): Meta information of each image, |
|
e.g., image size, scaling factor, etc. |
|
""" |
|
batch_gt_instances = [] |
|
batch_gt_instances_ignore = [] |
|
batch_img_metas = [] |
|
for data_sample in batch_data_samples: |
|
batch_img_metas.append(data_sample.metainfo) |
|
batch_gt_instances.append(data_sample.gt_instances) |
|
if 'ignored_instances' in data_sample: |
|
batch_gt_instances_ignore.append(data_sample.ignored_instances) |
|
else: |
|
batch_gt_instances_ignore.append(None) |
|
|
|
return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas |
|
|
|
|
|
def empty_instances(batch_img_metas: List[dict], |
|
device: torch.device, |
|
task_type: str, |
|
instance_results: OptInstanceList = None, |
|
mask_thr_binary: Union[int, float] = 0, |
|
box_type: Union[str, type] = 'hbox', |
|
use_box_type: bool = False, |
|
num_classes: int = 80, |
|
score_per_cls: bool = False) -> List[InstanceData]: |
|
"""Handle predicted instances when RoI is empty. |
|
|
|
Note: If ``instance_results`` is not None, it will be modified |
|
in place internally, and then return ``instance_results`` |
|
|
|
Args: |
|
batch_img_metas (list[dict]): List of image information. |
|
device (torch.device): Device of tensor. |
|
task_type (str): Expected returned task type. it currently |
|
supports bbox and mask. |
|
instance_results (list[:obj:`InstanceData`]): List of instance |
|
results. |
|
mask_thr_binary (int, float): mask binarization threshold. |
|
Defaults to 0. |
|
box_type (str or type): The empty box type. Defaults to `hbox`. |
|
use_box_type (bool): Whether to warp boxes with the box type. |
|
Defaults to False. |
|
num_classes (int): num_classes of bbox_head. Defaults to 80. |
|
score_per_cls (bool): Whether to generate classwise score for |
|
the empty instance. ``score_per_cls`` will be True when the model |
|
needs to produce raw results without nms. Defaults to False. |
|
|
|
Returns: |
|
list[:obj:`InstanceData`]: Detection results of each image |
|
""" |
|
assert task_type in ('bbox', 'mask'), 'Only support bbox and mask,' \ |
|
f' but got {task_type}' |
|
|
|
if instance_results is not None: |
|
assert len(instance_results) == len(batch_img_metas) |
|
|
|
results_list = [] |
|
for img_id in range(len(batch_img_metas)): |
|
if instance_results is not None: |
|
results = instance_results[img_id] |
|
assert isinstance(results, InstanceData) |
|
else: |
|
results = InstanceData() |
|
|
|
if task_type == 'bbox': |
|
_, box_type = get_box_type(box_type) |
|
bboxes = torch.zeros(0, box_type.box_dim, device=device) |
|
if use_box_type: |
|
bboxes = box_type(bboxes, clone=False) |
|
results.bboxes = bboxes |
|
score_shape = (0, num_classes + 1) if score_per_cls else (0, ) |
|
results.scores = torch.zeros(score_shape, device=device) |
|
results.labels = torch.zeros((0, ), |
|
device=device, |
|
dtype=torch.long) |
|
else: |
|
|
|
img_h, img_w = batch_img_metas[img_id]['ori_shape'][:2] |
|
|
|
|
|
im_mask = torch.zeros( |
|
0, |
|
img_h, |
|
img_w, |
|
device=device, |
|
dtype=torch.bool if mask_thr_binary >= 0 else torch.uint8) |
|
results.masks = im_mask |
|
results_list.append(results) |
|
return results_list |
|
|
|
|
|
def multi_apply(func, *args, **kwargs): |
|
"""Apply function to a list of arguments. |
|
|
|
Note: |
|
This function applies the ``func`` to multiple inputs and |
|
map the multiple outputs of the ``func`` into different |
|
list. Each list contains the same type of outputs corresponding |
|
to different inputs. |
|
|
|
Args: |
|
func (Function): A function that will be applied to a list of |
|
arguments |
|
|
|
Returns: |
|
tuple(list): A tuple containing multiple list, each list contains \ |
|
a kind of returned results by the function |
|
""" |
|
pfunc = partial(func, **kwargs) if kwargs else func |
|
map_results = map(pfunc, *args) |
|
return tuple(map(list, zip(*map_results))) |
|
|
|
|
|
def unmap(data, count, inds, fill=0): |
|
"""Unmap a subset of item (data) back to the original set of items (of size |
|
count)""" |
|
if data.dim() == 1: |
|
ret = data.new_full((count, ), fill) |
|
ret[inds.type(torch.bool)] = data |
|
else: |
|
new_size = (count, ) + data.size()[1:] |
|
ret = data.new_full(new_size, fill) |
|
ret[inds.type(torch.bool), :] = data |
|
return ret |
|
|
|
|
|
def mask2ndarray(mask): |
|
"""Convert Mask to ndarray.. |
|
|
|
Args: |
|
mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or |
|
torch.Tensor or np.ndarray): The mask to be converted. |
|
|
|
Returns: |
|
np.ndarray: Ndarray mask of shape (n, h, w) that has been converted |
|
""" |
|
if isinstance(mask, (BitmapMasks, PolygonMasks)): |
|
mask = mask.to_ndarray() |
|
elif isinstance(mask, torch.Tensor): |
|
mask = mask.detach().cpu().numpy() |
|
elif not isinstance(mask, np.ndarray): |
|
raise TypeError(f'Unsupported {type(mask)} data type') |
|
return mask |
|
|
|
|
|
def flip_tensor(src_tensor, flip_direction): |
|
"""flip tensor base on flip_direction. |
|
|
|
Args: |
|
src_tensor (Tensor): input feature map, shape (B, C, H, W). |
|
flip_direction (str): The flipping direction. Options are |
|
'horizontal', 'vertical', 'diagonal'. |
|
|
|
Returns: |
|
out_tensor (Tensor): Flipped tensor. |
|
""" |
|
assert src_tensor.ndim == 4 |
|
valid_directions = ['horizontal', 'vertical', 'diagonal'] |
|
assert flip_direction in valid_directions |
|
if flip_direction == 'horizontal': |
|
out_tensor = torch.flip(src_tensor, [3]) |
|
elif flip_direction == 'vertical': |
|
out_tensor = torch.flip(src_tensor, [2]) |
|
else: |
|
out_tensor = torch.flip(src_tensor, [2, 3]) |
|
return out_tensor |
|
|
|
|
|
def select_single_mlvl(mlvl_tensors, batch_id, detach=True): |
|
"""Extract a multi-scale single image tensor from a multi-scale batch |
|
tensor based on batch index. |
|
|
|
Note: The default value of detach is True, because the proposal gradient |
|
needs to be detached during the training of the two-stage model. E.g |
|
Cascade Mask R-CNN. |
|
|
|
Args: |
|
mlvl_tensors (list[Tensor]): Batch tensor for all scale levels, |
|
each is a 4D-tensor. |
|
batch_id (int): Batch index. |
|
detach (bool): Whether detach gradient. Default True. |
|
|
|
Returns: |
|
list[Tensor]: Multi-scale single image tensor. |
|
""" |
|
assert isinstance(mlvl_tensors, (list, tuple)) |
|
num_levels = len(mlvl_tensors) |
|
|
|
if detach: |
|
mlvl_tensor_list = [ |
|
mlvl_tensors[i][batch_id].detach() for i in range(num_levels) |
|
] |
|
else: |
|
mlvl_tensor_list = [ |
|
mlvl_tensors[i][batch_id] for i in range(num_levels) |
|
] |
|
return mlvl_tensor_list |
|
|
|
|
|
def filter_scores_and_topk(scores, score_thr, topk, results=None): |
|
"""Filter results using score threshold and topk candidates. |
|
|
|
Args: |
|
scores (Tensor): The scores, shape (num_bboxes, K). |
|
score_thr (float): The score filter threshold. |
|
topk (int): The number of topk candidates. |
|
results (dict or list or Tensor, Optional): The results to |
|
which the filtering rule is to be applied. The shape |
|
of each item is (num_bboxes, N). |
|
|
|
Returns: |
|
tuple: Filtered results |
|
|
|
- scores (Tensor): The scores after being filtered, \ |
|
shape (num_bboxes_filtered, ). |
|
- labels (Tensor): The class labels, shape \ |
|
(num_bboxes_filtered, ). |
|
- anchor_idxs (Tensor): The anchor indexes, shape \ |
|
(num_bboxes_filtered, ). |
|
- filtered_results (dict or list or Tensor, Optional): \ |
|
The filtered results. The shape of each item is \ |
|
(num_bboxes_filtered, N). |
|
""" |
|
valid_mask = scores > score_thr |
|
scores = scores[valid_mask] |
|
valid_idxs = torch.nonzero(valid_mask) |
|
|
|
num_topk = min(topk, valid_idxs.size(0)) |
|
|
|
scores, idxs = scores.sort(descending=True) |
|
scores = scores[:num_topk] |
|
topk_idxs = valid_idxs[idxs[:num_topk]] |
|
keep_idxs, labels = topk_idxs.unbind(dim=1) |
|
|
|
filtered_results = None |
|
if results is not None: |
|
if isinstance(results, dict): |
|
filtered_results = {k: v[keep_idxs] for k, v in results.items()} |
|
elif isinstance(results, list): |
|
filtered_results = [result[keep_idxs] for result in results] |
|
elif isinstance(results, torch.Tensor): |
|
filtered_results = results[keep_idxs] |
|
else: |
|
raise NotImplementedError(f'Only supports dict or list or Tensor, ' |
|
f'but get {type(results)}.') |
|
return scores, labels, keep_idxs, filtered_results |
|
|
|
|
|
def center_of_mass(mask, esp=1e-6): |
|
"""Calculate the centroid coordinates of the mask. |
|
|
|
Args: |
|
mask (Tensor): The mask to be calculated, shape (h, w). |
|
esp (float): Avoid dividing by zero. Default: 1e-6. |
|
|
|
Returns: |
|
tuple[Tensor]: the coordinates of the center point of the mask. |
|
|
|
- center_h (Tensor): the center point of the height. |
|
- center_w (Tensor): the center point of the width. |
|
""" |
|
h, w = mask.shape |
|
grid_h = torch.arange(h, device=mask.device)[:, None] |
|
grid_w = torch.arange(w, device=mask.device) |
|
normalizer = mask.sum().float().clamp(min=esp) |
|
center_h = (mask * grid_h).sum() / normalizer |
|
center_w = (mask * grid_w).sum() / normalizer |
|
return center_h, center_w |
|
|
|
|
|
def generate_coordinate(featmap_sizes, device='cuda'): |
|
"""Generate the coordinate. |
|
|
|
Args: |
|
featmap_sizes (tuple): The feature to be calculated, |
|
of shape (N, C, W, H). |
|
device (str): The device where the feature will be put on. |
|
Returns: |
|
coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H). |
|
""" |
|
|
|
x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) |
|
y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) |
|
y, x = torch.meshgrid(y_range, x_range) |
|
y = y.expand([featmap_sizes[0], 1, -1, -1]) |
|
x = x.expand([featmap_sizes[0], 1, -1, -1]) |
|
coord_feat = torch.cat([x, y], 1) |
|
|
|
return coord_feat |
|
|
|
|
|
def levels_to_images(mlvl_tensor: List[torch.Tensor]) -> List[torch.Tensor]: |
|
"""Concat multi-level feature maps by image. |
|
|
|
[feature_level0, feature_level1...] -> [feature_image0, feature_image1...] |
|
Convert the shape of each element in mlvl_tensor from (N, C, H, W) to |
|
(N, H*W , C), then split the element to N elements with shape (H*W, C), and |
|
concat elements in same image of all level along first dimension. |
|
|
|
Args: |
|
mlvl_tensor (list[Tensor]): list of Tensor which collect from |
|
corresponding level. Each element is of shape (N, C, H, W) |
|
|
|
Returns: |
|
list[Tensor]: A list that contains N tensors and each tensor is |
|
of shape (num_elements, C) |
|
""" |
|
batch_size = mlvl_tensor[0].size(0) |
|
batch_list = [[] for _ in range(batch_size)] |
|
channels = mlvl_tensor[0].size(1) |
|
for t in mlvl_tensor: |
|
t = t.permute(0, 2, 3, 1) |
|
t = t.view(batch_size, -1, channels).contiguous() |
|
for img in range(batch_size): |
|
batch_list[img].append(t[img]) |
|
return [torch.cat(item, 0) for item in batch_list] |
|
|
|
|
|
def images_to_levels(target, num_levels): |
|
"""Convert targets by image to targets by feature level. |
|
|
|
[target_img0, target_img1] -> [target_level0, target_level1, ...] |
|
""" |
|
target = stack_boxes(target, 0) |
|
level_targets = [] |
|
start = 0 |
|
for n in num_levels: |
|
end = start + n |
|
|
|
level_targets.append(target[:, start:end]) |
|
start = end |
|
return level_targets |
|
|
|
|
|
def samplelist_boxtype2tensor(batch_data_samples: SampleList) -> SampleList: |
|
for data_samples in batch_data_samples: |
|
if 'gt_instances' in data_samples: |
|
bboxes = data_samples.gt_instances.get('bboxes', None) |
|
if isinstance(bboxes, BaseBoxes): |
|
data_samples.gt_instances.bboxes = bboxes.tensor |
|
if 'pred_instances' in data_samples: |
|
bboxes = data_samples.pred_instances.get('bboxes', None) |
|
if isinstance(bboxes, BaseBoxes): |
|
data_samples.pred_instances.bboxes = bboxes.tensor |
|
if 'ignored_instances' in data_samples: |
|
bboxes = data_samples.ignored_instances.get('bboxes', None) |
|
if isinstance(bboxes, BaseBoxes): |
|
data_samples.ignored_instances.bboxes = bboxes.tensor |
|
|
|
|
|
_torch_version_div_indexing = ( |
|
'parrots' not in torch.__version__ |
|
and digit_version(torch.__version__) >= digit_version('1.8')) |
|
|
|
|
|
def floordiv(dividend, divisor, rounding_mode='trunc'): |
|
if _torch_version_div_indexing: |
|
return torch.div(dividend, divisor, rounding_mode=rounding_mode) |
|
else: |
|
return dividend // divisor |
|
|
|
|
|
def _filter_gt_instances_by_score(batch_data_samples: SampleList, |
|
score_thr: float) -> SampleList: |
|
"""Filter ground truth (GT) instances by score. |
|
|
|
Args: |
|
batch_data_samples (SampleList): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
score_thr (float): The score filter threshold. |
|
|
|
Returns: |
|
SampleList: The Data Samples filtered by score. |
|
""" |
|
for data_samples in batch_data_samples: |
|
assert 'scores' in data_samples.gt_instances, \ |
|
'there does not exit scores in instances' |
|
if data_samples.gt_instances.bboxes.shape[0] > 0: |
|
data_samples.gt_instances = data_samples.gt_instances[ |
|
data_samples.gt_instances.scores > score_thr] |
|
return batch_data_samples |
|
|
|
|
|
def _filter_gt_instances_by_size(batch_data_samples: SampleList, |
|
wh_thr: tuple) -> SampleList: |
|
"""Filter ground truth (GT) instances by size. |
|
|
|
Args: |
|
batch_data_samples (SampleList): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
wh_thr (tuple): Minimum width and height of bbox. |
|
|
|
Returns: |
|
SampleList: The Data Samples filtered by score. |
|
""" |
|
for data_samples in batch_data_samples: |
|
bboxes = data_samples.gt_instances.bboxes |
|
if bboxes.shape[0] > 0: |
|
w = bboxes[:, 2] - bboxes[:, 0] |
|
h = bboxes[:, 3] - bboxes[:, 1] |
|
data_samples.gt_instances = data_samples.gt_instances[ |
|
(w > wh_thr[0]) & (h > wh_thr[1])] |
|
return batch_data_samples |
|
|
|
|
|
def filter_gt_instances(batch_data_samples: SampleList, |
|
score_thr: float = None, |
|
wh_thr: tuple = None): |
|
"""Filter ground truth (GT) instances by score and/or size. |
|
|
|
Args: |
|
batch_data_samples (SampleList): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
score_thr (float): The score filter threshold. |
|
wh_thr (tuple): Minimum width and height of bbox. |
|
|
|
Returns: |
|
SampleList: The Data Samples filtered by score and/or size. |
|
""" |
|
|
|
if score_thr is not None: |
|
batch_data_samples = _filter_gt_instances_by_score( |
|
batch_data_samples, score_thr) |
|
if wh_thr is not None: |
|
batch_data_samples = _filter_gt_instances_by_size( |
|
batch_data_samples, wh_thr) |
|
return batch_data_samples |
|
|
|
|
|
def rename_loss_dict(prefix: str, losses: dict) -> dict: |
|
"""Rename the key names in loss dict by adding a prefix. |
|
|
|
Args: |
|
prefix (str): The prefix for loss components. |
|
losses (dict): A dictionary of loss components. |
|
|
|
Returns: |
|
dict: A dictionary of loss components with prefix. |
|
""" |
|
return {prefix + k: v for k, v in losses.items()} |
|
|
|
|
|
def reweight_loss_dict(losses: dict, weight: float) -> dict: |
|
"""Reweight losses in the dict by weight. |
|
|
|
Args: |
|
losses (dict): A dictionary of loss components. |
|
weight (float): Weight for loss components. |
|
|
|
Returns: |
|
dict: A dictionary of weighted loss components. |
|
""" |
|
for name, loss in losses.items(): |
|
if 'loss' in name: |
|
if isinstance(loss, Sequence): |
|
losses[name] = [item * weight for item in loss] |
|
else: |
|
losses[name] = loss * weight |
|
return losses |
|
|
|
|
|
def relative_coordinate_maps( |
|
locations: Tensor, |
|
centers: Tensor, |
|
strides: Tensor, |
|
size_of_interest: int, |
|
feat_sizes: Tuple[int], |
|
) -> Tensor: |
|
"""Generate the relative coordinate maps with feat_stride. |
|
|
|
Args: |
|
locations (Tensor): The prior location of mask feature map. |
|
It has shape (num_priors, 2). |
|
centers (Tensor): The prior points of a object in |
|
all feature pyramid. It has shape (num_pos, 2) |
|
strides (Tensor): The prior strides of a object in |
|
all feature pyramid. It has shape (num_pos, 1) |
|
size_of_interest (int): The size of the region used in rel coord. |
|
feat_sizes (Tuple[int]): The feature size H and W, which has 2 dims. |
|
Returns: |
|
rel_coord_feat (Tensor): The coordinate feature |
|
of shape (num_pos, 2, H, W). |
|
""" |
|
|
|
H, W = feat_sizes |
|
rel_coordinates = centers.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) |
|
rel_coordinates = rel_coordinates.permute(0, 2, 1).float() |
|
rel_coordinates = rel_coordinates / ( |
|
strides[:, None, None] * size_of_interest) |
|
return rel_coordinates.reshape(-1, 2, H, W) |
|
|
|
|
|
def aligned_bilinear(tensor: Tensor, factor: int) -> Tensor: |
|
"""aligned bilinear, used in original implement in CondInst: |
|
|
|
https://github.com/aim-uofa/AdelaiDet/blob/\ |
|
c0b2092ce72442b0f40972f7c6dda8bb52c46d16/adet/utils/comm.py#L23 |
|
""" |
|
|
|
assert tensor.dim() == 4 |
|
assert factor >= 1 |
|
assert int(factor) == factor |
|
|
|
if factor == 1: |
|
return tensor |
|
|
|
h, w = tensor.size()[2:] |
|
tensor = F.pad(tensor, pad=(0, 1, 0, 1), mode='replicate') |
|
oh = factor * h + 1 |
|
ow = factor * w + 1 |
|
tensor = F.interpolate( |
|
tensor, size=(oh, ow), mode='bilinear', align_corners=True) |
|
tensor = F.pad( |
|
tensor, pad=(factor // 2, 0, factor // 2, 0), mode='replicate') |
|
|
|
return tensor[:, :, :oh - 1, :ow - 1] |
|
|
|
|
|
def unfold_wo_center(x, kernel_size: int, dilation: int) -> Tensor: |
|
"""unfold_wo_center, used in original implement in BoxInst: |
|
|
|
https://github.com/aim-uofa/AdelaiDet/blob/\ |
|
4a3a1f7372c35b48ebf5f6adc59f135a0fa28d60/\ |
|
adet/modeling/condinst/condinst.py#L53 |
|
""" |
|
assert x.dim() == 4 |
|
assert kernel_size % 2 == 1 |
|
|
|
|
|
padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2 |
|
unfolded_x = F.unfold( |
|
x, kernel_size=kernel_size, padding=padding, dilation=dilation) |
|
unfolded_x = unfolded_x.reshape( |
|
x.size(0), x.size(1), -1, x.size(2), x.size(3)) |
|
|
|
size = kernel_size**2 |
|
unfolded_x = torch.cat( |
|
(unfolded_x[:, :, :size // 2], unfolded_x[:, :, size // 2 + 1:]), |
|
dim=2) |
|
|
|
return unfolded_x |
|
|