par-meta commited on
Commit
b0956bd
·
unverified ·
1 Parent(s): 82ab593

Make apex logs less noisy (#60)

Browse files
bytelatent/base_transformer.py CHANGED
@@ -1,4 +1,5 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
  import os
3
  from enum import Enum
4
  from typing import Optional, Tuple, Union
@@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
14
  )
15
  from xformers.ops import AttentionBias, fmha
16
 
17
- from bytelatent import probe
18
  from bytelatent.tokenizers.constants import EOS_ID
19
 
 
 
20
  try:
21
  from apex.normalization.fused_layer_norm import FusedRMSNorm
22
 
23
  RMSNorm = FusedRMSNorm
24
  except (ImportError, ModuleNotFoundError):
25
- print("Apex not found. Using nn.RMSNorm")
26
  RMSNorm = nn.RMSNorm
27
 
28
  if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import logging
3
  import os
4
  from enum import Enum
5
  from typing import Optional, Tuple, Union
 
15
  )
16
  from xformers.ops import AttentionBias, fmha
17
 
 
18
  from bytelatent.tokenizers.constants import EOS_ID
19
 
20
+ logger = logging.getLogger()
21
+
22
  try:
23
  from apex.normalization.fused_layer_norm import FusedRMSNorm
24
 
25
  RMSNorm = FusedRMSNorm
26
  except (ImportError, ModuleNotFoundError):
27
+ logging.debug("Apex not found. Using nn.RMSNorm")
28
  RMSNorm = nn.RMSNorm
29
 
30
  if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
bytelatent/model/latent_transformer.py CHANGED
@@ -17,16 +17,15 @@ from bytelatent.base_transformer import (
17
  )
18
  from bytelatent.model.utils import create_causal_mask
19
 
 
20
  try:
21
  from apex.normalization.fused_layer_norm import FusedRMSNorm
22
 
23
  RMSNorm = FusedRMSNorm
24
  except (ImportError, ModuleNotFoundError):
25
- print("Apex not found. Using nn.RMSNorm")
26
  RMSNorm = nn.RMSNorm
27
 
28
- logger = logging.getLogger()
29
-
30
 
31
  class CrossAttention(nn.Module):
32
  """
 
17
  )
18
  from bytelatent.model.utils import create_causal_mask
19
 
20
+ logger = logging.getLogger()
21
  try:
22
  from apex.normalization.fused_layer_norm import FusedRMSNorm
23
 
24
  RMSNorm = FusedRMSNorm
25
  except (ImportError, ModuleNotFoundError):
26
+ logging.debug("Apex not found. Using nn.RMSNorm")
27
  RMSNorm = nn.RMSNorm
28
 
 
 
29
 
30
  class CrossAttention(nn.Module):
31
  """
bytelatent/model/local_models.py CHANGED
@@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
6
  import torch
7
  import torch.nn
8
  import torch.nn as nn
9
- from pydantic import BaseModel, ConfigDict
10
  from torch.nn import functional as F
11
  from torch.nn.attention.flex_attention import BlockMask
12
  from xformers.ops import AttentionBias
@@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention
21
  from bytelatent.model.utils import create_causal_mask, downsample
22
  from bytelatent.tokenizers.blt_tokenizer import BOE_ID
23
 
 
24
  try:
25
  from apex.normalization.fused_layer_norm import FusedRMSNorm
26
 
27
  RMSNorm = FusedRMSNorm
28
  except (ImportError, ModuleNotFoundError):
29
- print("Apex not found. Using nn.RMSNorm")
30
  RMSNorm = nn.RMSNorm
31
 
32
- logger = logging.getLogger()
33
-
34
 
35
  class LocalModelArgs(BaseTransformerArgs):
36
  model_config = ConfigDict(extra="forbid")
 
6
  import torch
7
  import torch.nn
8
  import torch.nn as nn
9
+ from pydantic import ConfigDict
10
  from torch.nn import functional as F
11
  from torch.nn.attention.flex_attention import BlockMask
12
  from xformers.ops import AttentionBias
 
21
  from bytelatent.model.utils import create_causal_mask, downsample
22
  from bytelatent.tokenizers.blt_tokenizer import BOE_ID
23
 
24
+ logger = logging.getLogger()
25
  try:
26
  from apex.normalization.fused_layer_norm import FusedRMSNorm
27
 
28
  RMSNorm = FusedRMSNorm
29
  except (ImportError, ModuleNotFoundError):
30
+ logging.debug("Apex not found. Using nn.RMSNorm")
31
  RMSNorm = nn.RMSNorm
32
 
 
 
33
 
34
  class LocalModelArgs(BaseTransformerArgs):
35
  model_config = ConfigDict(extra="forbid")
bytelatent/transformer.py CHANGED
@@ -1,6 +1,6 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
- from dataclasses import dataclass
4
  from typing import Optional, Tuple, Union
5
 
6
  import torch
@@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import (
14
  parallelize_module,
15
  )
16
  from torch.nn.attention.flex_attention import BlockMask, create_block_mask
17
- from xformers.ops import AttentionBias, fmha
18
 
19
  from bytelatent.base_transformer import (
20
  BaseTransformer,
@@ -23,12 +23,14 @@ from bytelatent.base_transformer import (
23
  )
24
  from bytelatent.model.utils import create_causal_mask
25
 
 
 
26
  try:
27
  from apex.normalization.fused_layer_norm import FusedRMSNorm
28
 
29
  RMSNorm = FusedRMSNorm
30
  except (ImportError, ModuleNotFoundError):
31
- print("Apex not found. Using nn.RMSNorm")
32
  RMSNorm = nn.RMSNorm
33
 
34
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
+ import logging
4
  from typing import Optional, Tuple, Union
5
 
6
  import torch
 
14
  parallelize_module,
15
  )
16
  from torch.nn.attention.flex_attention import BlockMask, create_block_mask
17
+ from xformers.ops import AttentionBias
18
 
19
  from bytelatent.base_transformer import (
20
  BaseTransformer,
 
23
  )
24
  from bytelatent.model.utils import create_causal_mask
25
 
26
+ logger = logging.getLogger()
27
+
28
  try:
29
  from apex.normalization.fused_layer_norm import FusedRMSNorm
30
 
31
  RMSNorm = FusedRMSNorm
32
  except (ImportError, ModuleNotFoundError):
33
+ logging.debug("Apex not found. Using nn.RMSNorm")
34
  RMSNorm = nn.RMSNorm
35
 
36