Antoni Bigata commited on
Commit
6ea1ef7
·
1 Parent(s): cb604f6

requirements

Browse files
Files changed (1) hide show
  1. WavLM_modules.py +112 -36
WavLM_modules.py CHANGED
@@ -121,9 +121,14 @@ class GLU_Linear(nn.Module):
121
  x = self.linear(x)
122
 
123
  if self.glu_type == "bilinear":
124
- x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
 
 
 
125
  else:
126
- x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
 
 
127
 
128
  return x
129
 
@@ -131,7 +136,9 @@ class GLU_Linear(nn.Module):
131
  def gelu_accurate(x):
132
  if not hasattr(gelu_accurate, "_a"):
133
  gelu_accurate._a = math.sqrt(2 / math.pi)
134
- return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
 
 
135
 
136
 
137
  def gelu(x: torch.Tensor) -> torch.Tensor:
@@ -223,13 +230,17 @@ def quant_noise(module, p, block_size):
223
 
224
  # 2D matrix
225
  if not is_conv:
226
- assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
 
 
227
 
228
  # 4D matrix
229
  else:
230
  # 1x1 convolutions
231
  if module.kernel_size == (1, 1):
232
- assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
 
 
233
  # regular convolutions
234
  else:
235
  k = module.kernel_size[0] * module.kernel_size[1]
@@ -245,7 +256,9 @@ def quant_noise(module, p, block_size):
245
  out_features = weight.size(0)
246
 
247
  # split weight matrix into blocks and randomly drop selected blocks
248
- mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
 
 
249
  mask.bernoulli_(p)
250
  mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
251
 
@@ -264,12 +277,20 @@ def quant_noise(module, p, block_size):
264
  mask.bernoulli_(p)
265
  mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
266
  else:
267
- mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
 
 
268
  mask.bernoulli_(p)
269
- mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
 
 
 
 
270
 
271
  # scale weights and apply mask
272
- mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
 
 
273
  s = 1 / (1 - p)
274
  mod.weight.data = s * weight.masked_fill(mask, 0)
275
 
@@ -320,14 +341,16 @@ class MultiheadAttention(nn.Module):
320
  self.head_dim = embed_dim // num_heads
321
  self.q_head_dim = self.head_dim
322
  self.k_head_dim = self.head_dim
323
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
 
 
324
  self.scaling = self.head_dim**-0.5
325
 
326
  self.self_attention = self_attention
327
  self.encoder_decoder_attention = encoder_decoder_attention
328
 
329
  assert not self.self_attention or self.qkv_same_dim, (
330
- "Self-attention requires query, key and " "value to be of the same size"
331
  )
332
 
333
  k_bias = True
@@ -337,11 +360,19 @@ class MultiheadAttention(nn.Module):
337
  k_embed_dim = embed_dim
338
  q_embed_dim = embed_dim
339
 
340
- self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
341
- self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
342
- self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
 
 
 
 
 
 
343
 
344
- self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
 
 
345
 
346
  if add_bias_kv:
347
  self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
@@ -390,7 +421,9 @@ class MultiheadAttention(nn.Module):
390
  relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
391
  relative_positions = torch.abs(relative_positions)
392
  else:
393
- relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
 
 
394
 
395
  max_exact = num_buckets // 2
396
  is_small = relative_positions < max_exact
@@ -401,18 +434,25 @@ class MultiheadAttention(nn.Module):
401
  * (num_buckets - max_exact)
402
  ).to(torch.long)
403
  relative_postion_if_large = torch.min(
404
- relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
 
405
  )
406
 
407
- relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
 
 
408
  return relative_buckets
409
 
410
  def compute_bias(self, query_length, key_length):
411
  context_position = torch.arange(query_length, dtype=torch.long)[:, None]
412
  memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
413
  relative_position = memory_position - context_position
414
- relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
415
- relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
 
 
 
 
416
  values = self.relative_attention_bias(relative_position_bucket)
417
  values = values.permute([2, 0, 1])
418
  return values
@@ -450,7 +490,7 @@ class MultiheadAttention(nn.Module):
450
  if need_head_weights:
451
  need_weights = True
452
 
453
- is_tpu = query.device.type == "xla"
454
 
