Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -70,13 +70,12 @@ class TransformerEncoderLayer(nn.Module):
|
|
70 |
self,
|
71 |
x: torch.Tensor,
|
72 |
mask: torch.Tensor,
|
73 |
-
attention_cache: torch.Tensor =
|
74 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
75 |
"""
|
76 |
|
77 |
:param x: torch.Tensor. shape=(batch_size, time, input_dim).
|
78 |
:param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
|
79 |
-
:param position_embedding: torch.Tensor.
|
80 |
:param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
|
81 |
shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
|
82 |
:return:
|
@@ -177,8 +176,8 @@ class TransformerEncoder(nn.Module):
|
|
177 |
def forward_chunk(self,
|
178 |
xs: torch.Tensor,
|
179 |
offset: int,
|
180 |
-
attention_mask: torch.Tensor =
|
181 |
-
attention_cache: torch.Tensor =
|
182 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
183 |
"""
|
184 |
Forward just one chunk.
|
|
|
70 |
self,
|
71 |
x: torch.Tensor,
|
72 |
mask: torch.Tensor,
|
73 |
+
attention_cache: torch.Tensor = None,
|
74 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
75 |
"""
|
76 |
|
77 |
:param x: torch.Tensor. shape=(batch_size, time, input_dim).
|
78 |
:param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
|
|
|
79 |
:param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
|
80 |
shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
|
81 |
:return:
|
|
|
176 |
def forward_chunk(self,
|
177 |
xs: torch.Tensor,
|
178 |
offset: int,
|
179 |
+
attention_mask: torch.Tensor = None,
|
180 |
+
attention_cache: torch.Tensor = None,
|
181 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
182 |
"""
|
183 |
Forward just one chunk.
|