# from this import d import torch from .base_model import BaseModel import importlib from torch.utils.data import DataLoader from easydict import EasyDict as edict class Model(BaseModel): def __init__(self, opt, wandb=None): """Initialize the Generator. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseModel.__init__(self, opt,wandb) self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. self.real_A: aerial images self.real_B: ground images self.image_paths: images paths of ground images self.sky_mask: the sky mask of ground images self.sky_histc: the histogram of selected sky """ self.real_A = input['sat' ].to(self.device) self.real_B = input['pano'].to(self.device) if 'pano' in input else None # for testing self.image_paths = input['paths'] if self.opt.data.sky_mask: self.sky_mask = input['sky_mask'].to(self.device) if 'sky_mask' in input else None # for testing if self.opt.data.histo_mode and self.opt.data.sky_mask: self.sky_histc = input['sky_histc'].to(self.device) if 'sky_histc' in input else None # for testing else: self.sky_histc = None def forward(self,opt): """Run forward pass; called by both functions and .""" # origin_H_W is the inital localization of camera if opt.task != 'test_vid': opt.origin_H_W=None if hasattr(opt.arch.gen,'style_inject'): # replace the predicted sky with selected sky histogram if opt.arch.gen.style_inject == 'histo': self.out_put = self.netG(self.real_A,self.sky_histc.detach(),opt) else: raise Exception('Unknown style inject mode') else: self.out_put = self.netG(self.real_A,None,opt) self.out_put = edict(self.out_put) self.fake_B = self.out_put.pred # perceptive image def backward_D(self,opt): """Calculate GAN loss for the discriminator""" self.optimizer_D.zero_grad() self.netG.eval() with torch.no_grad(): self.forward(opt) self.out_put.pred = self.out_put.pred.detach() net_D_output = self.netD(self.real_B, self.out_put) output_fake = self._get_outputs(net_D_output, real=False) output_real = self._get_outputs(net_D_output, real=True) fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) true_loss = self.criteria['GAN'](output_real, True, dis_update=True) self.dis_losses = dict() self.dis_losses['GAN/fake'] = fake_loss self.dis_losses['GAN/true'] = true_loss self.dis_losses['DIS'] = fake_loss + true_loss self.dis_losses['DIS'].backward() self.optimizer_D.step() def backward_G(self,opt): self.optimizer_G.zero_grad() self.loss = {} self.netG.train() self.forward(opt) net_D_output = self.netD(self.real_B, self.out_put) pred_fake = self._get_outputs(net_D_output, real=False) self.loss['GAN'] = self.criteria['GAN'](pred_fake, True, dis_update=False) if 'GaussianKL' in self.criteria: self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar']) if 'L1' in self.criteria: self.loss['L1'] = self.criteria['L1'](self.real_B,self.fake_B) if 'L2' in self.criteria: self.loss['L2'] = self.criteria['L2'](self.real_B,self.fake_B) if 'SSIM' in self.criteria: self.loss['SSIM'] = 1-self.criteria['SSIM'](self.real_B, self.fake_B) if 'GaussianKL' in self.criteria: self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar']) if 'sky_inner' in self.criteria: self.loss['sky_inner'] = self.criteria['sky_inner'](self.out_put.opacity, 1-self.sky_mask) if 'Perceptual' in self.criteria: self.loss['Perceptual'] = self.criteria['Perceptual'](self.fake_B,self.real_B) if 'feature_matching' in self.criteria: self.loss['feature_matching'] = self.criteria['feature_matching'](net_D_output['fake_features'], net_D_output['real_features']) self.loss_G = 0 for key in self.loss: self.loss_G += self.loss[key] * self.weights[key] self.loss['total'] = self.loss_G self.loss_G.backward() self.optimizer_G.step() # udpate G's weights def load_dataset(self,opt): data = importlib.import_module("data.{}".format(opt.data.dataset)) if opt.task in ["train", "Train"]: train_data = data.Dataset(opt,"train",opt.data.train_sub) self.train_loader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,num_workers=opt.data.num_workers,drop_last=True) self.len_train_loader = len(self.train_loader) val_data = data.Dataset(opt,"val") opt.batch_size = 1 if opt.task in ["test" , "val","vis_test",'test_vid','test_sty'] else opt.batch_size opt.batch_size = 1 if opt.task=='test_speed' else opt.batch_size self.val_loader = DataLoader(val_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.data.num_workers) self.len_val_loader = len(self.val_loader) # you can select one random image as a style of all predicted skys # if None, we use the corresponding style of GT if opt.sty_img: sty_data = data.Dataset(opt,sty_img = opt.sty_img) self.sty_loader = DataLoader(sty_data,batch_size=1,num_workers=1,shuffle=False) # The followings are only used for test the illumination interpolation. if opt.sty_img1: sty1_data = data.Dataset(opt,sty_img = opt.sty_img1) self.sty_loader1 = DataLoader(sty1_data,batch_size=1,num_workers=1,shuffle=False) if opt.sty_img2: sty2_data = data.Dataset(opt,sty_img = opt.sty_img2) self.sty_loader2 = DataLoader(sty2_data,batch_size=1,num_workers=1,shuffle=False) def build_networks(self, opt): if 'imaginaire' in opt.arch.gen.netG: lib_G = importlib.import_module(opt.arch.gen.netG) self.netG = lib_G.Generator(opt).to(self.device) else: raise Exception('Unknown discriminator function') if opt.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc if opt.arch.dis.netD == 'imaginaire.discriminators.multires_patch_pano': lib_D = importlib.import_module(opt.arch.dis.netD) self.netD = lib_D.Discriminator(opt.arch.dis).to(self.device) else: raise Exception('Unknown discriminator function')