455
  tgt_len, bsz, embed_dim = query.size()
456
  src_len = tgt_len
@@ -466,7 +506,9 @@ class MultiheadAttention(nn.Module):
466
  if self.has_relative_attention_bias and position_bias is None:
467
  position_bias = self.compute_bias(tgt_len, src_len)
468
  position_bias = (
469
- position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
 
 
470
  )
471
 
472
  if (
@@ -492,10 +534,14 @@ class MultiheadAttention(nn.Module):
492
  _B, _H, _L, __ = query_layer.size()
493
 
494
  gate_a, gate_b = torch.sigmoid(
495
- self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
 
 
496
  ).chunk(2, dim=-1)
497
  gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
498
- attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
 
 
499
 
500
  attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
501
  k_proj_bias = self.k_proj.bias
@@ -565,7 +611,9 @@ class MultiheadAttention(nn.Module):
565
  k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
566
  v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
567
  if attn_mask is not None:
568
- attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
 
 
569
  if key_padding_mask is not None:
570
  key_padding_mask = torch.cat(
571
  [
@@ -575,11 +623,23 @@ class MultiheadAttention(nn.Module):
575
  dim=1,
576
  )
577
 
578
- q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
 
 
 
 
579
  if k is not None:
580
- k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
 
 
 
 
581
  if v is not None:
582
- v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
 
 
 
 
583
 
584
  if saved_state is not None:
585
  # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
@@ -638,12 +698,16 @@ class MultiheadAttention(nn.Module):
638
  k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
639
  v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
640
  if attn_mask is not None:
641
- attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
 
 
642
  if key_padding_mask is not None:
643
  key_padding_mask = torch.cat(
644
  [
645
  key_padding_mask,
646
- torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
 
 
647
  ],
648
  dim=1,
649
  )
@@ -679,10 +743,14 @@ class MultiheadAttention(nn.Module):
679
  query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
680
  _B, _H, _L, __ = query_layer.size()
681
  gate_a, gate_b = torch.sigmoid(
682
- self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
 
 
683
  ).chunk(2, dim=-1)
684
  gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
685
- position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
 
 
686
 
687
  position_bias = position_bias.view(attn_weights.size())
688
 
@@ -699,7 +767,9 @@ class MultiheadAttention(nn.Module):
699
  attn = self.out_proj(attn)
700
  attn_weights: Optional[Tensor] = None
701
  if need_weights:
702
- attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
 
 
703
  if not need_head_weights:
704
  # average attention weights over heads
705
  attn_weights = attn_weights.mean(dim=0)
@@ -718,7 +788,9 @@ class MultiheadAttention(nn.Module):
718
  if prev_key_padding_mask is not None and static_kv:
719
  new_key_padding_mask = prev_key_padding_mask
720
  elif prev_key_padding_mask is not None and key_padding_mask is not None:
721
- new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
 
 
722
  # During incremental decoding, as the padding token enters and
723
  # leaves the frame, there will be a time when prev or current
724
  # is None
@@ -728,7 +800,9 @@ class MultiheadAttention(nn.Module):
728
  (batch_size, src_len - prev_key_padding_mask.size(1)),
729
  device=prev_key_padding_mask.device,
730
  )
731
- new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
 
 
732
  else:
733
  new_key_padding_mask = prev_key_padding_mask.float()
734
  elif key_padding_mask is not None:
@@ -737,7 +811,9 @@ class MultiheadAttention(nn.Module):
737
  (batch_size, src_len - key_padding_mask.size(1)),
738
  device=key_padding_mask.device,
739
  )
740
- new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
 
 
741
  else:
742
  new_key_padding_mask = key_padding_mask.float()
743
  else:
 
121
  x = self.linear(x)
122
 
123
  if self.glu_type == "bilinear":
124
+ x = (
125
+ x[:, :, 0 : self.output_dim]
126
+ * x[:, :, self.output_dim : self.output_dim * 2]
127
+ )
128
  else:
129
+ x = x[:, :, 0 : self.output_dim] * self.glu_act(
130
+ x[:, :, self.output_dim : self.output_dim * 2]
131
+ )
132
 
133
  return x
134
 
 
136
  def gelu_accurate(x):
