from typing import Tuple, List, Union, cast import torch from kornia.geometry.transform import vflip, rotate UnionType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] def random_rotate(input: torch.Tensor) -> UnionType: r"""Rotate a tensor image or a batch of tensor images randomly. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: input tensor. Returns: torch.Tensor: The rotated input """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") device: torch.device = input.device input = input.unsqueeze(0) input = input.view((-1, (*input.shape[-3:]))) angle: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(-180, -180) rotated = rotate(input, angle) return rotated def random_vflip(input: torch.Tensor, p: float = 0.5, return_transform: bool = False) -> UnionType: r"""Vertically flip a tensor image or a batch of tensor images randomly with a given probability. Input should be a tensor of shape (C, H, W) or a batch of tensors :math:`(*, C, H, W)`. Args: p (float): probability of the image being flipped. Default value is 0.5 return_transform (bool): if ``True`` return the matrix describing the transformation applied to each input tensor. Returns: torch.Tensor: The vertically flipped input torch.Tensor: The applied transformation matrix :math: `(*, 3, 3)` if return_transform flag is set to ``True`` """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not isinstance(p, float): raise TypeError(f"The probability should be a float number. Got {type(p)}") if not isinstance(return_transform, bool): raise TypeError(f"The return_transform flag must be a bool. Got {type(return_transform)}") device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.unsqueeze(0) input = input.view((-1, (*input.shape[-3:]))) probs: torch.Tensor = torch.empty(input.shape[0], device=device).uniform_(0, 1) to_flip: torch.Tensor = probs < p flipped: torch.Tensor = input.clone() flipped[to_flip] = vflip(input[to_flip]) if return_transform: trans_mat: torch.Tensor = torch.eye(3, device=device, dtype=dtype).expand(input.shape[0], -1, -1) w: int = input.shape[-2] flip_mat: torch.Tensor = torch.tensor([[-1, 0, w], [0, 1, 0], [0, 0, 1]]) trans_mat[to_flip] = flip_mat.to(device).to(dtype) return flipped, trans_mat return flipped