File size: 5,719 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import logging as pylogger
import operator
import os

from nemo.utils import model_utils

# Prevent Numba CUDA logs from showing at info level
cuda_logger = pylogger.getLogger('numba.cuda.cudadrv.driver')
cuda_logger.setLevel(pylogger.ERROR)  # only show error

__NUMBA_DEFAULT_MINIMUM_VERSION__ = "0.53.0"
__NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__)


NUMBA_INSTALLATION_MESSAGE = (
    "Could not import `numba`.\n"
    "Please install numba in one of the following ways."
    "1) If using conda, simply install it with conda using `conda install -c numba numba`\n"
    "2) If using pip (not recommended), `pip install --upgrade numba`\n"
    "followed by `export NUMBAPRO_LIBDEVICE='/usr/local/cuda/nvvm/libdevice/'` and \n"
    "`export NUMBAPRO_NVVM='/usr/local/cuda/nvvm/lib64/libnvvm.so'`.\n"
    "It is advised to always install numba using conda only, "
    "as pip installations might interfere with other libraries such as llvmlite.\n"
    "If pip install does not work, you can also try adding `--ignore-installed` to the pip command,\n"
    "but this is not advised."
)

STRICT_NUMBA_COMPAT_CHECK = True

# Get environment key if available
if 'STRICT_NUMBA_COMPAT_CHECK' in os.environ:
    check_str = os.environ.get('STRICT_NUMBA_COMPAT_CHECK')
    check_bool = str(check_str).lower() in ("yes", "true", "t", "1")
    STRICT_NUMBA_COMPAT_CHECK = check_bool


def is_numba_compat_strict() -> bool:
    """
    Returns strictness level of numba cuda compatibility checks.

    If value is true, numba cuda compatibility matrix must be satisfied.
    If value is false, only cuda availability is checked, not compatibility.
    Numba Cuda may still compile and run without issues in such a case, or it may fail.
    """
    return STRICT_NUMBA_COMPAT_CHECK


def set_numba_compat_strictness(strict: bool):
    """
    Sets the strictness level of numba cuda compatibility checks.

    If value is true, numba cuda compatibility matrix must be satisfied.
    If value is false, only cuda availability is checked, not compatibility.
    Numba Cuda may still compile and run without issues in such a case, or it may fail.

    Args:
        strict: bool value, whether to enforce strict compatibility checks or relax them.
    """
    global STRICT_NUMBA_COMPAT_CHECK
    STRICT_NUMBA_COMPAT_CHECK = strict


@contextlib.contextmanager
def with_numba_compat_strictness(strict: bool):
    initial_strictness = is_numba_compat_strict()
    set_numba_compat_strictness(strict=strict)
    yield
    set_numba_compat_strictness(strict=initial_strictness)


def numba_cpu_is_supported(min_version: str) -> bool:
    """
    Tests if an appropriate version of numba is installed.

    Args:
        min_version: The minimum version of numba that is required.

    Returns:
        bool, whether numba CPU supported with this current installation or not.
    """
    module_available, msg = model_utils.check_lib_version('numba', checked_version=min_version, operator=operator.ge)

    # If numba is not installed
    if module_available is None:
        return False
    else:
        return True


def numba_cuda_is_supported(min_version: str) -> bool:
    """
    Tests if an appropriate version of numba is installed, and if it is,
    if cuda is supported properly within it.
    
    Args:
        min_version: The minimum version of numba that is required.

    Returns:
        bool, whether cuda is supported with this current installation or not.
    """
    module_available = numba_cpu_is_supported(min_version)

    # If numba is not installed
    if module_available is None:
        return False

    # If numba version is installed and available
    if module_available is True:
        from numba import cuda

        # this method first arrived in 0.53, and that's the minimum version required
        if hasattr(cuda, 'is_supported_version'):
            try:
                cuda_available = cuda.is_available()
                if cuda_available:
                    cuda_compatible = cuda.is_supported_version()
                else:
                    cuda_compatible = False

                if is_numba_compat_strict():
                    return cuda_available and cuda_compatible
                else:
                    return cuda_available

            except OSError:
                # dlopen(libcudart.dylib) might fail if CUDA was never installed in the first place.
                return False
        else:
            # assume cuda is supported, but it may fail due to CUDA incompatibility
            return False

    else:
        return False


def skip_numba_cuda_test_if_unsupported(min_version: str):
    """
    Helper method to skip pytest test case if numba cuda is not supported.
    
    Args:
        min_version: The minimum version of numba that is required.
    """
    numba_cuda_support = numba_cuda_is_supported(min_version)
    if not numba_cuda_support:
        import pytest

        pytest.skip(f"Numba cuda test is being skipped. Minimum version required : {min_version}")