RishabA commited on
Commit
42296ae
·
verified ·
1 Parent(s): 10c688c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +13 -19
model.py CHANGED
@@ -3,28 +3,23 @@ import torch.nn as nn
3
  import math
4
 
5
 
6
- class PatchEmbedding(nn.Module):
7
- def __init__(self, in_channels: int = 3, patch_size: int = 16, d_model: int = 128):
8
  super().__init__()
9
 
10
  self.patch_size = patch_size
11
- self.d_model = d_model
12
-
13
  self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
14
- self.proj = nn.Linear(in_channels * patch_size * patch_size, d_model)
15
 
16
  def forward(self, x):
17
  batch_size, c, h, w = x.shape
18
 
19
- # Unfold to extract patches: shape becomes (batch_size, in_channels * patch_size * patch_size, num_patches)
20
- # num_patches = (H / patch_size) * (W / patch_size)
21
- patches = self.unfold(x)
22
-
23
- # Transpose to (batch_size, num_patches, in_channels * patch_size * patch_size)
24
- patches = patches.transpose(1, 2)
25
-
26
- # Apply linear projection to each patch: (batch_size, num_patches, in_channels * patch_size * patch_size) -> (batch_size, num_patches, d_model)
27
- return self.proj(patches)
28
 
29
 
30
  # Positional Encoding
@@ -139,7 +134,7 @@ class PositionwiseFeedForward(nn.Module):
139
 
140
  self.ffn = nn.Sequential(
141
  nn.Linear(in_features=d_model, out_features=(d_model * 4)),
142
- nn.GELU(),
143
  nn.Linear(in_features=(d_model * 4), out_features=d_model),
144
  nn.Dropout(p=dropout),
145
  )
@@ -218,9 +213,8 @@ class Encoder(nn.Module):
218
 
219
  self.patch_size = patch_size
220
 
221
- self.patch_emb = PatchEmbedding(
222
- patch_size=patch_size, in_channels=in_channels, d_model=d_model
223
- )
224
 
225
  seq_length = (image_size // patch_size) ** 2
226
 
@@ -245,7 +239,7 @@ class Encoder(nn.Module):
245
 
246
  # Extract the patches and apply a linear layer
247
  batch_size = src.shape[0]
248
- src = self.patch_emb(src)
249
 
250
  # Add the learned positional embedding
251
  src = src + self.pos_embedding
 
3
  import math
4
 
5
 
6
+ class ExtractPatches(nn.Module):
7
+ def __init__(self, patch_size: int = 16):
8
  super().__init__()
9
 
10
  self.patch_size = patch_size
 
 
11
  self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
 
12
 
13
  def forward(self, x):
14
  batch_size, c, h, w = x.shape
15
 
16
+ # Unfold applies a slding window to generate patches
17
+ # The transpose and reshape change the shape to (batch_size, num_patches, 3 * patch_size * patch_size), flattening the patches
18
+ return (
19
+ self.unfold(x)
20
+ .transpose(1, 2)
21
+ .reshape(batch_size, -1, c * self.patch_size * self.patch_size)
22
+ )
 
 
23
 
24
 
25
  # Positional Encoding
 
134
 
135
  self.ffn = nn.Sequential(
136
  nn.Linear(in_features=d_model, out_features=(d_model * 4)),
137
+ nn.ReLU(),
138
  nn.Linear(in_features=(d_model * 4), out_features=d_model),
139
  nn.Dropout(p=dropout),
140
  )
 
213
 
214
  self.patch_size = patch_size
215
 
216
+ self.extract_patches = ExtractPatches(patch_size=patch_size)
217
+ self.fc_in = nn.Linear(in_channels * patch_size * patch_size, d_model)
 
218
 
219
  seq_length = (image_size // patch_size) ** 2
220
 
 
239
 
240
  # Extract the patches and apply a linear layer
241
  batch_size = src.shape[0]
242
+ src = self.fc_in(self.extract_patches(src))
243
 
244
  # Add the learned positional embedding
245
  src = src + self.pos_embedding