# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). import torch import torchvision.transforms import torchvision.transforms.functional as F # "Pair": apply a transform on a pair # "Both": apply the exact same transform to both images class ComposePair(torchvision.transforms.Compose): def __call__(self, img1, img2): for t in self.transforms: img1, img2 = t(img1, img2) return img1, img2 class NormalizeBoth(torchvision.transforms.Normalize): def forward(self, img1, img2): img1 = super().forward(img1) img2 = super().forward(img2) return img1, img2 class ToTensorBoth(torchvision.transforms.ToTensor): def __call__(self, img1, img2): img1 = super().__call__(img1) img2 = super().__call__(img2) return img1, img2 class RandomCropPair(torchvision.transforms.RandomCrop): # the crop will be intentionally different for the two images with this class def forward(self, img1, img2): img1 = super().forward(img1) img2 = super().forward(img2) return img1, img2 class ColorJitterPair(torchvision.transforms.ColorJitter): # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob def __init__(self, assymetric_prob, **kwargs): super().__init__(**kwargs) self.assymetric_prob = assymetric_prob def jitter_one( self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ): for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: img = F.adjust_brightness(img, brightness_factor) elif fn_id == 1 and contrast_factor is not None: img = F.adjust_contrast(img, contrast_factor) elif fn_id == 2 and saturation_factor is not None: img = F.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) return img def forward(self, img1, img2): ( fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) img1 = self.jitter_one( img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) if torch.rand(1) < self.assymetric_prob: # assymetric: ( fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) = self.get_params( self.brightness, self.contrast, self.saturation, self.hue ) img2 = self.jitter_one( img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) return img1, img2 def get_pair_transforms(transform_str, totensor=True, normalize=True): # transform_str is eg crop224+color trfs = [] for s in transform_str.split("+"): if s.startswith("crop"): size = int(s[len("crop") :]) trfs.append(RandomCropPair(size)) elif s == "acolor": trfs.append( ColorJitterPair( assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0, ) ) elif s == "": # if transform_str was "" pass else: raise NotImplementedError("Unknown augmentation: " + s) if totensor: trfs.append(ToTensorBoth()) if normalize: trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) if len(trfs) == 0: return None elif len(trfs) == 1: return trfs else: return ComposePair(trfs)