bilalfaye commited on
Commit
e925821
·
verified ·
1 Parent(s): 3022d51

Update text_image_audio.py

Browse files
Files changed (1) hide show
  1. text_image_audio.py +15 -4
text_image_audio.py CHANGED
@@ -83,18 +83,29 @@ class AudioEncoder(nn.Module):
83
  return self.forward(inputs)
84
 
85
  class ModalityTokenEncoder(nn.Module):
86
- def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs):
87
  super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
88
  # Attributes
89
  self.projection_dim = projection_dim
90
  self.device = device
91
  self.token_size = token_size
 
92
  # Models
93
  audio_variance = torch.rand(1) * 0.5 + 0.1
94
  self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
95
- size=(self.token_size, self.projection_dim)).to(self.device))
 
 
 
 
 
 
 
 
 
 
96
  def forward(self):
97
- return self.audio_token
98
 
99
  def __call__(self):
100
  return self.forward()
@@ -205,4 +216,4 @@ class OneEncoder(nn.Module, PyTorchModelHubMixin):
205
  # fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
206
  #plt.show()
207
  #return values, indices
208
-
 
83
  return self.forward(inputs)
84
 
85
  class ModalityTokenEncoder(nn.Module):
86
+ def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
87
  super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
88
  # Attributes
89
  self.projection_dim = projection_dim
90
  self.device = device
91
  self.token_size = token_size
92
+ self.token_dim = token_dim
93
  # Models
94
  audio_variance = torch.rand(1) * 0.5 + 0.1
95
  self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
96
+ size=(self.token_size, self.token_dim)).to(self.device))
97
+
98
+ self.token_projection = nn.Sequential(
99
+ nn.Linear(self.token_dim, 64),
100
+ nn.ReLU(),
101
+ nn.Linear(64, 128),
102
+ nn.ReLU(),
103
+ nn.Linear(128, self.projection_dim),
104
+ nn.LayerNorm(self.projection_dim)
105
+ )
106
+
107
  def forward(self):
108
+ return self.token_projection(self.audio_token)
109
 
110
  def __call__(self):
111
  return self.forward()
 
216
  # fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
217
  #plt.show()
218
  #return values, indices
219
+