import cv2 import sys import numpy as np import torch from torch.utils import data from torch.utils.data import DataLoader import torchvision.transforms as transforms from PIL import Image from collections import OrderedDict sys.path.insert(1, './schp') from utils.transforms import get_affine_transform import networks from utils.transforms import transform_logits class PILImageDataset(data.Dataset): def __init__(self, img_lst=[], input_size=[512, 512], transform=None): self.img_lst = img_lst self.input_size = input_size self.transform = transform self.aspect_ratio = input_size[1] * 1.0 / input_size[0] self.input_size = np.asarray(input_size) def __len__(self): return len(self.img_lst) def _box2cs(self, box): x, y, w, h = box[:4] return self._xywh2cs(x, y, w, h) def _xywh2cs(self, x, y, w, h): center = np.zeros((2), dtype=np.float32) center[0] = x + w * 0.5 center[1] = y + h * 0.5 if w > self.aspect_ratio * h: h = w * 1.0 / self.aspect_ratio elif w < self.aspect_ratio * h: w = h * self.aspect_ratio scale = np.array([w, h], dtype=np.float32) return center, scale def __getitem__(self, index): img = np.array(self.img_lst[index])[:,:,::-1] h, w, _ = img.shape # Get person center and scale person_center, s = self._box2cs([0, 0, w - 1, h - 1]) r = 0 trans = get_affine_transform(person_center, s, r, self.input_size) input = cv2.warpAffine( img, trans, (int(self.input_size[1]), int(self.input_size[0])), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0)) input = self.transform(input) meta = { 'center': person_center, 'height': h, 'width': w, 'scale': s, 'rotation': r } return input, meta PALLETE_DICT = { 'Background': [], 'Face': [], 'Upper-clothes':[], 'Dress':[], 'Coat':[], 'Soaks':[], 'Pants':[], 'Jumpsuits':[], 'Scarf':[], 'Skirt':[], 'Arm':[], 'Leg':[], 'Shoe':[] } val_list = [[0],[1,4,13],[5],[6],[7],[8],[9],[10],[11],[12],[14,15],[16,17],[18,19]] for c,j in enumerate(PALLETE_DICT.keys()): val = val_list[c] pallete = [] for i in range(60): if len(val) == 1: if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)): pallete.append(255) else: pallete.append(0) if len(val) == 2: if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)) or (i >= (val[1]*3)) & (i < ((val[1]+1)*3)): pallete.append(255) else: pallete.append(0) if len(val) == 3: if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)) or (i >= (val[1]*3)) & (i < ((val[1]+1)*3)) or (i >= (val[2]*3)) & (i < ((val[2]+1)*3)): pallete.append(255) else: pallete.append(0) PALLETE_DICT[j] = pallete DATASET_SETTINGS = { 'lip': { 'input_size': [473, 473], 'num_classes': 20, 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe'] }, 'atr': { 'input_size': [512, 512], 'num_classes': 18, 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] }, 'pascal': { 'input_size': [512, 512], 'num_classes': 7, 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'], } } class SCHPParser: def __init__(self, checkpoint_path, dataset_settings): self.cp_path = checkpoint_path self.ops = [] self.num_classes = dataset_settings['lip']['num_classes'] self.input_size = dataset_settings['lip']['input_size'] self.label = dataset_settings['lip']['label'] self.pallete_dict = PALLETE_DICT self.img_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) ]) self.model = self.load_model() def load_model(self): model = networks.init_model('resnet101', num_classes=self.num_classes, pretrained=None) state_dict = torch.load(self.cp_path)['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) model.cuda() model.eval() return model def create_dataloader(self, img_lst): dataset = PILImageDataset(img_lst, input_size=self.input_size, transform=self.img_transforms) # dataset = SimpleFolderDataset('inputs',input_size, transform) dataloader = DataLoader(dataset) return dataloader def get_image_masks(self, img_lst): print("Evaluating total class number {} with {}".format(self.num_classes, self.label)) dataloader = self.create_dataloader(img_lst) with torch.no_grad(): for batch in dataloader: op_dict = {} image, meta = batch c = meta['center'].numpy()[0] s = meta['scale'].numpy()[0] w = meta['width'].numpy()[0] h = meta['height'].numpy()[0] output = self.model(image.cuda()) upsample = torch.nn.Upsample(size=self.input_size, mode='bilinear', align_corners=True) upsample_output = upsample(output[0][-1][0].unsqueeze(0)) upsample_output = upsample_output.squeeze() upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=self.input_size) parsing_result = np.argmax(logits_result, axis=2) output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) for loc, key in enumerate(self.pallete_dict.keys()): output_img.putpalette(self.pallete_dict[key]) op_dict.update({ key: output_img.convert('L') }) self.ops.append(op_dict) return self.ops