|
import cv2 |
|
import random |
|
import time |
|
import numpy as np |
|
import torch |
|
from torch.utils import data as data |
|
|
|
from basicsr.data.transforms import rgb2lab |
|
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor |
|
from basicsr.utils.registry import DATASET_REGISTRY |
|
from basicsr.data.fmix import sample_mask |
|
|
|
|
|
@DATASET_REGISTRY.register() |
|
class LabDataset(data.Dataset): |
|
""" |
|
Dataset used for Lab colorizaion |
|
""" |
|
|
|
def __init__(self, opt): |
|
super(LabDataset, self).__init__() |
|
self.opt = opt |
|
|
|
self.file_client = None |
|
self.io_backend_opt = opt['io_backend'] |
|
self.gt_folder = opt['dataroot_gt'] |
|
|
|
meta_info_file = self.opt['meta_info_file'] |
|
assert meta_info_file is not None |
|
if not isinstance(meta_info_file, list): |
|
meta_info_file = [meta_info_file] |
|
self.paths = [] |
|
for meta_info in meta_info_file: |
|
with open(meta_info, 'r') as fin: |
|
self.paths.extend([line.strip() for line in fin]) |
|
|
|
self.min_ab, self.max_ab = -128, 128 |
|
self.interval_ab = 4 |
|
self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)] |
|
|
|
|
|
self.do_fmix = opt['do_fmix'] |
|
self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False} |
|
self.fmix_p = opt['fmix_p'] |
|
self.do_cutmix = opt['do_cutmix'] |
|
self.cutmix_params = {'alpha':1.} |
|
self.cutmix_p = opt['cutmix_p'] |
|
|
|
|
|
def __getitem__(self, index): |
|
if self.file_client is None: |
|
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) |
|
|
|
|
|
|
|
gt_path = self.paths[index] |
|
gt_size = self.opt['gt_size'] |
|
|
|
retry = 3 |
|
while retry > 0: |
|
try: |
|
img_bytes = self.file_client.get(gt_path, 'gt') |
|
except Exception as e: |
|
logger = get_root_logger() |
|
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') |
|
|
|
index = random.randint(0, self.__len__()) |
|
gt_path = self.paths[index] |
|
time.sleep(1) |
|
else: |
|
break |
|
finally: |
|
retry -= 1 |
|
img_gt = imfrombytes(img_bytes, float32=True) |
|
img_gt = cv2.resize(img_gt, (gt_size, gt_size)) |
|
|
|
|
|
if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p: |
|
with torch.no_grad(): |
|
lam, mask = sample_mask(**self.fmix_params) |
|
|
|
fmix_index = random.randint(0, self.__len__()) |
|
fmix_img_path = self.paths[fmix_index] |
|
fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt') |
|
fmix_img = imfrombytes(fmix_img_bytes, float32=True) |
|
fmix_img = cv2.resize(fmix_img, (gt_size, gt_size)) |
|
|
|
mask = mask.transpose(1, 2, 0) |
|
img_gt = mask * img_gt + (1. - mask) * fmix_img |
|
img_gt = img_gt.astype(np.float32) |
|
|
|
if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p: |
|
with torch.no_grad(): |
|
cmix_index = random.randint(0, self.__len__()) |
|
cmix_img_path = self.paths[cmix_index] |
|
cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt') |
|
cmix_img = imfrombytes(cmix_img_bytes, float32=True) |
|
cmix_img = cv2.resize(cmix_img, (gt_size, gt_size)) |
|
|
|
lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4) |
|
bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam) |
|
|
|
img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2] |
|
|
|
|
|
|
|
|
|
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB) |
|
img_l, img_ab = rgb2lab(img_gt) |
|
|
|
target_a, target_b = self.ab2int(img_ab) |
|
|
|
|
|
img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True) |
|
target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b) |
|
return_d = { |
|
'lq': img_l, |
|
'gt': img_ab, |
|
'target_a': target_a, |
|
'target_b': target_b, |
|
'lq_path': gt_path, |
|
'gt_path': gt_path |
|
} |
|
return return_d |
|
|
|
def ab2int(self, img_ab): |
|
img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1] |
|
int_a = (img_a - self.min_ab) / self.interval_ab |
|
int_b = (img_b - self.min_ab) / self.interval_ab |
|
|
|
return np.round(int_a), np.round(int_b) |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
|
|
def rand_bbox(size, lam): |
|
'''cutmix 的 bbox 截取函数 |
|
Args: |
|
size : tuple 图片尺寸 e.g (256,256) |
|
lam : float 截取比例 |
|
Returns: |
|
bbox 的左上角和右下角坐标 |
|
int,int,int,int |
|
''' |
|
W = size[0] |
|
H = size[1] |
|
cut_rat = np.sqrt(1. - lam) |
|
cut_w = np.int(W * cut_rat) |
|
cut_h = np.int(H * cut_rat) |
|
|
|
cx = np.random.randint(W) |
|
cy = np.random.randint(H) |
|
|
|
bbx1 = np.clip(cx - cut_w // 2, 0, W) |
|
bby1 = np.clip(cy - cut_h // 2, 0, H) |
|
bbx2 = np.clip(cx + cut_w // 2, 0, W) |
|
bby2 = np.clip(cy + cut_h // 2, 0, H) |
|
return bbx1, bby1, bbx2, bby2 |