Spaces:
Runtime error
Runtime error
Update text_image_audio.py
Browse files- text_image_audio.py +15 -4
text_image_audio.py
CHANGED
|
@@ -83,18 +83,29 @@ class AudioEncoder(nn.Module):
|
|
| 83 |
return self.forward(inputs)
|
| 84 |
|
| 85 |
class ModalityTokenEncoder(nn.Module):
|
| 86 |
-
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs):
|
| 87 |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
|
| 88 |
# Attributes
|
| 89 |
self.projection_dim = projection_dim
|
| 90 |
self.device = device
|
| 91 |
self.token_size = token_size
|
|
|
|
| 92 |
# Models
|
| 93 |
audio_variance = torch.rand(1) * 0.5 + 0.1
|
| 94 |
self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
|
| 95 |
-
size=(self.token_size, self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def forward(self):
|
| 97 |
-
return self.audio_token
|
| 98 |
|
| 99 |
def __call__(self):
|
| 100 |
return self.forward()
|
|
@@ -205,4 +216,4 @@ class OneEncoder(nn.Module, PyTorchModelHubMixin):
|
|
| 205 |
# fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
|
| 206 |
#plt.show()
|
| 207 |
#return values, indices
|
| 208 |
-
|
|
|
|
| 83 |
return self.forward(inputs)
|
| 84 |
|
| 85 |
class ModalityTokenEncoder(nn.Module):
|
| 86 |
+
def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', token_dim=CFG.token_dim, *args, **kwargs):
|
| 87 |
super(ModalityTokenEncoder, self).__init__(*args, **kwargs)
|
| 88 |
# Attributes
|
| 89 |
self.projection_dim = projection_dim
|
| 90 |
self.device = device
|
| 91 |
self.token_size = token_size
|
| 92 |
+
self.token_dim = token_dim
|
| 93 |
# Models
|
| 94 |
audio_variance = torch.rand(1) * 0.5 + 0.1
|
| 95 |
self.audio_token = nn.Parameter(torch.normal(mean=0, std=audio_variance.item(),
|
| 96 |
+
size=(self.token_size, self.token_dim)).to(self.device))
|
| 97 |
+
|
| 98 |
+
self.token_projection = nn.Sequential(
|
| 99 |
+
nn.Linear(self.token_dim, 64),
|
| 100 |
+
nn.ReLU(),
|
| 101 |
+
nn.Linear(64, 128),
|
| 102 |
+
nn.ReLU(),
|
| 103 |
+
nn.Linear(128, self.projection_dim),
|
| 104 |
+
nn.LayerNorm(self.projection_dim)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
def forward(self):
|
| 108 |
+
return self.token_projection(self.audio_token)
|
| 109 |
|
| 110 |
def __call__(self):
|
| 111 |
return self.forward()
|
|
|
|
| 216 |
# fig.suptitle(display(Audio(query['input_values'], rate=self.sample_rate)))
|
| 217 |
#plt.show()
|
| 218 |
#return values, indices
|
| 219 |
+
|