137
  if not hasattr(gelu_accurate, "_a"):
138
  gelu_accurate._a = math.sqrt(2 / math.pi)
139
+ return (
140
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
141
+ )
142
 
143
 
144
  def gelu(x: torch.Tensor) -> torch.Tensor:
 
230
 
231
  # 2D matrix
232
  if not is_conv:
233
+ assert module.weight.size(1) % block_size == 0, (
234
+ "Input features must be a multiple of block sizes"
235
+ )
236
 
237
  # 4D matrix
238
  else:
239
  # 1x1 convolutions
240
  if module.kernel_size == (1, 1):
241
+ assert module.in_channels % block_size == 0, (
242
+ "Input channels must be a multiple of block sizes"
243
+ )
244
  # regular convolutions
245
  else:
246
  k = module.kernel_size[0] * module.kernel_size[1]
 
256
  out_features = weight.size(0)
257
 
258
  # split weight matrix into blocks and randomly drop selected blocks
259
+ mask = torch.zeros(
260
+ in_features // block_size * out_features, device=weight.device
261
+ )
262
  mask.bernoulli_(p)
263
  mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
264
 
 
277
  mask.bernoulli_(p)
278
  mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
279
  else:
280
+ mask = torch.zeros(
281
+ weight.size(0), weight.size(1), device=weight.device
282
+ )
283
  mask.bernoulli_(p)
284
+ mask = (
285
+ mask.unsqueeze(2)
286
+ .unsqueeze(3)
287
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
288
+ )
289
 
290
  # scale weights and apply mask
291
+ mask = mask.to(
292
+ torch.bool
293
+ ) # x.bool() is not currently supported in TorchScript
294
  s = 1 / (1 - p)
295
  mod.weight.data = s * weight.masked_fill(mask, 0)
296
 
 
341
  self.head_dim = embed_dim // num_heads
342
  self.q_head_dim = self.head_dim
343
  self.k_head_dim = self.head_dim
344
+ assert self.head_dim * num_heads == self.embed_dim, (
345
+ "embed_dim must be divisible by num_heads"
346
+ )
347
  self.scaling = self.head_dim**-0.5
348
 
349
  self.self_attention = self_attention
350
  self.encoder_decoder_attention = encoder_decoder_attention
351
 
352
  assert not self.self_attention or self.qkv_same_dim, (
353
+ "Self-attention requires query, key and value to be of the same size"
354
  )
355
 
356
  k_bias = True
 
360
  k_embed_dim = embed_dim
361
  q_embed_dim = embed_dim
362
 
363
+ self.k_proj = quant_noise(
364
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
365
+ )
366
+ self.v_proj = quant_noise(
367
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
368
+ )
369
+ self.q_proj = quant_noise(
370
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
 
373
+ self.out_proj = quant_noise(
374
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
375
+ )
376
 
377
  if add_bias_kv:
378
  self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
 
421
  relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
422
  relative_positions = torch.abs(relative_positions)
423
  else:
424
+ relative_positions = -torch.min(
425
+ relative_positions, torch.zeros_like(relative_positions)
426
+ )
427
 
428
  max_exact = num_buckets // 2
429
  is_small = relative_positions < max_exact
 
434
  * (num_buckets - max_exact)
435
  ).to(torch.long)
436
  relative_postion_if_large = torch.min(
437
+ relative_postion_if_large,
438
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
439
  )
440
 
441
+ relative_buckets += torch.where(
442
+ is_small, relative_positions, relative_postion_if_large
443
+ )
444
  return relative_buckets
445
 
446
  def compute_bias(self, query_length, key_length):
447
  context_position = torch.arange(query_length, dtype=torch.long)[:, None]
448
  memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
449
  relative_position = memory_position - context_position
450
+ relative_position_bucket = self._relative_positions_bucket(
451
+ relative_position, bidirectional=True
452
+ )
453
+ relative_position_bucket = relative_position_bucket.to(
454
+ self.relative_attention_bias.weight.device
455
+ )
456
  values = self.relative_attention_bias(relative_position_bucket)
457
  values = values.permute([2, 0, 1])
458
  return values
 
490
  if need_head_weights:
491
  need_weights = True
492
 
493
+ is_tpu = False
494
 
