File size: 1,755 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Module imports and wrapper functions of the linear algebra backend.

If  __usecupy__ is "True" and cupy is successfully imported, then

	xp 		--> cupy
	spx		--> cupyx.scipy.sparse.linalg
	
Otherwise,

	xp 		--> numpy
	spx		--> scipy.sparse

For example, if __usecupy__ is False, then

				import numpy as np
				import scipy.sparse.linalg as sp
				
is equivalent to

				from stnn.linalg_backend import xp, spx
"""
__usecupy__ = True

try:
	# If CuPy is not preferred or available, fall back to NumPy
	if not __usecupy__:
		raise ImportError
	import cupy as cp
	import cupyx.scipy.sparse.linalg
	import cupyx.scipy.sparse as cupy_sparse

	xp = cp
	spx = cupy_sparse
	using_cupy = True
except ImportError:
	import numpy as np
	import scipy.sparse.linalg
	import scipy.sparse as scipy_sparse

	xp = np
	spx = scipy_sparse
	using_cupy = False


def csr_matrix(L):
	"""
	Create a CSR (Compressed Sparse Row) matrix.

	If CuPy is available and enabled, this function will create a CuPy CSR matrix.
	Otherwise, it converts the given data to a SciPy CSR matrix.

	Parameters:
	L (array_like or sparse matrix): 2-D array or sparse matrix to convert.

	Returns:
	CSR matrix: The converted CSR matrix, using either CuPy or SciPy.
	"""
	if using_cupy:
		return spx.csr_matrix(L, dtype=xp.float64)
	return L.tocsr()


def asnumpy(arr):
	"""
	Convert an array from the backend library (CuPy or NumPy) to NumPy.
	If NumPy is enabled, the input array is returned unchanged.
	"""
	if using_cupy:
		return cp.asnumpy(arr)
	return arr


def asarray(arr):
	"""
	Convert the input to an array of the backend library (CuPy or NumPy).
	If NumPy is enabled, the input array is returned unchanged.
	"""
	if using_cupy:
		return cp.asarray(arr, dtype=cp.float64)
	return arr