import torchvision.transforms as transforms # from torchvision.transforms.functional import InterpolationMode from PIL import Image, ImageFilter import random import torch import numpy as np import logging from enum import Enum from .augmentation.warp import Curve, Distort, Stretch from .augmentation.geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY from .augmentation.pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid from .augmentation.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise from .augmentation.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur from .augmentation.camera import Contrast, Brightness, JpegCompression, Pixelate from .augmentation.weather import Fog, Snow, Frost, Rain, Shadow from .augmentation.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color # 0: InterpolationMode.NEAREST, # 2: InterpolationMode.BILINEAR, # 3: InterpolationMode.BICUBIC, # 4: InterpolationMode.BOX, # 5: InterpolationMode.HAMMING, # 1: InterpolationMode.LANCZOS, class InterpolationMode(): NEAREST = 0 BILINEAR = 2 BICUBIC = 3 BOX = 4 HAMMING = 5 LANCZOS = 1 logger = logging.getLogger(__name__) class ResizePad(object): def __init__(self, imgH=64, imgW=3072, keep_ratio_with_pad=True): self.imgH = imgH self.imgW = imgW assert keep_ratio_with_pad == True self.keep_ratio_with_pad = keep_ratio_with_pad def __call__(self, im): old_size = im.size # old_size[0] is in (width, height) format ratio = float(self.imgH)/old_size[1] new_size = tuple([int(x*ratio) for x in old_size]) im = im.resize(new_size, Image.BICUBIC) new_im = Image.new("RGB", (self.imgW, self.imgH)) new_im.paste(im, (0, 0)) return new_im class WeightedRandomChoice: def __init__(self, trans, weights=None): self.trans = trans if not weights: self.weights = [1] * len(trans) else: assert len(trans) == len(weights) self.weights = weights def __call__(self, img): t = random.choices(self.trans, weights=self.weights, k=1)[0] try: tfm_img = t(img) except Exception as e: logger.warning('Error during data_aug: '+str(e)) return img return tfm_img def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string class Dilation(torch.nn.Module): def __init__(self, kernel=3): super().__init__() self.kernel=kernel def forward(self, img): return img.filter(ImageFilter.MaxFilter(self.kernel)) def __repr__(self): return self.__class__.__name__ + '(kernel={})'.format(self.kernel) class Erosion(torch.nn.Module): def __init__(self, kernel=3): super().__init__() self.kernel=kernel def forward(self, img): return img.filter(ImageFilter.MinFilter(self.kernel)) def __repr__(self): return self.__class__.__name__ + '(kernel={})'.format(self.kernel) class Underline(torch.nn.Module): def __init__(self): super().__init__() def forward(self, img): img_np = np.array(img.convert('L')) black_pixels = np.where(img_np < 50) try: y1 = max(black_pixels[0]) x0 = min(black_pixels[1]) x1 = max(black_pixels[1]) except: return img for x in range(x0, x1): for y in range(y1, y1-3, -1): try: img.putpixel((x, y), (0, 0, 0)) except: continue return img class KeepOriginal(torch.nn.Module): def __init__(self): super().__init__() def forward(self, img): return img def build_data_aug(size, mode, resnet=False, resizepad=False): if resnet: norm_tfm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: norm_tfm = transforms.Normalize(0.5, 0.5) if resizepad: resize_tfm = ResizePad(imgH=size[0], imgW=size[1]) else: resize_tfm = transforms.Resize(size, interpolation=InterpolationMode.BICUBIC) if mode == 'train': return transforms.Compose([ WeightedRandomChoice([ # transforms.RandomHorizontalFlip(p=1), transforms.RandomRotation(degrees=(-10, 10), expand=True, fill=255), transforms.GaussianBlur(3), Dilation(3), Erosion(3), transforms.Resize((size[0] // 3, size[1] // 3), interpolation=InterpolationMode.NEAREST), Underline(), KeepOriginal(), ]), resize_tfm, transforms.ToTensor(), norm_tfm ]) else: return transforms.Compose([ resize_tfm, transforms.ToTensor(), norm_tfm ]) class OptForDataAugment: def __init__(self, **kwargs): self.__dict__.update(kwargs) def isless(prob=0.5): return np.random.uniform(0,1) < prob class DataAugment(object): ''' Supports with and without data augmentation ''' def __init__(self, opt): self.opt = opt if not opt.eval: self.process = [Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()] self.camera = [Contrast(), Brightness(), JpegCompression(), Pixelate()] self.pattern = [VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()] self.noise = [GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()] self.blur = [GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()] self.weather = [Fog(), Snow(), Frost(), Rain(), Shadow()] self.noises = [self.blur, self.noise, self.weather] self.processes = [self.camera, self.process] self.warp = [Curve(), Distort(), Stretch()] self.geometry = [Rotate(), Perspective(), Shrink()] self.isbaseline_aug = False # rand augment if self.opt.isrand_aug: self.augs = [self.process, self.camera, self.noise, self.blur, self.weather, self.pattern, self.warp, self.geometry] # semantic augment elif self.opt.issemantic_aug: self.geometry = [Rotate(), Perspective(), Shrink()] self.noise = [GaussianNoise()] self.blur = [MotionBlur()] self.augs = [self.noise, self.blur, self.geometry] self.isbaseline_aug = True # pp-ocr augment elif self.opt.islearning_aug: self.geometry = [Rotate(), Perspective()] self.noise = [GaussianNoise()] self.blur = [MotionBlur()] self.warp = [Distort()] self.augs = [self.warp, self.noise, self.blur, self.geometry] self.isbaseline_aug = True # scatter augment elif self.opt.isscatter_aug: self.geometry = [Shrink()] self.warp = [Distort()] self.augs = [self.warp, self.geometry] self.baseline_aug = True # rotation augment elif self.opt.isrotation_aug: self.geometry = [Rotate()] self.augs = [self.geometry] self.isbaseline_aug = True def __call__(self, img): ''' Must call img.copy() if pattern, Rain or Shadow is used ''' img = img.resize((self.opt.imgW, self.opt.imgH), Image.BICUBIC) if self.opt.eval or isless(self.opt.intact_prob): pass elif self.opt.isrand_aug or self.isbaseline_aug: img = self.rand_aug(img) # individual augment can also be selected elif self.opt.issel_aug: img = self.sel_aug(img) img = transforms.ToTensor()(img) img = transforms.Normalize(0.5, 0.5)(img) return img def rand_aug(self, img): augs = np.random.choice(self.augs, self.opt.augs_num, replace=False) for aug in augs: index = np.random.randint(0, len(aug)) op = aug[index] mag = np.random.randint(0, 3) if self.opt.augs_mag is None else self.opt.augs_mag if type(op).__name__ == "Rain" or type(op).__name__ == "Grid": img = op(img.copy(), mag=mag) else: img = op(img, mag=mag) return img def sel_aug(self, img): prob = 1. if self.opt.process: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.process)) op = self.process[index] img = op(img, mag=mag, prob=prob) if self.opt.noise: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.noise)) op = self.noise[index] img = op(img, mag=mag, prob=prob) if self.opt.blur: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.blur)) op = self.blur[index] img = op(img, mag=mag, prob=prob) if self.opt.weather: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.weather)) op = self.weather[index] if type(op).__name__ == "Rain": #or "Grid" in type(op).__name__ : img = op(img.copy(), mag=mag, prob=prob) else: img = op(img, mag=mag, prob=prob) if self.opt.camera: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.camera)) op = self.camera[index] img = op(img, mag=mag, prob=prob) if self.opt.pattern: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.pattern)) op = self.pattern[index] img = op(img.copy(), mag=mag, prob=prob) iscurve = False if self.opt.warp: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.warp)) op = self.warp[index] if type(op).__name__ == "Curve": iscurve = True img = op(img, mag=mag, prob=prob) if self.opt.geometry: mag = np.random.randint(0, 3) index = np.random.randint(0, len(self.geometry)) op = self.geometry[index] if type(op).__name__ == "Rotate": img = op(img, iscurve=iscurve, mag=mag, prob=prob) else: img = op(img, mag=mag, prob=prob) return img