Spaces:
Sleeping
Sleeping
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nncore.nn import MODELS | |
class Permute(nn.Module): | |
def __init__(self): | |
super(Permute, self).__init__() | |
def forward(self, x): | |
return x.transpose(-1, -2) | |
class ConvPyramid(nn.Module): | |
def __init__(self, dims, strides): | |
super(ConvPyramid, self).__init__() | |
self.blocks = nn.ModuleList() | |
for s in strides: | |
p = int(math.log2(s)) | |
if p == 0: | |
layers = nn.ReLU(inplace=True) | |
else: | |
layers = nn.Sequential() | |
conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d | |
for _ in range(abs(p)): | |
layers.extend([ | |
Permute(), | |
conv_cls(dims, dims, 2, stride=2), | |
Permute(), | |
nn.LayerNorm(dims), | |
nn.ReLU(inplace=True) | |
]) | |
self.blocks.append(layers) | |
self.strides = strides | |
def forward(self, x, mask, return_mask=False): | |
pymid, pymid_msk = [], [] | |
for s, blk in zip(self.strides, self.blocks): | |
if x.size(1) < s: | |
continue | |
pymid.append(blk(x)) | |
if return_mask: | |
if s > 1: | |
msk = F.max_pool1d(mask.float(), s, stride=s).long() | |
elif s < 1: | |
msk = mask.repeat_interleave(int(1 / s), dim=1) | |
else: | |
msk = mask | |
pymid_msk.append(msk) | |
return pymid, pymid_msk | |
class AdaPooling(nn.Module): | |
def __init__(self, dims): | |
super(AdaPooling, self).__init__() | |
self.att = nn.Linear(dims, 1, bias=False) | |
def forward(self, x, mask): | |
a = self.att(x) + torch.where(mask.unsqueeze(2) == 1, .0, float('-inf')) | |
a = a.softmax(dim=1) | |
x = torch.matmul(x.transpose(1, 2), a) | |
x = x.squeeze(2).unsqueeze(1) | |
return x | |
class ConvHead(nn.Module): | |
def __init__(self, dims, out_dims, kernal_size=3): | |
super(ConvHead, self).__init__() | |
# yapf:disable | |
self.module = nn.Sequential( | |
Permute(), | |
nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2), | |
nn.ReLU(inplace=True), | |
nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2), | |
Permute()) | |
# yapf:enable | |
def forward(self, x): | |
return self.module(x) | |