File size: 5,169 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch,os
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.io as sio
import torchvision.transforms as transforms

def data_list(img_root,mode):
    exist_aer_list = os.listdir(os.path.join(img_root , 'satview_correct'))
    exist_grd_list = os.listdir(os.path.join(img_root , 'streetview'))
    allDataList = os.path.join(img_root, 'ACT_data.mat')
    anuData = sio.loadmat(allDataList)

    all_data_list = []
    for i in range(0, len(anuData['panoIds'])):
        grd_id_align = anuData['panoIds'][i] + '_grdView.png'
        sat_id_ori = anuData['panoIds'][i] + '_satView_polish.png'
        all_data_list.append([grd_id_align, sat_id_ori])

    data_list = []
    
    if mode=='train':
        training_inds = anuData['trainSet']['trainInd'][0][0] - 1
        trainNum = len(training_inds)
        for k in range(trainNum):
            data_list.append(all_data_list[training_inds[k][0]])
    else:
        val_inds = anuData['valSet']['valInd'][0][0] - 1
        valNum = len(val_inds)
        for k in range(valNum):
            data_list.append(all_data_list[val_inds[k][0]])


    pano_list = [img_root + 'streetview/' + item[0] for item in data_list if item[0] in exist_grd_list and item[1] in exist_aer_list]

    return pano_list
    
def img_read(img,size=None,datatype='RGB'):
    img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
    if size:
        if type(size) is int:
            size = (size,size)
        img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
    img = transforms.ToTensor()(img)
    return img


class Dataset(Dataset):
    def __init__(self, opt,split='train',sub=None,sty_img=None):
        if sty_img:
            assert sty_img.endswith('grdView.png')
            demo_img_path = os.path.join(opt.data.root,'streetview',sty_img)
            self.pano_list = [demo_img_path]

        elif opt.task in  ['test_vid','test_interpolation'] :
            demo_img_path = os.path.join(opt.data.root,'streetview',opt.demo_img.replace('satView_polish.png','grdView.png'))
            self.pano_list = [demo_img_path]

        else:
            self.pano_list = data_list(img_root=opt.data.root,mode=split)
            if sub:
                self.pano_list = self.pano_list[:sub]
        
        # select some ground images to test the influence of different skys.
        # different skys guide different illumination intensity, colors, and etc.
        if opt.task == 'test_sty':
            demo_name = [
                'dataset/CVACT/streetview/pPfo7qQ1fP_24rXrJ2Uxog_grdView.png',
                'dataset/CVACT/streetview/YL81FiK9PucIvAkr1FHkpA_grdView.png',
                'dataset/CVACT/streetview/Tzis1jBKHjbXiVB2oRYwAQ_grdView.png',
                'dataset/CVACT/streetview/eqGgeBLGXRhSj6c-0h0KoQ_grdView.png',
                'dataset/CVACT/streetview/pdZmLHYEhe2PHj_8-WHMhw_grdView.png',
                'dataset/CVACT/streetview/ehsu9Q3iTin5t52DM-MwyQ_grdView.png',
                'dataset/CVACT/streetview/agLEcuq3_-qFj7wwGbktVg_grdView.png',
                'dataset/CVACT/streetview/HwQIDdMI3GfHyPGtCSo6aA_grdView.png',
                'dataset/CVACT/streetview/hV8svb3ZVXcQ0AtTRFE1dQ_grdView.png',
                'dataset/CVACT/streetview/fzq2mBfKP3UIczAd9KpMMg_grdView.png',
                'dataset/CVACT/streetview/acRP98sACUIlwl2ZIsEyiQ_grdView.png',
                'dataset/CVACT/streetview/WSh9tNVryLdupUlU0ri2tQ_grdView.png',
                'dataset/CVACT/streetview/FhEuB9NA5o08VJ_TBCbHjw_grdView.png',
                'dataset/CVACT/streetview/YHfpn2Mgu1lqgT2OUeBpOg_grdView.png',
                'dataset/CVACT/streetview/vNhv7ZP1dUkJ93UwFXagJw_grdView.png',
            ]
            self.pano_list = demo_name

        self.opt = opt

    def __len__(self):
        return len(self.pano_list)

    def __getitem__(self, index):
        pano = self.pano_list[index]
        aer = pano.replace('streetview','satview_correct').replace('_grdView','_satView_polish')
        if self.opt.data.sky_mask:
            sky = pano.replace('streetview','pano_sky_mask')
        name = pano
        aer = img_read(aer,  size = self.opt.data.sat_size)
        pano = img_read(pano,size = self.opt.data.pano_size)
        if self.opt.data.sky_mask:
            sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')

        input = {}
        input['sat']=aer
        input['pano']=pano
        input['paths']=name
        if self.opt.data.sky_mask:
            input['sky_mask']=sky
            black_ground = torch.zeros_like(pano)
            if self.opt.data.histo_mode =='grey':
                input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] 
            elif self.opt.data.histo_mode in ['rgb','RGB']:
                input_a  = (pano*sky+black_ground*(1-sky))
                for idx in range(len(input_a)):
                    if idx == 0:
                        sky_histc = input_a[idx].histc()[10:]
                    else:
                        sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
                input['sky_histc'] = sky_histc
        return input