Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from typing import List | |
import torch as th | |
import torch.nn as nn | |
from torchvision.transforms.functional import gaussian_blur | |
class LearnableBlur(nn.Module): | |
# TODO: should we make this conditional? | |
def __init__(self, cameras: List[str]) -> None: | |
super().__init__() | |
self.cameras = cameras | |
self.register_parameter( | |
"weights_raw", nn.Parameter(th.ones(len(cameras), 3, dtype=th.float32)) | |
) | |
def name_to_idx(self, cameras: List[str]) -> th.Tensor: | |
return th.tensor( | |
[self.cameras.index(c) for c in cameras], | |
device=self.weights_raw.device, | |
dtype=th.long, | |
) | |
# pyre-ignore | |
def reg(self, cameras: List[str]): | |
# pyre-ignore | |
idxs = self.name_to_idx(cameras) | |
# pyre-ignore | |
return self.weights_raw[idxs] | |
# pyre-ignore | |
def forward(self, img: th.Tensor, cameras: List[str]): | |
B = img.shape[0] | |
# B, C, H, W | |
idxs = self.name_to_idx(cameras) | |
# TODO: mask? | |
# pyre-ignore | |
weights = th.softmax(self.weights_raw[idxs], dim=-1) | |
weights = weights.reshape(B, 3, 1, 1, 1) | |
return ( | |
weights[:, 0] * img | |
+ weights[:, 1] * gaussian_blur(img, [3, 3]) | |
+ weights[:, 2] * gaussian_blur(img, [7, 7]) | |
) | |