danifei commited on
Commit
3de3832
·
verified ·
1 Parent(s): 86e5f72

add check_image_size method to the network

Browse files
Files changed (1) hide show
  1. archs/network.py +7 -0
archs/network.py CHANGED
@@ -80,6 +80,7 @@ class Network(nn.Module):
80
 
81
  _, _, H, W = input.shape
82
 
 
83
  x = self.intro(input)
84
 
85
  encs = []
@@ -108,6 +109,12 @@ class Network(nn.Module):
108
 
109
  return x[:, :, :H, :W]
110
 
 
 
 
 
 
 
111
 
112
  if __name__ == '__main__':
113
 
 
80
 
81
  _, _, H, W = input.shape
82
 
83
+ input = self.check_image_size(input)
84
  x = self.intro(input)
85
 
86
  encs = []
 
109
 
110
  return x[:, :, :H, :W]
111
 
112
+ def check_image_size(self, x):
113
+ _, _, h, w = x.size()
114
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
115
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
116
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
117
+ return x
118
 
119
  if __name__ == '__main__':
120