Spaces:
Running
Running
r"""Custom layers used in metrics computations""" | |
import torch | |
from typing import Optional | |
from .filters import hann_filter | |
class L2Pool2d(torch.nn.Module): | |
r"""Applies L2 pooling with Hann window of size 3x3 | |
Args: | |
x: Tensor with shape (N, C, H, W)""" | |
EPS = 1e-12 | |
def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None: | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.kernel: Optional[torch.Tensor] = None | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.kernel is None: | |
C = x.size(1) | |
self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x) | |
out = torch.nn.functional.conv2d( | |
x ** 2, self.kernel, | |
stride=self.stride, | |
padding=self.padding, | |
groups=x.shape[1] | |
) | |
return (out + self.EPS).sqrt() | |