Spaces:
Running
Running
Antoni Bigata
commited on
Commit
·
6ea1ef7
1
Parent(s):
cb604f6
requirements
Browse files- WavLM_modules.py +112 -36
WavLM_modules.py
CHANGED
@@ -121,9 +121,14 @@ class GLU_Linear(nn.Module):
|
|
121 |
x = self.linear(x)
|
122 |
|
123 |
if self.glu_type == "bilinear":
|
124 |
-
x =
|
|
|
|
|
|
|
125 |
else:
|
126 |
-
x = x[:, :, 0 : self.output_dim] * self.glu_act(
|
|
|
|
|
127 |
|
128 |
return x
|
129 |
|
@@ -131,7 +136,9 @@ class GLU_Linear(nn.Module):
|
|
131 |
def gelu_accurate(x):
|
132 |
if not hasattr(gelu_accurate, "_a"):
|
133 |
gelu_accurate._a = math.sqrt(2 / math.pi)
|
134 |
-
return
|
|
|
|
|
135 |
|
136 |
|
137 |
def gelu(x: torch.Tensor) -> torch.Tensor:
|
@@ -223,13 +230,17 @@ def quant_noise(module, p, block_size):
|
|
223 |
|
224 |
# 2D matrix
|
225 |
if not is_conv:
|
226 |
-
assert module.weight.size(1) % block_size == 0,
|
|
|
|
|
227 |
|
228 |
# 4D matrix
|
229 |
else:
|
230 |
# 1x1 convolutions
|
231 |
if module.kernel_size == (1, 1):
|
232 |
-
assert module.in_channels % block_size == 0,
|
|
|
|
|
233 |
# regular convolutions
|
234 |
else:
|
235 |
k = module.kernel_size[0] * module.kernel_size[1]
|
@@ -245,7 +256,9 @@ def quant_noise(module, p, block_size):
|
|
245 |
out_features = weight.size(0)
|
246 |
|
247 |
# split weight matrix into blocks and randomly drop selected blocks
|
248 |
-
mask = torch.zeros(
|
|
|
|
|
249 |
mask.bernoulli_(p)
|
250 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
251 |
|
@@ -264,12 +277,20 @@ def quant_noise(module, p, block_size):
|
|
264 |
mask.bernoulli_(p)
|
265 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
266 |
else:
|
267 |
-
mask = torch.zeros(
|
|
|
|
|
268 |
mask.bernoulli_(p)
|
269 |
-
mask =
|
|
|
|
|
|
|
|
|
270 |
|
271 |
# scale weights and apply mask
|
272 |
-
mask = mask.to(
|
|
|
|
|
273 |
s = 1 / (1 - p)
|
274 |
mod.weight.data = s * weight.masked_fill(mask, 0)
|
275 |
|
@@ -320,14 +341,16 @@ class MultiheadAttention(nn.Module):
|
|
320 |
self.head_dim = embed_dim // num_heads
|
321 |
self.q_head_dim = self.head_dim
|
322 |
self.k_head_dim = self.head_dim
|
323 |
-
assert self.head_dim * num_heads == self.embed_dim,
|
|
|
|
|
324 |
self.scaling = self.head_dim**-0.5
|
325 |
|
326 |
self.self_attention = self_attention
|
327 |
self.encoder_decoder_attention = encoder_decoder_attention
|
328 |
|
329 |
assert not self.self_attention or self.qkv_same_dim, (
|
330 |
-
"Self-attention requires query, key and
|
331 |
)
|
332 |
|
333 |
k_bias = True
|
@@ -337,11 +360,19 @@ class MultiheadAttention(nn.Module):
|
|
337 |
k_embed_dim = embed_dim
|
338 |
q_embed_dim = embed_dim
|
339 |
|
340 |
-
self.k_proj = quant_noise(
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
-
self.out_proj = quant_noise(
|
|
|
|
|
345 |
|
346 |
if add_bias_kv:
|
347 |
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
@@ -390,7 +421,9 @@ class MultiheadAttention(nn.Module):
|
|
390 |
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
391 |
relative_positions = torch.abs(relative_positions)
|
392 |
else:
|
393 |
-
relative_positions = -torch.min(
|
|
|
|
|
394 |
|
395 |
max_exact = num_buckets // 2
|
396 |
is_small = relative_positions < max_exact
|
@@ -401,18 +434,25 @@ class MultiheadAttention(nn.Module):
|
|
401 |
* (num_buckets - max_exact)
|
402 |
).to(torch.long)
|
403 |
relative_postion_if_large = torch.min(
|
404 |
-
relative_postion_if_large,
|
|
|
405 |
)
|
406 |
|
407 |
-
relative_buckets += torch.where(
|
|
|
|
|
408 |
return relative_buckets
|
409 |
|
410 |
def compute_bias(self, query_length, key_length):
|
411 |
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
412 |
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
413 |
relative_position = memory_position - context_position
|
414 |
-
relative_position_bucket = self._relative_positions_bucket(
|
415 |
-
|
|
|
|
|
|
|
|
|
416 |
values = self.relative_attention_bias(relative_position_bucket)
|
417 |
values = values.permute([2, 0, 1])
|
418 |
return values
|
@@ -450,7 +490,7 @@ class MultiheadAttention(nn.Module):
|
|
450 |
if need_head_weights:
|
451 |
need_weights = True
|
452 |
|
453 |
-
is_tpu =
|
454 |
|
455 |
tgt_len, bsz, embed_dim = query.size()
|
456 |
src_len = tgt_len
|
@@ -466,7 +506,9 @@ class MultiheadAttention(nn.Module):
|
|
466 |
if self.has_relative_attention_bias and position_bias is None:
|
467 |
position_bias = self.compute_bias(tgt_len, src_len)
|
468 |
position_bias = (
|
469 |
-
position_bias.unsqueeze(0)
|
|
|
|
|
470 |
)
|
471 |
|
472 |
if (
|
@@ -492,10 +534,14 @@ class MultiheadAttention(nn.Module):
|
|
492 |
_B, _H, _L, __ = query_layer.size()
|
493 |
|
494 |
gate_a, gate_b = torch.sigmoid(
|
495 |
-
self.grep_linear(query_layer)
|
|
|
|
|
496 |
).chunk(2, dim=-1)
|
497 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
498 |
-
attn_mask_rel_pos =
|
|
|
|
|
499 |
|
500 |
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
501 |
k_proj_bias = self.k_proj.bias
|
@@ -565,7 +611,9 @@ class MultiheadAttention(nn.Module):
|
|
565 |
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
566 |
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
567 |
if attn_mask is not None:
|
568 |
-
attn_mask = torch.cat(
|
|
|
|
|
569 |
if key_padding_mask is not None:
|
570 |
key_padding_mask = torch.cat(
|
571 |
[
|
@@ -575,11 +623,23 @@ class MultiheadAttention(nn.Module):
|
|
575 |
dim=1,
|
576 |
)
|
577 |
|
578 |
-
q =
|
|
|
|
|
|
|
|
|
579 |
if k is not None:
|
580 |
-
k =
|
|
|
|
|
|
|
|
|
581 |
if v is not None:
|
582 |
-
v =
|
|
|
|
|
|
|
|
|
583 |
|
584 |
if saved_state is not None:
|
585 |
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
@@ -638,12 +698,16 @@ class MultiheadAttention(nn.Module):
|
|
638 |
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
639 |
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
640 |
if attn_mask is not None:
|
641 |
-
attn_mask = torch.cat(
|
|
|
|
|
642 |
if key_padding_mask is not None:
|
643 |
key_padding_mask = torch.cat(
|
644 |
[
|
645 |
key_padding_mask,
|
646 |
-
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
|
|
|
|
647 |
],
|
648 |
dim=1,
|
649 |
)
|
@@ -679,10 +743,14 @@ class MultiheadAttention(nn.Module):
|
|
679 |
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
680 |
_B, _H, _L, __ = query_layer.size()
|
681 |
gate_a, gate_b = torch.sigmoid(
|
682 |
-
self.grep_linear(query_layer)
|
|
|
|
|
683 |
).chunk(2, dim=-1)
|
684 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
685 |
-
position_bias =
|
|
|
|
|
686 |
|
687 |
position_bias = position_bias.view(attn_weights.size())
|
688 |
|
@@ -699,7 +767,9 @@ class MultiheadAttention(nn.Module):
|
|
699 |
attn = self.out_proj(attn)
|
700 |
attn_weights: Optional[Tensor] = None
|
701 |
if need_weights:
|
702 |
-
attn_weights = attn_weights_float.view(
|
|
|
|
|
703 |
if not need_head_weights:
|
704 |
# average attention weights over heads
|
705 |
attn_weights = attn_weights.mean(dim=0)
|
@@ -718,7 +788,9 @@ class MultiheadAttention(nn.Module):
|
|
718 |
if prev_key_padding_mask is not None and static_kv:
|
719 |
new_key_padding_mask = prev_key_padding_mask
|
720 |
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
721 |
-
new_key_padding_mask = torch.cat(
|
|
|
|
|
722 |
# During incremental decoding, as the padding token enters and
|
723 |
# leaves the frame, there will be a time when prev or current
|
724 |
# is None
|
@@ -728,7 +800,9 @@ class MultiheadAttention(nn.Module):
|
|
728 |
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
729 |
device=prev_key_padding_mask.device,
|
730 |
)
|
731 |
-
new_key_padding_mask = torch.cat(
|
|
|
|
|
732 |
else:
|
733 |
new_key_padding_mask = prev_key_padding_mask.float()
|
734 |
elif key_padding_mask is not None:
|
@@ -737,7 +811,9 @@ class MultiheadAttention(nn.Module):
|
|
737 |
(batch_size, src_len - key_padding_mask.size(1)),
|
738 |
device=key_padding_mask.device,
|
739 |
)
|
740 |
-
new_key_padding_mask = torch.cat(
|
|
|
|
|
741 |
else:
|
742 |
new_key_padding_mask = key_padding_mask.float()
|
743 |
else:
|
|
|
121 |
x = self.linear(x)
|
122 |
|
123 |
if self.glu_type == "bilinear":
|
124 |
+
x = (
|
125 |
+
x[:, :, 0 : self.output_dim]
|
126 |
+
* x[:, :, self.output_dim : self.output_dim * 2]
|
127 |
+
)
|
128 |
else:
|
129 |
+
x = x[:, :, 0 : self.output_dim] * self.glu_act(
|
130 |
+
x[:, :, self.output_dim : self.output_dim * 2]
|
131 |
+
)
|
132 |
|
133 |
return x
|
134 |
|
|
|
136 |
def gelu_accurate(x):
|
137 |
if not hasattr(gelu_accurate, "_a"):
|
138 |
gelu_accurate._a = math.sqrt(2 / math.pi)
|
139 |
+
return (
|
140 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
141 |
+
)
|
142 |
|
143 |
|
144 |
def gelu(x: torch.Tensor) -> torch.Tensor:
|
|
|
230 |
|
231 |
# 2D matrix
|
232 |
if not is_conv:
|
233 |
+
assert module.weight.size(1) % block_size == 0, (
|
234 |
+
"Input features must be a multiple of block sizes"
|
235 |
+
)
|
236 |
|
237 |
# 4D matrix
|
238 |
else:
|
239 |
# 1x1 convolutions
|
240 |
if module.kernel_size == (1, 1):
|
241 |
+
assert module.in_channels % block_size == 0, (
|
242 |
+
"Input channels must be a multiple of block sizes"
|
243 |
+
)
|
244 |
# regular convolutions
|
245 |
else:
|
246 |
k = module.kernel_size[0] * module.kernel_size[1]
|
|
|
256 |
out_features = weight.size(0)
|
257 |
|
258 |
# split weight matrix into blocks and randomly drop selected blocks
|
259 |
+
mask = torch.zeros(
|
260 |
+
in_features // block_size * out_features, device=weight.device
|
261 |
+
)
|
262 |
mask.bernoulli_(p)
|
263 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
264 |
|
|
|
277 |
mask.bernoulli_(p)
|
278 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
279 |
else:
|
280 |
+
mask = torch.zeros(
|
281 |
+
weight.size(0), weight.size(1), device=weight.device
|
282 |
+
)
|
283 |
mask.bernoulli_(p)
|
284 |
+
mask = (
|
285 |
+
mask.unsqueeze(2)
|
286 |
+
.unsqueeze(3)
|
287 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
288 |
+
)
|
289 |
|
290 |
# scale weights and apply mask
|
291 |
+
mask = mask.to(
|
292 |
+
torch.bool
|
293 |
+
) # x.bool() is not currently supported in TorchScript
|
294 |
s = 1 / (1 - p)
|
295 |
mod.weight.data = s * weight.masked_fill(mask, 0)
|
296 |
|
|
|
341 |
self.head_dim = embed_dim // num_heads
|
342 |
self.q_head_dim = self.head_dim
|
343 |
self.k_head_dim = self.head_dim
|
344 |
+
assert self.head_dim * num_heads == self.embed_dim, (
|
345 |
+
"embed_dim must be divisible by num_heads"
|
346 |
+
)
|
347 |
self.scaling = self.head_dim**-0.5
|
348 |
|
349 |
self.self_attention = self_attention
|
350 |
self.encoder_decoder_attention = encoder_decoder_attention
|
351 |
|
352 |
assert not self.self_attention or self.qkv_same_dim, (
|
353 |
+
"Self-attention requires query, key and value to be of the same size"
|
354 |
)
|
355 |
|
356 |
k_bias = True
|
|
|
360 |
k_embed_dim = embed_dim
|
361 |
q_embed_dim = embed_dim
|
362 |
|
363 |
+
self.k_proj = quant_noise(
|
364 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
365 |
+
)
|
366 |
+
self.v_proj = quant_noise(
|
367 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
368 |
+
)
|
369 |
+
self.q_proj = quant_noise(
|
370 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
371 |
+
)
|
372 |
|
373 |
+
self.out_proj = quant_noise(
|
374 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
375 |
+
)
|
376 |
|
377 |
if add_bias_kv:
|
378 |
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
|
|
421 |
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
422 |
relative_positions = torch.abs(relative_positions)
|
423 |
else:
|
424 |
+
relative_positions = -torch.min(
|
425 |
+
relative_positions, torch.zeros_like(relative_positions)
|
426 |
+
)
|
427 |
|
428 |
max_exact = num_buckets // 2
|
429 |
is_small = relative_positions < max_exact
|
|
|
434 |
* (num_buckets - max_exact)
|
435 |
).to(torch.long)
|
436 |
relative_postion_if_large = torch.min(
|
437 |
+
relative_postion_if_large,
|
438 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
439 |
)
|
440 |
|
441 |
+
relative_buckets += torch.where(
|
442 |
+
is_small, relative_positions, relative_postion_if_large
|
443 |
+
)
|
444 |
return relative_buckets
|
445 |
|
446 |
def compute_bias(self, query_length, key_length):
|
447 |
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
448 |
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
449 |
relative_position = memory_position - context_position
|
450 |
+
relative_position_bucket = self._relative_positions_bucket(
|
451 |
+
relative_position, bidirectional=True
|
452 |
+
)
|
453 |
+
relative_position_bucket = relative_position_bucket.to(
|
454 |
+
self.relative_attention_bias.weight.device
|
455 |
+
)
|
456 |
values = self.relative_attention_bias(relative_position_bucket)
|
457 |
values = values.permute([2, 0, 1])
|
458 |
return values
|
|
|
490 |
if need_head_weights:
|
491 |
need_weights = True
|
492 |
|
493 |
+
is_tpu = False
|
494 |
|
495 |
tgt_len, bsz, embed_dim = query.size()
|
496 |
src_len = tgt_len
|
|
|
506 |
if self.has_relative_attention_bias and position_bias is None:
|
507 |
position_bias = self.compute_bias(tgt_len, src_len)
|
508 |
position_bias = (
|
509 |
+
position_bias.unsqueeze(0)
|
510 |
+
.repeat(bsz, 1, 1, 1)
|
511 |
+
.view(bsz * self.num_heads, tgt_len, src_len)
|
512 |
)
|
513 |
|
514 |
if (
|
|
|
534 |
_B, _H, _L, __ = query_layer.size()
|
535 |
|
536 |
gate_a, gate_b = torch.sigmoid(
|
537 |
+
self.grep_linear(query_layer)
|
538 |
+
.view(_B, _H, _L, 2, 4)
|
539 |
+
.sum(-1, keepdim=False)
|
540 |
).chunk(2, dim=-1)
|
541 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
542 |
+
attn_mask_rel_pos = (
|
543 |
+
gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
544 |
+
)
|
545 |
|
546 |
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
547 |
k_proj_bias = self.k_proj.bias
|
|
|
611 |
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
612 |
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
613 |
if attn_mask is not None:
|
614 |
+
attn_mask = torch.cat(
|
615 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
616 |
+
)
|
617 |
if key_padding_mask is not None:
|
618 |
key_padding_mask = torch.cat(
|
619 |
[
|
|
|
623 |
dim=1,
|
624 |
)
|
625 |
|
626 |
+
q = (
|
627 |
+
q.contiguous()
|
628 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
629 |
+
.transpose(0, 1)
|
630 |
+
)
|
631 |
if k is not None:
|
632 |
+
k = (
|
633 |
+
k.contiguous()
|
634 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
635 |
+
.transpose(0, 1)
|
636 |
+
)
|
637 |
if v is not None:
|
638 |
+
v = (
|
639 |
+
v.contiguous()
|
640 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
641 |
+
.transpose(0, 1)
|
642 |
+
)
|
643 |
|
644 |
if saved_state is not None:
|
645 |
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
|
|
698 |
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
699 |
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
700 |
if attn_mask is not None:
|
701 |
+
attn_mask = torch.cat(
|
702 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
703 |
+
)
|
704 |
if key_padding_mask is not None:
|
705 |
key_padding_mask = torch.cat(
|
706 |
[
|
707 |
key_padding_mask,
|
708 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
709 |
+
key_padding_mask
|
710 |
+
),
|
711 |
],
|
712 |
dim=1,
|
713 |
)
|
|
|
743 |
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
744 |
_B, _H, _L, __ = query_layer.size()
|
745 |
gate_a, gate_b = torch.sigmoid(
|
746 |
+
self.grep_linear(query_layer)
|
747 |
+
.view(_B, _H, _L, 2, 4)
|
748 |
+
.sum(-1, keepdim=False)
|
749 |
).chunk(2, dim=-1)
|
750 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
751 |
+
position_bias = (
|
752 |
+
gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
753 |
+
)
|
754 |
|
755 |
position_bias = position_bias.view(attn_weights.size())
|
756 |
|
|
|
767 |
attn = self.out_proj(attn)
|
768 |
attn_weights: Optional[Tensor] = None
|
769 |
if need_weights:
|
770 |
+
attn_weights = attn_weights_float.view(
|
771 |
+
bsz, self.num_heads, tgt_len, src_len
|
772 |
+
).transpose(1, 0)
|
773 |
if not need_head_weights:
|
774 |
# average attention weights over heads
|
775 |
attn_weights = attn_weights.mean(dim=0)
|
|
|
788 |
if prev_key_padding_mask is not None and static_kv:
|
789 |
new_key_padding_mask = prev_key_padding_mask
|
790 |
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
791 |
+
new_key_padding_mask = torch.cat(
|
792 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
793 |
+
)
|
794 |
# During incremental decoding, as the padding token enters and
|
795 |
# leaves the frame, there will be a time when prev or current
|
796 |
# is None
|
|
|
800 |
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
801 |
device=prev_key_padding_mask.device,
|
802 |
)
|
803 |
+
new_key_padding_mask = torch.cat(
|
804 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
805 |
+
)
|
806 |
else:
|
807 |
new_key_padding_mask = prev_key_padding_mask.float()
|
808 |
elif key_padding_mask is not None:
|
|
|
811 |
(batch_size, src_len - key_padding_mask.size(1)),
|
812 |
device=key_padding_mask.device,
|
813 |
)
|
814 |
+
new_key_padding_mask = torch.cat(
|
815 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
816 |
+
)
|
817 |
else:
|
818 |
new_key_padding_mask = key_padding_mask.float()
|
819 |
else:
|