stacked_tensorial_nn / stnn /linalg_backend.py
caleb2's picture
initial commit
d68c650
"""
Module imports and wrapper functions of the linear algebra backend.
If __usecupy__ is "True" and cupy is successfully imported, then
xp --> cupy
spx --> cupyx.scipy.sparse.linalg
Otherwise,
xp --> numpy
spx --> scipy.sparse
For example, if __usecupy__ is False, then
import numpy as np
import scipy.sparse.linalg as sp
is equivalent to
from stnn.linalg_backend import xp, spx
"""
__usecupy__ = True
try:
# If CuPy is not preferred or available, fall back to NumPy
if not __usecupy__:
raise ImportError
import cupy as cp
import cupyx.scipy.sparse.linalg
import cupyx.scipy.sparse as cupy_sparse
xp = cp
spx = cupy_sparse
using_cupy = True
except ImportError:
import numpy as np
import scipy.sparse.linalg
import scipy.sparse as scipy_sparse
xp = np
spx = scipy_sparse
using_cupy = False
def csr_matrix(L):
"""
Create a CSR (Compressed Sparse Row) matrix.
If CuPy is available and enabled, this function will create a CuPy CSR matrix.
Otherwise, it converts the given data to a SciPy CSR matrix.
Parameters:
L (array_like or sparse matrix): 2-D array or sparse matrix to convert.
Returns:
CSR matrix: The converted CSR matrix, using either CuPy or SciPy.
"""
if using_cupy:
return spx.csr_matrix(L, dtype=xp.float64)
return L.tocsr()
def asnumpy(arr):
"""
Convert an array from the backend library (CuPy or NumPy) to NumPy.
If NumPy is enabled, the input array is returned unchanged.
"""
if using_cupy:
return cp.asnumpy(arr)
return arr
def asarray(arr):
"""
Convert the input to an array of the backend library (CuPy or NumPy).
If NumPy is enabled, the input array is returned unchanged.
"""
if using_cupy:
return cp.asarray(arr, dtype=cp.float64)
return arr