HoneyTian commited on
Commit
f86fc1a
·
1 Parent(s): b576e15
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -70,13 +70,12 @@ class TransformerEncoderLayer(nn.Module):
70
  self,
71
  x: torch.Tensor,
72
  mask: torch.Tensor,
73
- attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
74
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
75
  """
76
 
77
  :param x: torch.Tensor. shape=(batch_size, time, input_dim).
78
  :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
79
- :param position_embedding: torch.Tensor.
80
  :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
81
  shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
82
  :return:
@@ -177,8 +176,8 @@ class TransformerEncoder(nn.Module):
177
  def forward_chunk(self,
178
  xs: torch.Tensor,
179
  offset: int,
180
- attention_mask: torch.Tensor = torch.zeros(0, 0, 0),
181
- attention_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
182
  ) -> Tuple[torch.Tensor, torch.Tensor]:
183
  """
184
  Forward just one chunk.
 
70
  self,
71
  x: torch.Tensor,
72
  mask: torch.Tensor,
73
+ attention_cache: torch.Tensor = None,
74
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
75
  """
76
 
77
  :param x: torch.Tensor. shape=(batch_size, time, input_dim).
78
  :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
 
79
  :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
80
  shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
81
  :return:
 
176
  def forward_chunk(self,
177
  xs: torch.Tensor,
178
  offset: int,
179
+ attention_mask: torch.Tensor = None,
180
+ attention_cache: torch.Tensor = None,
181
  ) -> Tuple[torch.Tensor, torch.Tensor]:
182
  """
183
  Forward just one chunk.