bilalfaye commited on
Commit
3022d51
·
verified ·
1 Parent(s): 7aa0ea9

Update text_image.py

Browse files
Files changed (1) hide show
  1. text_image.py +19 -9
text_image.py CHANGED
@@ -194,26 +194,37 @@ class TextEncoder(nn.Module):
194
 
195
 
196
  class ModalityTokenEncoder(nn.Module):
197
- def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs):
198
  """
199
  Modality token encoder module for encoding modality token information.
200
 
201
  :param projection_dim: Dimensionality of projected features (default: CFG.projection_dim).
202
  :param token_size: Token size.
203
  :param device: Device to run the module on (default: 'cpu').
 
204
  """
205
  super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
206
  # Attributes
207
  self.projection_dim = projection_dim
208
  self.device = device
209
  self.token_size = token_size
 
210
  # Models
211
  text_variance = torch.rand(1) * 0.5 + 0.1
212
  image_variance = torch.rand(1) * 0.5 + 0.1
213
  self.text_token = nn.Parameter(torch.normal(mean=0, std=text_variance.item(),
214
- size=(self.token_size, self.projection_dim)).to(self.device))
215
  self.image_token = nn.Parameter(torch.normal(mean=0, std=image_variance.item(),
216
- size=(self.token_size, self.projection_dim)).to(self.device))
 
 
 
 
 
 
 
 
 
217
 
218
  def forward(self, modality_type):
219
  """
@@ -223,7 +234,8 @@ class ModalityTokenEncoder(nn.Module):
223
  :return: Projected features.
224
  """
225
  token = torch.where(torch.tensor(modality_type == "image"), self.image_token, self.text_token)
226
- return token
 
227
 
228
  def __call__(self, modality_type):
229
  """
@@ -234,7 +246,6 @@ class ModalityTokenEncoder(nn.Module):
234
  """
235
  return self.forward(modality_type)
236
 
237
-
238
  class UniversalProjectionOutput:
239
  def __init__(self, outputs):
240
  """
@@ -534,10 +545,9 @@ class OneEncoder(nn.Module, PyTorchModelHubMixin):
534
  image = self.load_image(query)
535
  # Plot the image
536
  plt.imshow(image)
537
- plt.title('Random Image')
538
  plt.axis('off')
539
  plt.savefig("img.png")
540
- plt.show()
541
- return matches, values
542
-
543
 
 
194
 
195
 
196
  class ModalityTokenEncoder(nn.Module):
197
+ def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
198
  """
199
  Modality token encoder module for encoding modality token information.
200
 
201
  :param projection_dim: Dimensionality of projected features (default: CFG.projection_dim).
202
  :param token_size: Token size.
203
  :param device: Device to run the module on (default: 'cpu').
204
+ :param token_dim: Dimension of tokens
205
  """
206
  super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
207
  # Attributes
208
  self.projection_dim = projection_dim
209
  self.device = device
210
  self.token_size = token_size
211
+ self.token_dim = token_dim
212
  # Models
213
  text_variance = torch.rand(1) * 0.5 + 0.1
214
  image_variance = torch.rand(1) * 0.5 + 0.1
215
  self.text_token = nn.Parameter(torch.normal(mean=0, std=text_variance.item(),
216
+ size=(self.token_size, self.token_dim)).to(self.device))
217
  self.image_token = nn.Parameter(torch.normal(mean=0, std=image_variance.item(),
218
+ size=(self.token_size, self.token_dim)).to(self.device))
219
+ # Projection
220
+ self.token_projection = nn.Sequential(
221
+ nn.Linear(self.token_dim, 64),
222
+ nn.ReLU(),
223
+ nn.Linear(64, 128),
224
+ nn.ReLU(),
225
+ nn.Linear(128, self.projection_dim),
226
+ nn.LayerNorm(self.projection_dim)
227
+ )
228
 
229
  def forward(self, modality_type):
230
  """
 
234
  :return: Projected features.
235
  """
236
  token = torch.where(torch.tensor(modality_type == "image"), self.image_token, self.text_token)
237
+ token_features = self.token_projection(token)
238
+ return token_features
239
 
240
  def __call__(self, modality_type):
241
  """
 
246
  """
247
  return self.forward(modality_type)
248
 
 
249
  class UniversalProjectionOutput:
250
  def __init__(self, outputs):
251
  """
 
545
  image = self.load_image(query)
546
  # Plot the image
547
  plt.imshow(image)
548
+ #plt.title('Random Image')
549
  plt.axis('off')
550
  plt.savefig("img.png")
551
+ #plt.show()
552
+ #return matches, values
 
553