Ink commited on
Commit
caec8d2
·
unverified ·
1 Parent(s): 1da3dd9

allow flex-attention to be disabled (#19)

Browse files

* allow flex-attention to silently fail

* allow flex-attn to be disabled via an env var

bytelatent/base_transformer.py CHANGED
@@ -1,5 +1,5 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
-
3
  from enum import Enum
4
  from typing import Optional, Tuple, Union
5
 
@@ -16,7 +16,10 @@ from xformers.ops import AttentionBias, fmha
16
 
17
  from bytelatent import probe
18
 
19
- flex_attention_comp = torch.compile(flex_attention)
 
 
 
20
 
21
 
22
  class InitStdFactor(Enum):
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
  from enum import Enum
4
  from typing import Optional, Tuple, Union
5
 
 
16
 
17
  from bytelatent import probe
18
 
19
+ if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
20
+ flex_attention_comp = torch.compile(flex_attention)
21
+ else:
22
+ flex_attention_comp = None
23
 
24
 
25
  class InitStdFactor(Enum):
bytelatent/distributed.py CHANGED
@@ -1,5 +1,4 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
-
3
  import atexit
4
  import contextlib
5
  import logging
@@ -48,9 +47,13 @@ default_no_recompute_ops = {
48
  torch.ops.aten._scaled_dot_product_flash_attention.default,
49
  torch.ops.c10d_functional.reduce_scatter_tensor.default,
50
  torch.ops.xformers_flash.flash_fwd.default,
51
- torch.ops.xformers.efficient_attention_forward_cutlass.default,
52
  }
53
 
 
 
 
 
 
54
 
55
  class DistributedArgs(BaseModel):
56
  model_config = ConfigDict(extra="forbid")
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
2
  import atexit
3
  import contextlib
4
  import logging
 
47
  torch.ops.aten._scaled_dot_product_flash_attention.default,
48
  torch.ops.c10d_functional.reduce_scatter_tensor.default,
49
  torch.ops.xformers_flash.flash_fwd.default,
 
50
  }
51
 
52
+ if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
53
+ default_no_recompute_ops.add(
54
+ torch.ops.xformers.efficient_attention_forward_cutlass.default
55
+ )
56
+
57
 
58
  class DistributedArgs(BaseModel):
59
  model_config = ConfigDict(extra="forbid")