Srinivasan Iyer sviyer commited on
Commit
f3e8125
·
unverified ·
1 Parent(s): c49e251

using apex rmsnorm (#57)

Browse files

* using apex rmsnorm

* added message for missing apex

* black

* missed a print

---------

Co-authored-by: Srini Iyer <[email protected]>

bytelatent/base_transformer.py CHANGED
@@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha
17
  from bytelatent import probe
18
  from bytelatent.tokenizers.constants import EOS_ID
19
 
 
 
 
 
 
 
 
 
20
  if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
21
  flex_attention_comp = torch.compile(flex_attention)
22
  else:
@@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module):
294
  return self.freqs_cis[0:seqlen]
295
 
296
 
297
- class RMSNorm(nn.Module):
298
- """
299
- Initialize the RMSNorm normalization layer.
300
-
301
- Args:
302
- dim (int): The dimension of the input tensor.
303
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
304
-
305
- Attributes:
306
- eps (float): A small value added to the denominator for numerical stability.
307
- weight (nn.Parameter): Learnable scaling parameter.
308
-
309
- """
310
-
311
- def __init__(self, dim: int, eps: float = 1e-6):
312
- super().__init__()
313
- self.eps = eps
314
- self.weight = nn.Parameter(torch.ones(dim))
315
-
316
- def _norm(self, x: torch.Tensor):
317
- return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
318
-
319
- def forward(self, x: torch.Tensor):
320
- x = probe.log_stats(x, "resid")
321
- output = self._norm(x.float())
322
- return (output * self.weight.float()).type_as(x)
323
-
324
- def reset_parameters(self):
325
- torch.nn.init.ones_(self.weight) # type: ignore
326
-
327
-
328
  def _reshape_for_attn_bias(
329
  attn_bias: AttentionBias | None,
330
  *tensors: torch.Tensor,
 
17
  from bytelatent import probe
18
  from bytelatent.tokenizers.constants import EOS_ID
19
 
20
+ try:
21
+ from apex.normalization.fused_layer_norm import FusedRMSNorm
22
+
23
+ RMSNorm = FusedRMSNorm
24
+ except (ImportError, ModuleNotFoundError):
25
+ print("Apex not found. Using nn.RMSNorm")
26
+ RMSNorm = nn.RMSNorm
27
+
28
  if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
29
  flex_attention_comp = torch.compile(flex_attention)
30
  else:
 
302
  return self.freqs_cis[0:seqlen]
303
 
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def _reshape_for_attn_bias(
306
  attn_bias: AttentionBias | None,
307
  *tensors: torch.Tensor,
bytelatent/model/latent_transformer.py CHANGED
@@ -12,12 +12,19 @@ from xformers.ops import AttentionBias
12
  from bytelatent.base_transformer import (
13
  BaseTransformer,
14
  BaseTransformerArgs,
15
- RMSNorm,
16
  flex_attention_comp,
17
  repeat_kv,
18
  )
19
  from bytelatent.model.utils import create_causal_mask
20
 
 
 
 
 
 
 
 
 
21
  logger = logging.getLogger()
22
 
23
 
@@ -44,7 +51,7 @@ class CrossAttention(nn.Module):
44
  self.n_kv_heads = n_kv_heads
45
  self.heads_per_group = self.n_heads // self.n_kv_heads
46
 
47
- self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
48
  self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
49
 
50
  self.wq = nn.Linear(
 
12
  from bytelatent.base_transformer import (
13
  BaseTransformer,
14
  BaseTransformerArgs,
 
15
  flex_attention_comp,
16
  repeat_kv,
17
  )
18
  from bytelatent.model.utils import create_causal_mask
19
 
20
+ try:
21
+ from apex.normalization.fused_layer_norm import FusedRMSNorm
22
+
23
+ RMSNorm = FusedRMSNorm
24
+ except (ImportError, ModuleNotFoundError):
25
+ print("Apex not found. Using nn.RMSNorm")
26
+ RMSNorm = nn.RMSNorm
27
+
28
  logger = logging.getLogger()
29
 
30
 
 
51
  self.n_kv_heads = n_kv_heads
52
  self.heads_per_group = self.n_heads // self.n_kv_heads
53
 
54
+ self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
55
  self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
56
 
57
  self.wq = nn.Linear(
bytelatent/model/local_models.py CHANGED
@@ -14,7 +14,6 @@ from xformers.ops import AttentionBias
14
  from bytelatent.base_transformer import (
15
  BaseTransformerArgs,
16
  InitStdFactor,
17
- RMSNorm,
18
  RotaryEmbedding,
19
  TransformerBlock,
20
  )
@@ -22,6 +21,14 @@ from bytelatent.model.latent_transformer import CrossAttention
22
  from bytelatent.model.utils import create_causal_mask, downsample
23
  from bytelatent.tokenizers.blt_tokenizer import BOE_ID
24
 
 
 
 
 
 
 
 
 
25
  logger = logging.getLogger()
26
 
27
 
 
14
  from bytelatent.base_transformer import (
15
  BaseTransformerArgs,
16
  InitStdFactor,
 
17
  RotaryEmbedding,
18
  TransformerBlock,
19
  )
 
21
  from bytelatent.model.utils import create_causal_mask, downsample
22
  from bytelatent.tokenizers.blt_tokenizer import BOE_ID
23
 
24
+ try:
25
+ from apex.normalization.fused_layer_norm import FusedRMSNorm
26
+
27
+ RMSNorm = FusedRMSNorm
28
+ except (ImportError, ModuleNotFoundError):
29
+ print("Apex not found. Using nn.RMSNorm")
30
+ RMSNorm = nn.RMSNorm
31
+
32
  logger = logging.getLogger()
33
 
34
 
bytelatent/transformer.py CHANGED
@@ -19,11 +19,18 @@ from xformers.ops import AttentionBias, fmha
19
  from bytelatent.base_transformer import (
20
  BaseTransformer,
21
  BaseTransformerArgs,
22
- RMSNorm,
23
  cross_entropy,
24
  )
25
  from bytelatent.model.utils import create_causal_mask
26
 
 
 
 
 
 
 
 
 
27
 
28
  def attention_flops_per_token(n_layers, seq_len, dim, causal):
29
  # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
 
19
  from bytelatent.base_transformer import (
20
  BaseTransformer,
21
  BaseTransformerArgs,
 
22
  cross_entropy,
23
  )
24
  from bytelatent.model.utils import create_causal_mask
25
 
26
+ try:
27
+ from apex.normalization.fused_layer_norm import FusedRMSNorm
28
+
29
+ RMSNorm = FusedRMSNorm
30
+ except (ImportError, ModuleNotFoundError):
31
+ print("Apex not found. Using nn.RMSNorm")
32
+ RMSNorm = nn.RMSNorm
33
+
34
 
35
  def attention_flops_per_token(n_layers, seq_len, dim, causal):
36
  # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30