Spaces:
Running
Running
File size: 967 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
r"""Custom layers used in metrics computations"""
import torch
from typing import Optional
from .filters import hann_filter
class L2Pool2d(torch.nn.Module):
r"""Applies L2 pooling with Hann window of size 3x3
Args:
x: Tensor with shape (N, C, H, W)"""
EPS = 1e-12
def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None:
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.kernel: Optional[torch.Tensor] = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.kernel is None:
C = x.size(1)
self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x)
out = torch.nn.functional.conv2d(
x ** 2, self.kernel,
stride=self.stride,
padding=self.padding,
groups=x.shape[1]
)
return (out + self.EPS).sqrt()
|