495
  tgt_len, bsz, embed_dim = query.size()
496
  src_len = tgt_len
 
506
  if self.has_relative_attention_bias and position_bias is None:
507
  position_bias = self.compute_bias(tgt_len, src_len)
508
  position_bias = (
509
+ position_bias.unsqueeze(0)
510
+ .repeat(bsz, 1, 1, 1)
511
+ .view(bsz * self.num_heads, tgt_len, src_len)
512
  )
513
 
514
  if (
 
534
  _B, _H, _L, __ = query_layer.size()
535
 
536
  gate_a, gate_b = torch.sigmoid(
537
+ self.grep_linear(query_layer)
538
+ .view(_B, _H, _L, 2, 4)
539
+ .sum(-1, keepdim=False)
540
  ).chunk(2, dim=-1)
541
  gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
542
+ attn_mask_rel_pos = (
543
+ gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
544
+ )
545
 
546
  attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
547
  k_proj_bias = self.k_proj.bias
 
611
  k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
612
  v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
613
  if attn_mask is not None:
614
+ attn_mask = torch.cat(
615
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
616
+ )
617
  if key_padding_mask is not None:
618
  key_padding_mask = torch.cat(
619
  [
 
623
  dim=1,
624
  )
625
 
626
+ q = (
627
+ q.contiguous()
628
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
629
+ .transpose(0, 1)
630
+ )
631
  if k is not None:
632
+ k = (
633
+ k.contiguous()
634
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
635
+ .transpose(0, 1)
636
+ )
637
  if v is not None:
638
+ v = (
639
+ v.contiguous()
640
+ .view(-1, bsz * self.num_heads, self.head_dim)
641
+ .transpose(0, 1)
642
+ )
643
 
644
  if saved_state is not None:
645
  # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
 
698
  k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
699
  v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
700
  if attn_mask is not None:
701
+ attn_mask = torch.cat(
702
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
703
+ )
704
  if key_padding_mask is not None:
705
  key_padding_mask = torch.cat(
706
  [
707
  key_padding_mask,
708
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
709
+ key_padding_mask
710
+ ),
711
  ],
712
  dim=1,
713
  )
 
743
  query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
744
  _B, _H, _L, __ = query_layer.size()
745
  gate_a, gate_b = torch.sigmoid(
746
+ self.grep_linear(query_layer)
747
+ .view(_B, _H, _L, 2, 4)
748
+ .sum(-1, keepdim=False)
749
  ).chunk(2, dim=-1)
750
  gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
751
+ position_bias = (
752
+ gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
753
+ )
754
 
755
  position_bias = position_bias.view(attn_weights.size())
756
 
 
767
  attn = self.out_proj(attn)
768
  attn_weights: Optional[Tensor] = None
769
  if need_weights:
770
+ attn_weights = attn_weights_float.view(
771
+ bsz, self.num_heads, tgt_len, src_len
772
+ ).transpose(1, 0)
773
  if not need_head_weights:
774
  # average attention weights over heads
775
  attn_weights = attn_weights.mean(dim=0)
 
788
  if prev_key_padding_mask is not None and static_kv:
789
  new_key_padding_mask = prev_key_padding_mask
790
  elif prev_key_padding_mask is not None and key_padding_mask is not None:
791
+ new_key_padding_mask = torch.cat(
792
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
793
+ )
794
  # During incremental decoding, as the padding token enters and
795
  # leaves the frame, there will be a time when prev or current
796
  # is None
 
800
  (batch_size, src_len - prev_key_padding_mask.size(1)),
801
  device=prev_key_padding_mask.device,
802
  )
803
+ new_key_padding_mask = torch.cat(
804
+ [prev_key_padding_mask.float(), filler.float()], dim=1
805
+ )
806
  else:
807
  new_key_padding_mask = prev_key_padding_mask.float()
808
  elif key_padding_mask is not None:
 
811
  (batch_size, src_len - key_padding_mask.size(1)),
812
  device=key_padding_mask.device,
813
  )
814
+ new_key_padding_mask = torch.cat(
815
+ [filler.float(), key_padding_mask.float()], dim=1
816
+ )
817
  else:
818
  new_key_padding_mask = key_padding_mask.float()
819
  else: