Spaces:
Sleeping
Sleeping
import torchvision.transforms as transforms | |
# from torchvision.transforms.functional import InterpolationMode | |
from PIL import Image, ImageFilter | |
import random | |
import torch | |
import numpy as np | |
import logging | |
from enum import Enum | |
from .augmentation.warp import Curve, Distort, Stretch | |
from .augmentation.geometry import Rotate, Perspective, Shrink, TranslateX, TranslateY | |
from .augmentation.pattern import VGrid, HGrid, Grid, RectGrid, EllipseGrid | |
from .augmentation.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise | |
from .augmentation.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur | |
from .augmentation.camera import Contrast, Brightness, JpegCompression, Pixelate | |
from .augmentation.weather import Fog, Snow, Frost, Rain, Shadow | |
from .augmentation.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color | |
# 0: InterpolationMode.NEAREST, | |
# 2: InterpolationMode.BILINEAR, | |
# 3: InterpolationMode.BICUBIC, | |
# 4: InterpolationMode.BOX, | |
# 5: InterpolationMode.HAMMING, | |
# 1: InterpolationMode.LANCZOS, | |
class InterpolationMode(): | |
NEAREST = 0 | |
BILINEAR = 2 | |
BICUBIC = 3 | |
BOX = 4 | |
HAMMING = 5 | |
LANCZOS = 1 | |
logger = logging.getLogger(__name__) | |
class ResizePad(object): | |
def __init__(self, imgH=64, imgW=3072, keep_ratio_with_pad=True): | |
self.imgH = imgH | |
self.imgW = imgW | |
assert keep_ratio_with_pad == True | |
self.keep_ratio_with_pad = keep_ratio_with_pad | |
def __call__(self, im): | |
old_size = im.size # old_size[0] is in (width, height) format | |
ratio = float(self.imgH)/old_size[1] | |
new_size = tuple([int(x*ratio) for x in old_size]) | |
im = im.resize(new_size, Image.BICUBIC) | |
new_im = Image.new("RGB", (self.imgW, self.imgH)) | |
new_im.paste(im, (0, 0)) | |
return new_im | |
class WeightedRandomChoice: | |
def __init__(self, trans, weights=None): | |
self.trans = trans | |
if not weights: | |
self.weights = [1] * len(trans) | |
else: | |
assert len(trans) == len(weights) | |
self.weights = weights | |
def __call__(self, img): | |
t = random.choices(self.trans, weights=self.weights, k=1)[0] | |
try: | |
tfm_img = t(img) | |
except Exception as e: | |
logger.warning('Error during data_aug: '+str(e)) | |
return img | |
return tfm_img | |
def __repr__(self): | |
format_string = self.__class__.__name__ + '(' | |
for t in self.transforms: | |
format_string += '\n' | |
format_string += ' {0}'.format(t) | |
format_string += '\n)' | |
return format_string | |
class Dilation(torch.nn.Module): | |
def __init__(self, kernel=3): | |
super().__init__() | |
self.kernel=kernel | |
def forward(self, img): | |
return img.filter(ImageFilter.MaxFilter(self.kernel)) | |
def __repr__(self): | |
return self.__class__.__name__ + '(kernel={})'.format(self.kernel) | |
class Erosion(torch.nn.Module): | |
def __init__(self, kernel=3): | |
super().__init__() | |
self.kernel=kernel | |
def forward(self, img): | |
return img.filter(ImageFilter.MinFilter(self.kernel)) | |
def __repr__(self): | |
return self.__class__.__name__ + '(kernel={})'.format(self.kernel) | |
class Underline(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, img): | |
img_np = np.array(img.convert('L')) | |
black_pixels = np.where(img_np < 50) | |
try: | |
y1 = max(black_pixels[0]) | |
x0 = min(black_pixels[1]) | |
x1 = max(black_pixels[1]) | |
except: | |
return img | |
for x in range(x0, x1): | |
for y in range(y1, y1-3, -1): | |
try: | |
img.putpixel((x, y), (0, 0, 0)) | |
except: | |
continue | |
return img | |
class KeepOriginal(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, img): | |
return img | |
def build_data_aug(size, mode, resnet=False, resizepad=False): | |
if resnet: | |
norm_tfm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
else: | |
norm_tfm = transforms.Normalize(0.5, 0.5) | |
if resizepad: | |
resize_tfm = ResizePad(imgH=size[0], imgW=size[1]) | |
else: | |
resize_tfm = transforms.Resize(size, interpolation=InterpolationMode.BICUBIC) | |
if mode == 'train': | |
return transforms.Compose([ | |
WeightedRandomChoice([ | |
# transforms.RandomHorizontalFlip(p=1), | |
transforms.RandomRotation(degrees=(-10, 10), expand=True, fill=255), | |
transforms.GaussianBlur(3), | |
Dilation(3), | |
Erosion(3), | |
transforms.Resize((size[0] // 3, size[1] // 3), interpolation=InterpolationMode.NEAREST), | |
Underline(), | |
KeepOriginal(), | |
]), | |
resize_tfm, | |
transforms.ToTensor(), | |
norm_tfm | |
]) | |
else: | |
return transforms.Compose([ | |
resize_tfm, | |
transforms.ToTensor(), | |
norm_tfm | |
]) | |
class OptForDataAugment: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def isless(prob=0.5): | |
return np.random.uniform(0,1) < prob | |
class DataAugment(object): | |
''' | |
Supports with and without data augmentation | |
''' | |
def __init__(self, opt): | |
self.opt = opt | |
if not opt.eval: | |
self.process = [Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()] | |
self.camera = [Contrast(), Brightness(), JpegCompression(), Pixelate()] | |
self.pattern = [VGrid(), HGrid(), Grid(), RectGrid(), EllipseGrid()] | |
self.noise = [GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()] | |
self.blur = [GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur(), ZoomBlur()] | |
self.weather = [Fog(), Snow(), Frost(), Rain(), Shadow()] | |
self.noises = [self.blur, self.noise, self.weather] | |
self.processes = [self.camera, self.process] | |
self.warp = [Curve(), Distort(), Stretch()] | |
self.geometry = [Rotate(), Perspective(), Shrink()] | |
self.isbaseline_aug = False | |
# rand augment | |
if self.opt.isrand_aug: | |
self.augs = [self.process, self.camera, self.noise, self.blur, self.weather, self.pattern, self.warp, self.geometry] | |
# semantic augment | |
elif self.opt.issemantic_aug: | |
self.geometry = [Rotate(), Perspective(), Shrink()] | |
self.noise = [GaussianNoise()] | |
self.blur = [MotionBlur()] | |
self.augs = [self.noise, self.blur, self.geometry] | |
self.isbaseline_aug = True | |
# pp-ocr augment | |
elif self.opt.islearning_aug: | |
self.geometry = [Rotate(), Perspective()] | |
self.noise = [GaussianNoise()] | |
self.blur = [MotionBlur()] | |
self.warp = [Distort()] | |
self.augs = [self.warp, self.noise, self.blur, self.geometry] | |
self.isbaseline_aug = True | |
# scatter augment | |
elif self.opt.isscatter_aug: | |
self.geometry = [Shrink()] | |
self.warp = [Distort()] | |
self.augs = [self.warp, self.geometry] | |
self.baseline_aug = True | |
# rotation augment | |
elif self.opt.isrotation_aug: | |
self.geometry = [Rotate()] | |
self.augs = [self.geometry] | |
self.isbaseline_aug = True | |
def __call__(self, img): | |
''' | |
Must call img.copy() if pattern, Rain or Shadow is used | |
''' | |
img = img.resize((self.opt.imgW, self.opt.imgH), Image.BICUBIC) | |
if self.opt.eval or isless(self.opt.intact_prob): | |
pass | |
elif self.opt.isrand_aug or self.isbaseline_aug: | |
img = self.rand_aug(img) | |
# individual augment can also be selected | |
elif self.opt.issel_aug: | |
img = self.sel_aug(img) | |
img = transforms.ToTensor()(img) | |
img = transforms.Normalize(0.5, 0.5)(img) | |
return img | |
def rand_aug(self, img): | |
augs = np.random.choice(self.augs, self.opt.augs_num, replace=False) | |
for aug in augs: | |
index = np.random.randint(0, len(aug)) | |
op = aug[index] | |
mag = np.random.randint(0, 3) if self.opt.augs_mag is None else self.opt.augs_mag | |
if type(op).__name__ == "Rain" or type(op).__name__ == "Grid": | |
img = op(img.copy(), mag=mag) | |
else: | |
img = op(img, mag=mag) | |
return img | |
def sel_aug(self, img): | |
prob = 1. | |
if self.opt.process: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.process)) | |
op = self.process[index] | |
img = op(img, mag=mag, prob=prob) | |
if self.opt.noise: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.noise)) | |
op = self.noise[index] | |
img = op(img, mag=mag, prob=prob) | |
if self.opt.blur: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.blur)) | |
op = self.blur[index] | |
img = op(img, mag=mag, prob=prob) | |
if self.opt.weather: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.weather)) | |
op = self.weather[index] | |
if type(op).__name__ == "Rain": #or "Grid" in type(op).__name__ : | |
img = op(img.copy(), mag=mag, prob=prob) | |
else: | |
img = op(img, mag=mag, prob=prob) | |
if self.opt.camera: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.camera)) | |
op = self.camera[index] | |
img = op(img, mag=mag, prob=prob) | |
if self.opt.pattern: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.pattern)) | |
op = self.pattern[index] | |
img = op(img.copy(), mag=mag, prob=prob) | |
iscurve = False | |
if self.opt.warp: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.warp)) | |
op = self.warp[index] | |
if type(op).__name__ == "Curve": | |
iscurve = True | |
img = op(img, mag=mag, prob=prob) | |
if self.opt.geometry: | |
mag = np.random.randint(0, 3) | |
index = np.random.randint(0, len(self.geometry)) | |
op = self.geometry[index] | |
if type(op).__name__ == "Rotate": | |
img = op(img, iscurve=iscurve, mag=mag, prob=prob) | |
else: | |
img = op(img, mag=mag, prob=prob) | |
return img |