Spaces:
Runtime error
Runtime error
File size: 7,296 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# from this import d
import torch
from .base_model import BaseModel
import importlib
from torch.utils.data import DataLoader
from easydict import EasyDict as edict
class Model(BaseModel):
def __init__(self, opt, wandb=None):
"""Initialize the Generator.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseModel.__init__(self, opt,wandb)
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
self.real_A: aerial images
self.real_B: ground images
self.image_paths: images paths of ground images
self.sky_mask: the sky mask of ground images
self.sky_histc: the histogram of selected sky
"""
self.real_A = input['sat' ].to(self.device)
self.real_B = input['pano'].to(self.device) if 'pano' in input else None # for testing
self.image_paths = input['paths']
if self.opt.data.sky_mask:
self.sky_mask = input['sky_mask'].to(self.device) if 'sky_mask' in input else None # for testing
if self.opt.data.histo_mode and self.opt.data.sky_mask:
self.sky_histc = input['sky_histc'].to(self.device) if 'sky_histc' in input else None # for testing
else: self.sky_histc = None
def forward(self,opt):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
# origin_H_W is the inital localization of camera
if opt.task != 'test_vid':
opt.origin_H_W=None
if hasattr(opt.arch.gen,'style_inject'):
# replace the predicted sky with selected sky histogram
if opt.arch.gen.style_inject == 'histo':
self.out_put = self.netG(self.real_A,self.sky_histc.detach(),opt)
else:
raise Exception('Unknown style inject mode')
else:
self.out_put = self.netG(self.real_A,None,opt)
self.out_put = edict(self.out_put)
self.fake_B = self.out_put.pred
# perceptive image
def backward_D(self,opt):
"""Calculate GAN loss for the discriminator"""
self.optimizer_D.zero_grad()
self.netG.eval()
with torch.no_grad():
self.forward(opt)
self.out_put.pred = self.out_put.pred.detach()
net_D_output = self.netD(self.real_B, self.out_put)
output_fake = self._get_outputs(net_D_output, real=False)
output_real = self._get_outputs(net_D_output, real=True)
fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True)
true_loss = self.criteria['GAN'](output_real, True, dis_update=True)
self.dis_losses = dict()
self.dis_losses['GAN/fake'] = fake_loss
self.dis_losses['GAN/true'] = true_loss
self.dis_losses['DIS'] = fake_loss + true_loss
self.dis_losses['DIS'].backward()
self.optimizer_D.step()
def backward_G(self,opt):
self.optimizer_G.zero_grad()
self.loss = {}
self.netG.train()
self.forward(opt)
net_D_output = self.netD(self.real_B, self.out_put)
pred_fake = self._get_outputs(net_D_output, real=False)
self.loss['GAN'] = self.criteria['GAN'](pred_fake, True, dis_update=False)
if 'GaussianKL' in self.criteria:
self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar'])
if 'L1' in self.criteria:
self.loss['L1'] = self.criteria['L1'](self.real_B,self.fake_B)
if 'L2' in self.criteria:
self.loss['L2'] = self.criteria['L2'](self.real_B,self.fake_B)
if 'SSIM' in self.criteria:
self.loss['SSIM'] = 1-self.criteria['SSIM'](self.real_B, self.fake_B)
if 'GaussianKL' in self.criteria:
self.loss['GaussianKL'] = self.criteria['GaussianKL'](self.out_put['mu'], self.out_put['logvar'])
if 'sky_inner' in self.criteria:
self.loss['sky_inner'] = self.criteria['sky_inner'](self.out_put.opacity, 1-self.sky_mask)
if 'Perceptual' in self.criteria:
self.loss['Perceptual'] = self.criteria['Perceptual'](self.fake_B,self.real_B)
if 'feature_matching' in self.criteria:
self.loss['feature_matching'] = self.criteria['feature_matching'](net_D_output['fake_features'], net_D_output['real_features'])
self.loss_G = 0
for key in self.loss:
self.loss_G += self.loss[key] * self.weights[key]
self.loss['total'] = self.loss_G
self.loss_G.backward()
self.optimizer_G.step() # udpate G's weights
def load_dataset(self,opt):
data = importlib.import_module("data.{}".format(opt.data.dataset))
if opt.task in ["train", "Train"]:
train_data = data.Dataset(opt,"train",opt.data.train_sub)
self.train_loader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,num_workers=opt.data.num_workers,drop_last=True)
self.len_train_loader = len(self.train_loader)
val_data = data.Dataset(opt,"val")
opt.batch_size = 1 if opt.task in ["test" , "val","vis_test",'test_vid','test_sty'] else opt.batch_size
opt.batch_size = 1 if opt.task=='test_speed' else opt.batch_size
self.val_loader = DataLoader(val_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.data.num_workers)
self.len_val_loader = len(self.val_loader)
# you can select one random image as a style of all predicted skys
# if None, we use the corresponding style of GT
if opt.sty_img:
sty_data = data.Dataset(opt,sty_img = opt.sty_img)
self.sty_loader = DataLoader(sty_data,batch_size=1,num_workers=1,shuffle=False)
# The followings are only used for test the illumination interpolation.
if opt.sty_img1:
sty1_data = data.Dataset(opt,sty_img = opt.sty_img1)
self.sty_loader1 = DataLoader(sty1_data,batch_size=1,num_workers=1,shuffle=False)
if opt.sty_img2:
sty2_data = data.Dataset(opt,sty_img = opt.sty_img2)
self.sty_loader2 = DataLoader(sty2_data,batch_size=1,num_workers=1,shuffle=False)
def build_networks(self, opt):
if 'imaginaire' in opt.arch.gen.netG:
lib_G = importlib.import_module(opt.arch.gen.netG)
self.netG = lib_G.Generator(opt).to(self.device)
else:
raise Exception('Unknown discriminator function')
if opt.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if opt.arch.dis.netD == 'imaginaire.discriminators.multires_patch_pano':
lib_D = importlib.import_module(opt.arch.dis.netD)
self.netD = lib_D.Discriminator(opt.arch.dis).to(self.device)
else:
raise Exception('Unknown discriminator function')
|