File size: 2,897 Bytes
db6a3b7
 
a6bbecf
db6a3b7
a6bbecf
 
db6a3b7
 
 
a6bbecf
db6a3b7
 
 
a6bbecf
 
 
 
db6a3b7
a6bbecf
db6a3b7
a6bbecf
 
 
 
db6a3b7
 
a6bbecf
 
db6a3b7
a6bbecf
690b53e
a6bbecf
db6a3b7
 
 
a6bbecf
 
db6a3b7
 
 
a6bbecf
db6a3b7
 
 
 
a6bbecf
 
db6a3b7
 
a6bbecf
 
db6a3b7
 
 
a6bbecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db6a3b7
 
a6bbecf
db6a3b7
 
 
a6bbecf
db6a3b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6bbecf
db6a3b7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import *

BACKEND = "spconv"
DEBUG = False
ATTN = "flash_attn"


def __from_env():
    import os

    global BACKEND
    global DEBUG
    global ATTN

    env_sparse_backend = os.environ.get("SPARSE_BACKEND")
    env_sparse_debug = os.environ.get("SPARSE_DEBUG")
    env_sparse_attn = os.environ.get("SPARSE_ATTN_BACKEND")
    if env_sparse_attn is None:
        env_sparse_attn = os.environ.get("ATTN_BACKEND")

    if env_sparse_backend is not None and env_sparse_backend in [
        "spconv",
        "torchsparse",
    ]:
        BACKEND = env_sparse_backend
    if env_sparse_debug is not None:
        DEBUG = env_sparse_debug == "1"
    if env_sparse_attn is not None and env_sparse_attn in ["xformers", "flash_attn"]:
        ATTN = env_sparse_attn

    print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")


__from_env()


def set_backend(backend: Literal["spconv", "torchsparse"]):
    global BACKEND
    BACKEND = backend


def set_debug(debug: bool):
    global DEBUG
    DEBUG = debug


def set_attn(attn: Literal["xformers", "flash_attn"]):
    global ATTN
    ATTN = attn


import importlib

__attributes = {
    "SparseTensor": "basic",
    "sparse_batch_broadcast": "basic",
    "sparse_batch_op": "basic",
    "sparse_cat": "basic",
    "sparse_unbind": "basic",
    "SparseGroupNorm": "norm",
    "SparseLayerNorm": "norm",
    "SparseGroupNorm32": "norm",
    "SparseLayerNorm32": "norm",
    "SparseReLU": "nonlinearity",
    "SparseSiLU": "nonlinearity",
    "SparseGELU": "nonlinearity",
    "SparseActivation": "nonlinearity",
    "SparseLinear": "linear",
    "sparse_scaled_dot_product_attention": "attention",
    "SerializeMode": "attention",
    "sparse_serialized_scaled_dot_product_self_attention": "attention",
    "sparse_windowed_scaled_dot_product_self_attention": "attention",
    "SparseMultiHeadAttention": "attention",
    "SparseConv3d": "conv",
    "SparseInverseConv3d": "conv",
    "SparseDownsample": "spatial",
    "SparseUpsample": "spatial",
    "SparseSubdivide": "spatial",
}

__submodules = ["transformer"]

__all__ = list(__attributes.keys()) + __submodules


def __getattr__(name):
    if name not in globals():
        if name in __attributes:
            module_name = __attributes[name]
            module = importlib.import_module(f".{module_name}", __name__)
            globals()[name] = getattr(module, name)
        elif name in __submodules:
            module = importlib.import_module(f".{name}", __name__)
            globals()[name] = module
        else:
            raise AttributeError(f"module {__name__} has no attribute {name}")
    return globals()[name]


# For Pylance
if __name__ == "__main__":
    from .basic import *
    from .norm import *
    from .nonlinearity import *
    from .linear import *
    from .attention import *
    from .conv import *
    from .spatial import *
    import transformer