|
from math import prod |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
from timm.models.layers import to_2tuple |
|
|
|
|
|
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: |
|
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C).""" |
|
return x.permute(0, 2, 3, 1) |
|
|
|
|
|
def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor: |
|
"""Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W).""" |
|
return x.permute(0, 3, 1, 2) |
|
|
|
|
|
def bchw_to_blc(x: torch.Tensor) -> torch.Tensor: |
|
"""Rearrange a tensor from the shape (B, C, H, W) to (B, L, C).""" |
|
return x.flatten(2).transpose(1, 2) |
|
|
|
|
|
def blc_to_bchw(x: torch.Tensor, x_size: Tuple) -> torch.Tensor: |
|
"""Rearrange a tensor from the shape (B, L, C) to (B, C, H, W).""" |
|
B, L, C = x.shape |
|
return x.transpose(1, 2).view(B, C, *x_size) |
|
|
|
|
|
def blc_to_bhwc(x: torch.Tensor, x_size: Tuple) -> torch.Tensor: |
|
"""Rearrange a tensor from the shape (B, L, C) to (B, H, W, C).""" |
|
B, L, C = x.shape |
|
return x.view(B, *x_size, C) |
|
|
|
|
|
def window_partition(x, window_size: Tuple[int, int]): |
|
""" |
|
Args: |
|
x: (B, H, W, C) |
|
window_size (int): window size |
|
|
|
Returns: |
|
windows: (num_windows*B, window_size, window_size, C) |
|
""" |
|
B, H, W, C = x.shape |
|
x = x.view( |
|
B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C |
|
) |
|
windows = ( |
|
x.permute(0, 1, 3, 2, 4, 5) |
|
.contiguous() |
|
.view(-1, window_size[0], window_size[1], C) |
|
) |
|
return windows |
|
|
|
|
|
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): |
|
""" |
|
Args: |
|
windows: (num_windows * B, window_size[0], window_size[1], C) |
|
window_size (Tuple[int, int]): Window size |
|
img_size (Tuple[int, int]): Image size |
|
|
|
Returns: |
|
x: (B, H, W, C) |
|
""" |
|
H, W = img_size |
|
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) |
|
x = windows.view( |
|
B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1 |
|
) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|
return x |
|
|
|
|
|
def _fill_window(input_resolution, window_size, shift_size=None): |
|
if shift_size is None: |
|
shift_size = [s // 2 for s in window_size] |
|
|
|
img_mask = torch.zeros((1, *input_resolution, 1)) |
|
h_slices = ( |
|
slice(0, -window_size[0]), |
|
slice(-window_size[0], -shift_size[0]), |
|
slice(-shift_size[0], None), |
|
) |
|
w_slices = ( |
|
slice(0, -window_size[1]), |
|
slice(-window_size[1], -shift_size[1]), |
|
slice(-shift_size[1], None), |
|
) |
|
cnt = 0 |
|
for h in h_slices: |
|
for w in w_slices: |
|
img_mask[:, h, w, :] = cnt |
|
cnt += 1 |
|
|
|
mask_windows = window_partition(img_mask, window_size) |
|
|
|
mask_windows = mask_windows.view(-1, prod(window_size)) |
|
return mask_windows |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_mask(input_resolution, window_size, shift_size): |
|
""" |
|
Use case: 1) |
|
""" |
|
|
|
if isinstance(shift_size, int): |
|
shift_size = to_2tuple(shift_size) |
|
mask_windows = _fill_window(input_resolution, window_size, shift_size) |
|
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( |
|
attn_mask == 0, float(0.0) |
|
) |
|
|
|
return attn_mask |
|
|
|
|
|
def calculate_mask_all( |
|
input_resolution, |
|
window_size, |
|
shift_size, |
|
anchor_window_down_factor=1, |
|
window_to_anchor=True, |
|
): |
|
""" |
|
Use case: 3) |
|
""" |
|
|
|
anchor_resolution = [s // anchor_window_down_factor for s in input_resolution] |
|
aws = [s // anchor_window_down_factor for s in window_size] |
|
anchor_shift = [s // anchor_window_down_factor for s in shift_size] |
|
|
|
|
|
mask_windows = _fill_window(input_resolution, window_size, shift_size) |
|
|
|
mask_anchor = _fill_window(anchor_resolution, aws, anchor_shift) |
|
|
|
if window_to_anchor: |
|
attn_mask = mask_windows.unsqueeze(2) - mask_anchor.unsqueeze(1) |
|
else: |
|
attn_mask = mask_anchor.unsqueeze(2) - mask_windows.unsqueeze(1) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( |
|
attn_mask == 0, float(0.0) |
|
) |
|
|
|
return attn_mask |
|
|
|
|
|
def calculate_win_mask( |
|
input_resolution1, input_resolution2, window_size1, window_size2 |
|
): |
|
""" |
|
Use case: 2) |
|
""" |
|
|
|
|
|
|
|
mask_windows1 = _fill_window(input_resolution1, window_size1) |
|
|
|
mask_windows2 = _fill_window(input_resolution2, window_size2) |
|
|
|
attn_mask = mask_windows1.unsqueeze(2) - mask_windows2.unsqueeze(1) |
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( |
|
attn_mask == 0, float(0.0) |
|
) |
|
|
|
return attn_mask |
|
|
|
|
|
def _get_meshgrid_coords(start_coords, end_coords): |
|
coord_h = torch.arange(start_coords[0], end_coords[0]) |
|
coord_w = torch.arange(start_coords[1], end_coords[1]) |
|
coords = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")) |
|
coords = torch.flatten(coords, 1) |
|
return coords |
|
|
|
|
|
def get_relative_coords_table( |
|
window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1 |
|
): |
|
""" |
|
Use case: 1) |
|
""" |
|
|
|
ws = window_size |
|
aws = [w // anchor_window_down_factor for w in window_size] |
|
pws = pretrained_window_size |
|
paws = [w // anchor_window_down_factor for w in pretrained_window_size] |
|
|
|
ts = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] |
|
pts = [(w1 + w2) // 2 for w1, w2 in zip(pws, paws)] |
|
|
|
|
|
|
|
|
|
|
|
coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32) |
|
coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32) |
|
table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( |
|
1, 2, 0 |
|
) |
|
table = table.contiguous().unsqueeze(0) |
|
if pts[0] > 0: |
|
table[:, :, :, 0] /= pts[0] - 1 |
|
table[:, :, :, 1] /= pts[1] - 1 |
|
else: |
|
table[:, :, :, 0] /= ts[0] - 1 |
|
table[:, :, :, 1] /= ts[1] - 1 |
|
table *= 8 |
|
table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) |
|
return table |
|
|
|
|
|
def get_relative_coords_table_all( |
|
window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1 |
|
): |
|
""" |
|
Use case: 3) |
|
|
|
Support all window shapes. |
|
Args: |
|
window_size: |
|
pretrained_window_size: |
|
anchor_window_down_factor: |
|
|
|
Returns: |
|
|
|
""" |
|
|
|
ws = window_size |
|
aws = [w // anchor_window_down_factor for w in window_size] |
|
pws = pretrained_window_size |
|
paws = [w // anchor_window_down_factor for w in pretrained_window_size] |
|
|
|
|
|
ts_p = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)] |
|
|
|
ts_n = [-(w2 - 1) - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)] |
|
pts = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(pws, paws)] |
|
|
|
|
|
|
|
|
|
|
|
coord_h = torch.arange(ts_n[0], ts_p[0] + 1, dtype=torch.float32) |
|
coord_w = torch.arange(ts_n[1], ts_p[1] + 1, dtype=torch.float32) |
|
table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( |
|
1, 2, 0 |
|
) |
|
table = table.contiguous().unsqueeze(0) |
|
if pts[0] > 0: |
|
table[:, :, :, 0] /= pts[0] |
|
table[:, :, :, 1] /= pts[1] |
|
else: |
|
table[:, :, :, 0] /= ts_p[0] |
|
table[:, :, :, 1] /= ts_p[1] |
|
table *= 8 |
|
table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) |
|
|
|
return table |
|
|
|
|
|
def coords_diff(coords1, coords2, max_diff): |
|
|
|
coords = coords1[:, :, None] - coords2[:, None, :] |
|
coords = coords.permute(1, 2, 0).contiguous() |
|
coords[:, :, 0] += max_diff[0] - 1 |
|
coords[:, :, 1] += max_diff[1] - 1 |
|
coords[:, :, 0] *= 2 * max_diff[1] - 1 |
|
idx = coords.sum(-1) |
|
return idx |
|
|
|
|
|
def get_relative_position_index( |
|
window_size, anchor_window_down_factor=1, window_to_anchor=True |
|
): |
|
""" |
|
Use case: 1) |
|
""" |
|
|
|
ws = window_size |
|
aws = [w // anchor_window_down_factor for w in window_size] |
|
coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] |
|
coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] |
|
|
|
coords = _get_meshgrid_coords((0, 0), window_size) |
|
coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) |
|
|
|
|
|
if window_to_anchor: |
|
idx = coords_diff(coords, coords_anchor, max_diff=coords_anchor_end) |
|
else: |
|
idx = coords_diff(coords_anchor, coords, max_diff=coords_anchor_end) |
|
return idx |
|
|
|
|
|
def coords_diff_odd(coords1, coords2, start_coord, max_diff): |
|
|
|
coords = coords1[:, :, None] - coords2[:, None, :] |
|
coords = coords.permute(1, 2, 0).contiguous() |
|
coords[:, :, 0] += start_coord[0] |
|
coords[:, :, 1] += start_coord[1] |
|
coords[:, :, 0] *= max_diff |
|
idx = coords.sum(-1) |
|
return idx |
|
|
|
|
|
def get_relative_position_index_all( |
|
window_size, anchor_window_down_factor=1, window_to_anchor=True |
|
): |
|
""" |
|
Use case: 3) |
|
Support all window shapes: |
|
square window - square window |
|
rectangular window - rectangular window |
|
window - anchor |
|
anchor - window |
|
[8, 8] - [8, 8] |
|
[4, 86] - [2, 43] |
|
""" |
|
|
|
ws = window_size |
|
aws = [w // anchor_window_down_factor for w in window_size] |
|
coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] |
|
coords_anchor_end = [s + w2 for s, w2 in zip(coords_anchor_start, aws)] |
|
|
|
coords = _get_meshgrid_coords((0, 0), window_size) |
|
coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) |
|
|
|
|
|
max_horizontal_diff = aws[1] + ws[1] - 1 |
|
if window_to_anchor: |
|
offset = [w2 + s - 1 for s, w2 in zip(coords_anchor_start, aws)] |
|
idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff) |
|
else: |
|
offset = [w1 - s - 1 for s, w1 in zip(coords_anchor_start, ws)] |
|
idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff) |
|
return idx |
|
|
|
|
|
def get_relative_position_index_simple( |
|
window_size, anchor_window_down_factor=1, window_to_anchor=True |
|
): |
|
""" |
|
Use case: 3) |
|
This is a simplified version of get_relative_position_index_all |
|
The start coordinate of anchor window is also (0, 0) |
|
get pair-wise relative position index for each token inside the window |
|
""" |
|
ws = window_size |
|
aws = [w // anchor_window_down_factor for w in window_size] |
|
|
|
coords = _get_meshgrid_coords((0, 0), window_size) |
|
coords_anchor = _get_meshgrid_coords((0, 0), aws) |
|
|
|
|
|
max_horizontal_diff = aws[1] + ws[1] - 1 |
|
if window_to_anchor: |
|
offset = [w2 - 1 for w2 in aws] |
|
idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff) |
|
else: |
|
offset = [w1 - 1 for w1 in ws] |
|
idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff) |
|
return idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_relative_win_position_index(window_size, anchor_window_size): |
|
""" |
|
Use case: 2) |
|
""" |
|
|
|
ws = window_size |
|
aws = anchor_window_size |
|
coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)] |
|
coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)] |
|
|
|
coords = _get_meshgrid_coords((0, 0), window_size) |
|
coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end) |
|
|
|
coords = coords[:, :, None] - coords_anchor[:, None, :] |
|
coords = coords.permute(1, 2, 0).contiguous() |
|
coords[:, :, 0] += coords_anchor_end[0] - 1 |
|
coords[:, :, 1] += coords_anchor_end[1] - 1 |
|
coords[:, :, 0] *= 2 * coords_anchor_end[1] - 1 |
|
idx = coords.sum(-1) |
|
return idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_relative_win_coords_table( |
|
window_size, |
|
anchor_window_size, |
|
pretrained_window_size=[0, 0], |
|
pretrained_anchor_window_size=[0, 0], |
|
): |
|
""" |
|
Use case: 2) |
|
""" |
|
|
|
ws = window_size |
|
aws = anchor_window_size |
|
pws = pretrained_window_size |
|
paws = pretrained_anchor_window_size |
|
|
|
|
|
|
|
|
|
|
|
table_size = [(wsi + awsi) // 2 for wsi, awsi in zip(ws, aws)] |
|
table_size_pretrained = [(pwsi + pawsi) // 2 for pwsi, pawsi in zip(pws, paws)] |
|
coord_h = torch.arange(-(table_size[0] - 1), table_size[0], dtype=torch.float32) |
|
coord_w = torch.arange(-(table_size[1] - 1), table_size[1], dtype=torch.float32) |
|
table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute( |
|
1, 2, 0 |
|
) |
|
table = table.contiguous().unsqueeze(0) |
|
if table_size_pretrained[0] > 0: |
|
table[:, :, :, 0] /= table_size_pretrained[0] - 1 |
|
table[:, :, :, 1] /= table_size_pretrained[1] - 1 |
|
else: |
|
table[:, :, :, 0] /= table_size[0] - 1 |
|
table[:, :, :, 1] /= table_size[1] - 1 |
|
table *= 8 |
|
table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8) |
|
return table |
|
|
|
|
|
if __name__ == "__main__": |
|
table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=2) |
|
table = table.view(-1, 2) |
|
index1 = get_relative_position_index_all((4, 86), 2, False) |
|
index2 = get_relative_position_index_simple((4, 86), 2, False) |
|
print(index2) |
|
index3 = get_relative_position_index_all((4, 86), 2) |
|
index4 = get_relative_position_index_simple((4, 86), 2) |
|
print(index4) |
|
print( |
|
table.shape, |
|
index2.shape, |
|
index2.max(), |
|
index2.min(), |
|
index4.shape, |
|
index4.max(), |
|
index4.min(), |
|
torch.allclose(index1, index2), |
|
torch.allclose(index3, index4), |
|
) |
|
|
|
table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=1) |
|
table = table.view(-1, 2) |
|
index1 = get_relative_position_index_all((4, 86), 1, False) |
|
index2 = get_relative_position_index_simple((4, 86), 1, False) |
|
|
|
index3 = get_relative_position_index_all((4, 86), 1) |
|
index4 = get_relative_position_index_simple((4, 86), 1) |
|
|
|
print( |
|
table.shape, |
|
index2.shape, |
|
index2.max(), |
|
index2.min(), |
|
index4.shape, |
|
index4.max(), |
|
index4.min(), |
|
torch.allclose(index1, index2), |
|
torch.allclose(index3, index4), |
|
) |
|
|
|
table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=2) |
|
table = table.view(-1, 2) |
|
index1 = get_relative_position_index_all((8, 8), 2, False) |
|
index2 = get_relative_position_index_simple((8, 8), 2, False) |
|
|
|
index3 = get_relative_position_index_all((8, 8), 2) |
|
index4 = get_relative_position_index_simple((8, 8), 2) |
|
|
|
print( |
|
table.shape, |
|
index2.shape, |
|
index2.max(), |
|
index2.min(), |
|
index4.shape, |
|
index4.max(), |
|
index4.min(), |
|
torch.allclose(index1, index2), |
|
torch.allclose(index3, index4), |
|
) |
|
|
|
table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=1) |
|
table = table.view(-1, 2) |
|
index1 = get_relative_position_index_all((8, 8), 1, False) |
|
index2 = get_relative_position_index_simple((8, 8), 1, False) |
|
|
|
index3 = get_relative_position_index_all((8, 8), 1) |
|
index4 = get_relative_position_index_simple((8, 8), 1) |
|
|
|
print( |
|
table.shape, |
|
index2.shape, |
|
index2.max(), |
|
index2.min(), |
|
index4.shape, |
|
index4.max(), |
|
index4.min(), |
|
torch.allclose(index1, index2), |
|
torch.allclose(index3, index4), |
|
) |
|
|