Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -3,28 +3,23 @@ import torch.nn as nn
|
|
3 |
import math
|
4 |
|
5 |
|
6 |
-
class
|
7 |
-
def __init__(self,
|
8 |
super().__init__()
|
9 |
|
10 |
self.patch_size = patch_size
|
11 |
-
self.d_model = d_model
|
12 |
-
|
13 |
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
14 |
-
self.proj = nn.Linear(in_channels * patch_size * patch_size, d_model)
|
15 |
|
16 |
def forward(self, x):
|
17 |
batch_size, c, h, w = x.shape
|
18 |
|
19 |
-
# Unfold
|
20 |
-
#
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
# Apply linear projection to each patch: (batch_size, num_patches, in_channels * patch_size * patch_size) -> (batch_size, num_patches, d_model)
|
27 |
-
return self.proj(patches)
|
28 |
|
29 |
|
30 |
# Positional Encoding
|
@@ -139,7 +134,7 @@ class PositionwiseFeedForward(nn.Module):
|
|
139 |
|
140 |
self.ffn = nn.Sequential(
|
141 |
nn.Linear(in_features=d_model, out_features=(d_model * 4)),
|
142 |
-
nn.
|
143 |
nn.Linear(in_features=(d_model * 4), out_features=d_model),
|
144 |
nn.Dropout(p=dropout),
|
145 |
)
|
@@ -218,9 +213,8 @@ class Encoder(nn.Module):
|
|
218 |
|
219 |
self.patch_size = patch_size
|
220 |
|
221 |
-
self.
|
222 |
-
|
223 |
-
)
|
224 |
|
225 |
seq_length = (image_size // patch_size) ** 2
|
226 |
|
@@ -245,7 +239,7 @@ class Encoder(nn.Module):
|
|
245 |
|
246 |
# Extract the patches and apply a linear layer
|
247 |
batch_size = src.shape[0]
|
248 |
-
src = self.
|
249 |
|
250 |
# Add the learned positional embedding
|
251 |
src = src + self.pos_embedding
|
|
|
3 |
import math
|
4 |
|
5 |
|
6 |
+
class ExtractPatches(nn.Module):
|
7 |
+
def __init__(self, patch_size: int = 16):
|
8 |
super().__init__()
|
9 |
|
10 |
self.patch_size = patch_size
|
|
|
|
|
11 |
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
|
|
12 |
|
13 |
def forward(self, x):
|
14 |
batch_size, c, h, w = x.shape
|
15 |
|
16 |
+
# Unfold applies a slding window to generate patches
|
17 |
+
# The transpose and reshape change the shape to (batch_size, num_patches, 3 * patch_size * patch_size), flattening the patches
|
18 |
+
return (
|
19 |
+
self.unfold(x)
|
20 |
+
.transpose(1, 2)
|
21 |
+
.reshape(batch_size, -1, c * self.patch_size * self.patch_size)
|
22 |
+
)
|
|
|
|
|
23 |
|
24 |
|
25 |
# Positional Encoding
|
|
|
134 |
|
135 |
self.ffn = nn.Sequential(
|
136 |
nn.Linear(in_features=d_model, out_features=(d_model * 4)),
|
137 |
+
nn.ReLU(),
|
138 |
nn.Linear(in_features=(d_model * 4), out_features=d_model),
|
139 |
nn.Dropout(p=dropout),
|
140 |
)
|
|
|
213 |
|
214 |
self.patch_size = patch_size
|
215 |
|
216 |
+
self.extract_patches = ExtractPatches(patch_size=patch_size)
|
217 |
+
self.fc_in = nn.Linear(in_channels * patch_size * patch_size, d_model)
|
|
|
218 |
|
219 |
seq_length = (image_size // patch_size) ** 2
|
220 |
|
|
|
239 |
|
240 |
# Extract the patches and apply a linear layer
|
241 |
batch_size = src.shape[0]
|
242 |
+
src = self.fc_in(self.extract_patches(src))
|
243 |
|
244 |
# Add the learned positional embedding
|
245 |
src = src + self.pos_embedding
|