Spaces:
Build error
Build error
r""" Superclass for semantic correspondence datasets """ | |
import os | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from PIL import Image | |
import torch | |
from model.base.geometry import Geometry | |
class CorrespondenceDataset(Dataset): | |
r""" Parent class of PFPascal, PFWillow, and SPair """ | |
def __init__(self, benchmark, datapath, thres, split): | |
r""" CorrespondenceDataset constructor """ | |
super(CorrespondenceDataset, self).__init__() | |
# {Directory name, Layout path, Image path, Annotation path, PCK threshold} | |
self.metadata = { | |
'pfwillow': ('PF-WILLOW', | |
'test_pairs.csv', | |
'', | |
'', | |
'bbox'), | |
'pfpascal': ('PF-PASCAL', | |
'_pairs.csv', | |
'JPEGImages', | |
'Annotations', | |
'img'), | |
'spair': ('SPair-71k', | |
'Layout/large', | |
'JPEGImages', | |
'PairAnnotation', | |
'bbox') | |
} | |
# Directory path for train, val, or test splits | |
base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0]) | |
if benchmark == 'pfpascal': | |
self.spt_path = os.path.join(base_path, split+'_pairs.csv') | |
elif benchmark == 'spair': | |
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt') | |
else: | |
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1]) | |
# Directory path for images | |
self.img_path = os.path.join(base_path, self.metadata[benchmark][2]) | |
# Directory path for annotations | |
if benchmark == 'spair': | |
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split) | |
else: | |
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3]) | |
# Miscellaneous | |
self.max_pts = 40 | |
self.split = split | |
self.img_size = Geometry.img_size | |
self.benchmark = benchmark | |
self.range_ts = torch.arange(self.max_pts) | |
self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres | |
self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225])]) | |
# To get initialized in subclass constructors | |
self.train_data = [] | |
self.src_imnames = [] | |
self.trg_imnames = [] | |
self.cls = [] | |
self.cls_ids = [] | |
self.src_kps = [] | |
self.trg_kps = [] | |
def __len__(self): | |
r""" Returns the number of pairs """ | |
return len(self.train_data) | |
def __getitem__(self, idx): | |
r""" Constructs and return a batch """ | |
# Image name | |
batch = dict() | |
batch['src_imname'] = self.src_imnames[idx] | |
batch['trg_imname'] = self.trg_imnames[idx] | |
# Object category | |
batch['category_id'] = self.cls_ids[idx] | |
batch['category'] = self.cls[batch['category_id']] | |
# Image as numpy (original width, original height) | |
src_pil = self.get_image(self.src_imnames, idx) | |
trg_pil = self.get_image(self.trg_imnames, idx) | |
batch['src_imsize'] = src_pil.size | |
batch['trg_imsize'] = trg_pil.size | |
# Image as tensor | |
batch['src_img'] = self.transform(src_pil) | |
batch['trg_img'] = self.transform(trg_pil) | |
# Key-points (re-scaled) | |
batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size) | |
batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size) | |
batch['n_pts'] = torch.tensor(num_pts) | |
# Total number of pairs in training split | |
batch['datalen'] = len(self.train_data) | |
return batch | |
def get_image(self, imnames, idx): | |
r""" Reads PIL image from path """ | |
path = os.path.join(self.img_path, imnames[idx]) | |
return Image.open(path).convert('RGB') | |
def get_pckthres(self, batch, imsize): | |
r""" Computes PCK threshold """ | |
if self.thres == 'bbox': | |
bbox = batch['trg_bbox'].clone() | |
bbox_w = (bbox[2] - bbox[0]) | |
bbox_h = (bbox[3] - bbox[1]) | |
pckthres = torch.max(bbox_w, bbox_h) | |
elif self.thres == 'img': | |
imsize_t = batch['trg_img'].size() | |
pckthres = torch.tensor(max(imsize_t[1], imsize_t[2])) | |
else: | |
raise Exception('Invalid pck threshold type: %s' % self.thres) | |
return pckthres.float() | |
def get_points(self, pts_list, idx, org_imsize): | |
r""" Returns key-points of an image """ | |
xy, n_pts = pts_list[idx].size() | |
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2 | |
x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0]) | |
y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1]) | |
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1) | |
return kps, n_pts | |