File size: 2,897 Bytes
db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf 690b53e a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 a6bbecf db6a3b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from typing import *
BACKEND = "spconv"
DEBUG = False
ATTN = "flash_attn"
def __from_env():
import os
global BACKEND
global DEBUG
global ATTN
env_sparse_backend = os.environ.get("SPARSE_BACKEND")
env_sparse_debug = os.environ.get("SPARSE_DEBUG")
env_sparse_attn = os.environ.get("SPARSE_ATTN_BACKEND")
if env_sparse_attn is None:
env_sparse_attn = os.environ.get("ATTN_BACKEND")
if env_sparse_backend is not None and env_sparse_backend in [
"spconv",
"torchsparse",
]:
BACKEND = env_sparse_backend
if env_sparse_debug is not None:
DEBUG = env_sparse_debug == "1"
if env_sparse_attn is not None and env_sparse_attn in ["xformers", "flash_attn"]:
ATTN = env_sparse_attn
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
__from_env()
def set_backend(backend: Literal["spconv", "torchsparse"]):
global BACKEND
BACKEND = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
def set_attn(attn: Literal["xformers", "flash_attn"]):
global ATTN
ATTN = attn
import importlib
__attributes = {
"SparseTensor": "basic",
"sparse_batch_broadcast": "basic",
"sparse_batch_op": "basic",
"sparse_cat": "basic",
"sparse_unbind": "basic",
"SparseGroupNorm": "norm",
"SparseLayerNorm": "norm",
"SparseGroupNorm32": "norm",
"SparseLayerNorm32": "norm",
"SparseReLU": "nonlinearity",
"SparseSiLU": "nonlinearity",
"SparseGELU": "nonlinearity",
"SparseActivation": "nonlinearity",
"SparseLinear": "linear",
"sparse_scaled_dot_product_attention": "attention",
"SerializeMode": "attention",
"sparse_serialized_scaled_dot_product_self_attention": "attention",
"sparse_windowed_scaled_dot_product_self_attention": "attention",
"SparseMultiHeadAttention": "attention",
"SparseConv3d": "conv",
"SparseInverseConv3d": "conv",
"SparseDownsample": "spatial",
"SparseUpsample": "spatial",
"SparseSubdivide": "spatial",
}
__submodules = ["transformer"]
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == "__main__":
from .basic import *
from .norm import *
from .nonlinearity import *
from .linear import *
from .attention import *
from .conv import *
from .spatial import *
import transformer
|