Spaces:
Running
Running
Update text_image_audio.py
Browse files- text_image_audio.py +15 -4
text_image_audio.py
CHANGED
@@ -83,18 +83,29 @@ class AudioEncoder(nn.Module):
|
|
83 |
return self.forward(inputs)
|
84 |
|
85 |
class ModalityTokenEncoder(nn.Module):
|
86 |
-
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs):
|
87 |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
|
88 |
# Attributes
|
89 |
self.projection_dim = projection_dim
|
90 |
self.device = device
|
91 |
self.token_size = token_size
|
|
|
92 |
# Models
|
93 |
audio_variance = torch.rand(1) * 0.5 + 0.1
|
94 |
self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
|
95 |
-
size=(self.token_size, self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def forward(self):
|
97 |
-
return self.audio_token
|
98 |
|
99 |
def __call__(self):
|
100 |
return self.forward()
|
@@ -205,4 +216,4 @@ class OneEncoder(nn.Module, PyTorchModelHubMixin):
|
|
205 |
# fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
|
206 |
#plt.show()
|
207 |
#return values, indices
|
208 |
-
|
|
|
83 |
return self.forward(inputs)
|
84 |
|
85 |
class ModalityTokenEncoder(nn.Module):
|
86 |
+
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
|
87 |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
|
88 |
# Attributes
|
89 |
self.projection_dim = projection_dim
|
90 |
self.device = device
|
91 |
self.token_size = token_size
|
92 |
+
self.token_dim = token_dim
|
93 |
# Models
|
94 |
audio_variance = torch.rand(1) * 0.5 + 0.1
|
95 |
self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
|
96 |
+
size=(self.token_size, self.token_dim)).to(self.device))
|
97 |
+
|
98 |
+
self.token_projection = nn.Sequential(
|
99 |
+
nn.Linear(self.token_dim, 64),
|
100 |
+
nn.ReLU(),
|
101 |
+
nn.Linear(64, 128),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Linear(128, self.projection_dim),
|
104 |
+
nn.LayerNorm(self.projection_dim)
|
105 |
+
)
|
106 |
+
|
107 |
def forward(self):
|
108 |
+
return self.token_projection(self.audio_token)
|
109 |
|
110 |
def __call__(self):
|
111 |
return self.forward()
|
|
|
216 |
# fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
|
217 |
#plt.show()
|
218 |
#return values, indices
|
219 |
+
|