Spaces:
Running
Running
add check_image_size method to the network
Browse files- archs/network.py +7 -0
archs/network.py
CHANGED
@@ -80,6 +80,7 @@ class Network(nn.Module):
|
|
80 |
|
81 |
_, _, H, W = input.shape
|
82 |
|
|
|
83 |
x = self.intro(input)
|
84 |
|
85 |
encs = []
|
@@ -108,6 +109,12 @@ class Network(nn.Module):
|
|
108 |
|
109 |
return x[:, :, :H, :W]
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
if __name__ == '__main__':
|
113 |
|
|
|
80 |
|
81 |
_, _, H, W = input.shape
|
82 |
|
83 |
+
input = self.check_image_size(input)
|
84 |
x = self.intro(input)
|
85 |
|
86 |
encs = []
|
|
|
109 |
|
110 |
return x[:, :, :H, :W]
|
111 |
|
112 |
+
def check_image_size(self, x):
|
113 |
+
_, _, h, w = x.size()
|
114 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
115 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
116 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
|
117 |
+
return x
|
118 |
|
119 |
if __name__ == '__main__':
|
120 |
|