Tzktz's picture
Upload 7664 files
6fc683c verified
import torchvision.transforms as transforms
# from torchvision.transforms.functional import InterpolationMode
from PIL import Image, ImageFilter
import random
import torch
import numpy as np
import logging
from enum import Enum
from .augmentation.warp import Curve, Distort, Stretch
from .augmentation.geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY
from .augmentation.pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid
from .augmentation.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise
from .augmentation.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur
from .augmentation.camera import Contrast, Brightness, JpegCompression, Pixelate
from .augmentation.weather import Fog, Snow, Frost, Rain, Shadow
from .augmentation.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color
# 0: InterpolationMode.NEAREST,
# 2: InterpolationMode.BILINEAR,
# 3: InterpolationMode.BICUBIC,
# 4: InterpolationMode.BOX,
# 5: InterpolationMode.HAMMING,
# 1: InterpolationMode.LANCZOS,
class InterpolationMode():
NEAREST = 0
BILINEAR = 2
BICUBIC = 3
BOX = 4
HAMMING = 5
LANCZOS = 1
logger = logging.getLogger(__name__)
class ResizePad(object):
def __init__(self, imgH=64, imgW=3072, keep_ratio_with_pad=True):
self.imgH = imgH
self.imgW = imgW
assert keep_ratio_with_pad == True
self.keep_ratio_with_pad = keep_ratio_with_pad
def __call__(self, im):
old_size = im.size # old_size[0] is in (width, height) format
ratio = float(self.imgH)/old_size[1]
new_size = tuple([int(x*ratio) for x in old_size])
im = im.resize(new_size, Image.BICUBIC)
new_im = Image.new("RGB", (self.imgW, self.imgH))
new_im.paste(im, (0, 0))
return new_im
class WeightedRandomChoice:
def __init__(self, trans, weights=None):
self.trans = trans
if not weights:
self.weights = [1] * len(trans)
else:
assert len(trans) == len(weights)
self.weights = weights
def __call__(self, img):
t = random.choices(self.trans, weights=self.weights, k=1)[0]
try:
tfm_img = t(img)
except Exception as e:
logger.warning('Error during data_aug: '+str(e))
return img
return tfm_img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class Dilation(torch.nn.Module):
def __init__(self, kernel=3):
super().__init__()
self.kernel=kernel
def forward(self, img):
return img.filter(ImageFilter.MaxFilter(self.kernel))
def __repr__(self):
return self.__class__.__name__ + '(kernel={})'.format(self.kernel)
class Erosion(torch.nn.Module):
def __init__(self, kernel=3):
super().__init__()
self.kernel=kernel
def forward(self, img):
return img.filter(ImageFilter.MinFilter(self.kernel))
def __repr__(self):
return self.__class__.__name__ + '(kernel={})'.format(self.kernel)
class Underline(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, img):
img_np = np.array(img.convert('L'))
black_pixels = np.where(img_np < 50)
try:
y1 = max(black_pixels[0])
x0 = min(black_pixels[1])
x1 = max(black_pixels[1])
except:
return img
for x in range(x0, x1):
for y in range(y1, y1-3, -1):
try:
img.putpixel((x, y), (0, 0, 0))
except:
continue
return img
class KeepOriginal(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, img):
return img
def build_data_aug(size, mode, resnet=False, resizepad=False):
if resnet:
norm_tfm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
norm_tfm = transforms.Normalize(0.5, 0.5)
if resizepad:
resize_tfm = ResizePad(imgH=size[0], imgW=size[1])
else:
resize_tfm = transforms.Resize(size, interpolation=InterpolationMode.BICUBIC)
if mode == 'train':
return transforms.Compose([
WeightedRandomChoice([
# transforms.RandomHorizontalFlip(p=1),
transforms.RandomRotation(degrees=(-10, 10), expand=True, fill=255),
transforms.GaussianBlur(3),
Dilation(3),
Erosion(3),
transforms.Resize((size[0] // 3, size[1] // 3), interpolation=InterpolationMode.NEAREST),
Underline(),
KeepOriginal(),
]),
resize_tfm,
transforms.ToTensor(),
norm_tfm
])
else:
return transforms.Compose([
resize_tfm,
transforms.ToTensor(),
norm_tfm
])
class OptForDataAugment:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def isless(prob=0.5):
return np.random.uniform(0,1) < prob
class DataAugment(object):
'''
Supports with and without data augmentation
'''
def __init__(self, opt):
self.opt = opt
if not opt.eval:
self.process = [Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()]
self.camera = [Contrast(), Brightness(), JpegCompression(), Pixelate()]
self.pattern = [VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()]
self.noise = [GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()]
self.blur = [GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()]
self.weather = [Fog(), Snow(), Frost(), Rain(), Shadow()]
self.noises = [self.blur, self.noise, self.weather]
self.processes = [self.camera, self.process]
self.warp = [Curve(), Distort(), Stretch()]
self.geometry = [Rotate(), Perspective(), Shrink()]
self.isbaseline_aug = False
# rand augment
if self.opt.isrand_aug:
self.augs = [self.process, self.camera, self.noise, self.blur, self.weather, self.pattern, self.warp, self.geometry]
# semantic augment
elif self.opt.issemantic_aug:
self.geometry = [Rotate(), Perspective(), Shrink()]
self.noise = [GaussianNoise()]
self.blur = [MotionBlur()]
self.augs = [self.noise, self.blur, self.geometry]
self.isbaseline_aug = True
# pp-ocr augment
elif self.opt.islearning_aug:
self.geometry = [Rotate(), Perspective()]
self.noise = [GaussianNoise()]
self.blur = [MotionBlur()]
self.warp = [Distort()]
self.augs = [self.warp, self.noise, self.blur, self.geometry]
self.isbaseline_aug = True
# scatter augment
elif self.opt.isscatter_aug:
self.geometry = [Shrink()]
self.warp = [Distort()]
self.augs = [self.warp, self.geometry]
self.baseline_aug = True
# rotation augment
elif self.opt.isrotation_aug:
self.geometry = [Rotate()]
self.augs = [self.geometry]
self.isbaseline_aug = True
def __call__(self, img):
'''
Must call img.copy() if pattern, Rain or Shadow is used
'''
img = img.resize((self.opt.imgW, self.opt.imgH), Image.BICUBIC)
if self.opt.eval or isless(self.opt.intact_prob):
pass
elif self.opt.isrand_aug or self.isbaseline_aug:
img = self.rand_aug(img)
# individual augment can also be selected
elif self.opt.issel_aug:
img = self.sel_aug(img)
img = transforms.ToTensor()(img)
img = transforms.Normalize(0.5, 0.5)(img)
return img
def rand_aug(self, img):
augs = np.random.choice(self.augs, self.opt.augs_num, replace=False)
for aug in augs:
index = np.random.randint(0, len(aug))
op = aug[index]
mag = np.random.randint(0, 3) if self.opt.augs_mag is None else self.opt.augs_mag
if type(op).__name__ == "Rain" or type(op).__name__ == "Grid":
img = op(img.copy(), mag=mag)
else:
img = op(img, mag=mag)
return img
def sel_aug(self, img):
prob = 1.
if self.opt.process:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.process))
op = self.process[index]
img = op(img, mag=mag, prob=prob)
if self.opt.noise:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.noise))
op = self.noise[index]
img = op(img, mag=mag, prob=prob)
if self.opt.blur:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.blur))
op = self.blur[index]
img = op(img, mag=mag, prob=prob)
if self.opt.weather:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.weather))
op = self.weather[index]
if type(op).__name__ == "Rain": #or "Grid" in type(op).__name__ :
img = op(img.copy(), mag=mag, prob=prob)
else:
img = op(img, mag=mag, prob=prob)
if self.opt.camera:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.camera))
op = self.camera[index]
img = op(img, mag=mag, prob=prob)
if self.opt.pattern:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.pattern))
op = self.pattern[index]
img = op(img.copy(), mag=mag, prob=prob)
iscurve = False
if self.opt.warp:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.warp))
op = self.warp[index]
if type(op).__name__ == "Curve":
iscurve = True
img = op(img, mag=mag, prob=prob)
if self.opt.geometry:
mag = np.random.randint(0, 3)
index = np.random.randint(0, len(self.geometry))
op = self.geometry[index]
if type(op).__name__ == "Rotate":
img = op(img, iscurve=iscurve, mag=mag, prob=prob)
else:
img = op(img, mag=mag, prob=prob)
return img