from typing import * | |
BACKEND = "flash_attn" | |
DEBUG = False | |
def __from_env(): | |
import os | |
global BACKEND | |
global DEBUG | |
env_attn_backend = os.environ.get("ATTN_BACKEND") | |
env_sttn_debug = os.environ.get("ATTN_DEBUG") | |
if env_attn_backend is not None and env_attn_backend in [ | |
"xformers", | |
"flash_attn", | |
"sdpa", | |
"naive", | |
]: | |
BACKEND = env_attn_backend | |
if env_sttn_debug is not None: | |
DEBUG = env_sttn_debug == "1" | |
print(f"[ATTENTION] Using backend: {BACKEND}") | |
__from_env() | |
def set_backend(backend: Literal["xformers", "flash_attn"]): | |
global BACKEND | |
BACKEND = backend | |
def set_debug(debug: bool): | |
global DEBUG | |
DEBUG = debug | |
from .full_attn import * | |
from .modules import * | |