Spaces:
Sleeping
Sleeping
initial commit
Browse files- README.md +8 -1
- T5_config.json +40 -0
- T5_weights.h5 +3 -0
- app.py +316 -0
- requirements.txt +10 -0
- stnn/__init__.py +0 -0
- stnn/__pycache__/__init__.cpython-311.pyc +0 -0
- stnn/data/__init__.py +0 -0
- stnn/data/function_generators.py +239 -0
- stnn/data/preprocessing.py +312 -0
- stnn/data/test_functions.py +197 -0
- stnn/linalg_backend.py +81 -0
- stnn/nn/__init__.py +0 -0
- stnn/nn/stnn.py +66 -0
- stnn/nn/stnn_layers.py +274 -0
- stnn/pde/__init__.py +0 -0
- stnn/pde/__pycache__/__init__.cpython-311.pyc +0 -0
- stnn/pde/circle.py +156 -0
- stnn/pde/common.py +91 -0
- stnn/pde/ellipse.py +154 -0
- stnn/pde/pde_system.py +223 -0
- stnn/tests/test_circle.py +63 -0
- stnn/tests/test_dependencies.py +25 -0
- stnn/tests/test_differential_ops.py +115 -0
- stnn/tests/test_ellipse.py +67 -0
- stnn/tests/test_file.py +121 -0
- stnn/tests/test_pde_system.py +54 -0
- stnn/tests/test_preprocessing.py +99 -0
- stnn/tests/test_stats.py +41 -0
- stnn/tests/test_stnn_config.py +61 -0
- stnn/tests/test_ttl.py +51 -0
- stnn/utils/__init__.py +0 -0
- stnn/utils/input_output.py +135 -0
- stnn/utils/network_visualization.py +121 -0
- stnn/utils/plotting.py +93 -0
- stnn/utils/stats.py +27 -0
README.md
CHANGED
@@ -10,4 +10,11 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
## Stacked Tensorial Neural Network (STNN) demo
|
14 |
+
This demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979)
|
15 |
+
to solve a parametric PDE problem on an elliptical annular domain. See the paper for a
|
16 |
+
detailed description of the problem and its applications.
|
17 |
+
|
18 |
+
The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, including
|
19 |
+
intructions for solving the PDE using a conventional iterative method (GMRES). Due to the long runtime of
|
20 |
+
solving the PDE in this way, it is not included in the demo.
|
T5_config.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"K": 20,
|
3 |
+
"d": 8,
|
4 |
+
"W": 2,
|
5 |
+
"ranks": [
|
6 |
+
1,
|
7 |
+
16,
|
8 |
+
16,
|
9 |
+
16,
|
10 |
+
16,
|
11 |
+
16,
|
12 |
+
7,
|
13 |
+
1
|
14 |
+
],
|
15 |
+
"shape1": [
|
16 |
+
4,
|
17 |
+
4,
|
18 |
+
4,
|
19 |
+
4,
|
20 |
+
4,
|
21 |
+
4,
|
22 |
+
4
|
23 |
+
],
|
24 |
+
"shape2": [
|
25 |
+
4,
|
26 |
+
2,
|
27 |
+
2,
|
28 |
+
2,
|
29 |
+
2,
|
30 |
+
2,
|
31 |
+
2
|
32 |
+
],
|
33 |
+
"nx1": 256,
|
34 |
+
"nx2": 64,
|
35 |
+
"nx3": 32,
|
36 |
+
"ell_min": 0.01,
|
37 |
+
"ell_max": 100.0,
|
38 |
+
"a2_min": 2.0,
|
39 |
+
"a2_max": 20.0
|
40 |
+
}
|
T5_weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51b60beefbbd6e5c8ade77d6f4f64c54146f34dc79e2b91605e0864ecbfcbd07
|
3 |
+
size 957976
|
app.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import gradio as gr
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib.ticker as ticker
|
5 |
+
import numpy as np
|
6 |
+
import sympy
|
7 |
+
from matplotlib.cm import get_cmap
|
8 |
+
from stnn.nn import stnn
|
9 |
+
from stnn.pde.pde_system import PDESystem
|
10 |
+
|
11 |
+
|
12 |
+
def adjust_to_nice_number(value, round_down = False):
|
13 |
+
"""
|
14 |
+
Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks.
|
15 |
+
"""
|
16 |
+
if value == 0:
|
17 |
+
return value
|
18 |
+
|
19 |
+
is_negative = False
|
20 |
+
if value < 0:
|
21 |
+
round_down = True
|
22 |
+
is_negative = True
|
23 |
+
value = -value
|
24 |
+
exponent = np.floor(np.log10(value)) # Find exponent of 10
|
25 |
+
fractional_part = value / 10**exponent # Find leading digit(s)
|
26 |
+
|
27 |
+
if round_down:
|
28 |
+
if fractional_part < 1.5:
|
29 |
+
nice_fractional = 1
|
30 |
+
elif fractional_part < 3:
|
31 |
+
nice_fractional = 2
|
32 |
+
elif fractional_part < 7:
|
33 |
+
nice_fractional = 5
|
34 |
+
else:
|
35 |
+
nice_fractional = 10
|
36 |
+
else:
|
37 |
+
if fractional_part <= 1:
|
38 |
+
nice_fractional = 1
|
39 |
+
elif fractional_part <= 2:
|
40 |
+
nice_fractional = 2
|
41 |
+
elif fractional_part <= 5:
|
42 |
+
nice_fractional = 5
|
43 |
+
else:
|
44 |
+
nice_fractional = 10
|
45 |
+
|
46 |
+
nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1)
|
47 |
+
if is_negative:
|
48 |
+
nice_value = -nice_value
|
49 |
+
return nice_value
|
50 |
+
|
51 |
+
|
52 |
+
def find_nice_values(min_val_raw, max_val, num_values = 4):
|
53 |
+
"""
|
54 |
+
Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks.
|
55 |
+
"""
|
56 |
+
# Calculate rough spacing between values
|
57 |
+
min_val = adjust_to_nice_number(min_val_raw)
|
58 |
+
frac_val = (min_val - min_val_raw) / (max_val - min_val_raw)
|
59 |
+
if frac_val < 1 / num_values:
|
60 |
+
min_val = min_val_raw
|
61 |
+
raw_spacing = (max_val - min_val) / (num_values - 1)
|
62 |
+
|
63 |
+
# Calculate order of magnitude of the spacing
|
64 |
+
magnitude = np.floor(np.log10(raw_spacing))
|
65 |
+
|
66 |
+
nice_factors = np.array([1, 2, 5, 10])
|
67 |
+
normalized_spacing = raw_spacing / (10**magnitude)
|
68 |
+
closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))]
|
69 |
+
nice_spacing = closest_factor * (10**magnitude)
|
70 |
+
|
71 |
+
nice_values = min_val + nice_spacing * np.arange(num_values)
|
72 |
+
|
73 |
+
# Adjust if last value exceeds max_val
|
74 |
+
if nice_values[-1] < max_val - nice_spacing:
|
75 |
+
last_val = nice_values[-1]
|
76 |
+
nice_values = np.append(nice_values, [last_val + nice_spacing])
|
77 |
+
|
78 |
+
return [val for val in nice_values if min_val <= val <= max_val]
|
79 |
+
|
80 |
+
|
81 |
+
def format_tick_label(val):
|
82 |
+
"""
|
83 |
+
Format w/ scientific notation for large/small values.
|
84 |
+
"""
|
85 |
+
if val != 0:
|
86 |
+
magnitude = np.abs(np.floor(np.log10(np.abs(val))))
|
87 |
+
if magnitude > 2:
|
88 |
+
return f'{val:.1e}'
|
89 |
+
elif magnitude > 1:
|
90 |
+
return f'{val:.0f}'
|
91 |
+
elif magnitude > 0:
|
92 |
+
return f'{val:.1f}'
|
93 |
+
else:
|
94 |
+
return f'{val:.2f}'
|
95 |
+
else:
|
96 |
+
return f'{val}'
|
97 |
+
|
98 |
+
|
99 |
+
def plot_simple(system, rho, fontscale = 1):
|
100 |
+
# Major axis of outer boundary
|
101 |
+
b2 = system.b2
|
102 |
+
|
103 |
+
# Get x, y grids from 'PDESystem' object
|
104 |
+
x, y = system.get_xy_grids()
|
105 |
+
|
106 |
+
# wrap around values for continuity
|
107 |
+
rho = np.append(rho, rho[:, 0:1], axis = 1)
|
108 |
+
|
109 |
+
# Color bar limits
|
110 |
+
vmin = np.nanmin(rho)
|
111 |
+
vmax = np.nanmax(rho)
|
112 |
+
|
113 |
+
fig = plt.figure(figsize = (5, 5))
|
114 |
+
ax = plt.gca()
|
115 |
+
|
116 |
+
im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
|
117 |
+
ax.set_title('rho(x,y)', fontsize = fontscale * 16)
|
118 |
+
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
119 |
+
label.set_fontsize(fontscale * 12)
|
120 |
+
ax.set_aspect(1.0)
|
121 |
+
fac = 1.05
|
122 |
+
ax.set_xlim([-fac * b2, fac * b2])
|
123 |
+
ax.set_ylim([-fac * b2, fac * b2])
|
124 |
+
|
125 |
+
cbar = fig.colorbar(im, shrink = 0.8)
|
126 |
+
|
127 |
+
# Set colorbar ticks and labels to "nice" values
|
128 |
+
nice_values = find_nice_values(vmin, vmax, num_values = 5)
|
129 |
+
cbar.set_ticks(nice_values)
|
130 |
+
cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x)))
|
131 |
+
|
132 |
+
return fig
|
133 |
+
|
134 |
+
|
135 |
+
def evaluate_2d_expression(expr_str, xvals, yvals):
|
136 |
+
x, y = sympy.symbols('s t')
|
137 |
+
expr = sympy.sympify(expr_str)
|
138 |
+
f = sympy.lambdify((x, y), expr, modules = ['numpy'])
|
139 |
+
result = f(xvals, yvals)
|
140 |
+
if isinstance(result, (int, float)):
|
141 |
+
return result * np.ones(xvals.shape)
|
142 |
+
return f(xvals, yvals)
|
143 |
+
|
144 |
+
|
145 |
+
'''
|
146 |
+
# Currently unused in gradio interface
|
147 |
+
def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations):
|
148 |
+
# Direct solution
|
149 |
+
start = timeit.default_timer()
|
150 |
+
|
151 |
+
pde_config = {}
|
152 |
+
for key in ['nx1', 'nx2', 'nx3']:
|
153 |
+
pde_config[key] = stnn_config[key]
|
154 |
+
pde_config['ell'] = ell
|
155 |
+
pde_config['eccentricity'] = eccentricity
|
156 |
+
pde_config['a2'] = a2
|
157 |
+
system = PDESystem(pde_config)
|
158 |
+
|
159 |
+
try:
|
160 |
+
ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
|
161 |
+
except:
|
162 |
+
raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
|
163 |
+
try:
|
164 |
+
obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
|
165 |
+
except:
|
166 |
+
raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")
|
167 |
+
|
168 |
+
if np.any(np.isnan(ibf_data)):
|
169 |
+
raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.")
|
170 |
+
|
171 |
+
if np.any(np.isnan(obf_data)):
|
172 |
+
raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.")
|
173 |
+
|
174 |
+
ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)
|
175 |
+
|
176 |
+
L_xp = csr_matrix(system.L) # Sparse matrix representation of the PDE operator
|
177 |
+
nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
|
178 |
+
b_xp = asarray(b.reshape((nx1 * nx2 * nx3,))) # r.h.s. vector
|
179 |
+
|
180 |
+
def callback(res):
|
181 |
+
print(f'GMRES residual: {res}')
|
182 |
+
|
183 |
+
f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback)
|
184 |
+
|
185 |
+
residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp))
|
186 |
+
|
187 |
+
if info > 0:
|
188 |
+
warnings.simplefilter('always')
|
189 |
+
warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning)
|
190 |
+
|
191 |
+
f = asnumpy(f_xp)
|
192 |
+
rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1)
|
193 |
+
direct_time = timeit.default_timer() - start
|
194 |
+
print(f'Done with direct solution. Time: {direct_time} seconds.')
|
195 |
+
|
196 |
+
fig = plot_simple(system, rho_direct)
|
197 |
+
return fig, info
|
198 |
+
'''
|
199 |
+
|
200 |
+
|
201 |
+
def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str):
|
202 |
+
if a2 <= eccentricity:
|
203 |
+
raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).')
|
204 |
+
|
205 |
+
pde_config = {}
|
206 |
+
for key in ['nx1', 'nx2', 'nx3']:
|
207 |
+
pde_config[key] = stnn_config[key]
|
208 |
+
pde_config['ell'] = ell
|
209 |
+
pde_config['eccentricity'] = eccentricity
|
210 |
+
pde_config['a2'] = a2
|
211 |
+
system = PDESystem(pde_config)
|
212 |
+
|
213 |
+
try:
|
214 |
+
ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
|
215 |
+
except:
|
216 |
+
raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
|
217 |
+
try:
|
218 |
+
obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
|
219 |
+
except:
|
220 |
+
raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")
|
221 |
+
|
222 |
+
if np.any(np.isnan(ibf_data)):
|
223 |
+
raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.")
|
224 |
+
|
225 |
+
if np.any(np.isnan(obf_data)):
|
226 |
+
raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.")
|
227 |
+
|
228 |
+
# Permute and reshape boundary data to the format expected by the STNN model
|
229 |
+
ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)
|
230 |
+
|
231 |
+
'''
|
232 |
+
# Currently unused in gradio interface
|
233 |
+
ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id)
|
234 |
+
'''
|
235 |
+
|
236 |
+
# Load some relevant quantities from the config dictionaries
|
237 |
+
ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max']
|
238 |
+
a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max']
|
239 |
+
nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3']
|
240 |
+
|
241 |
+
# Combine boundary data in single vector
|
242 |
+
bf = np.zeros((1, 2 * nx2, nx3 // 2))
|
243 |
+
bf[:, :nx2, :] = ibf_data[np.newaxis, ...]
|
244 |
+
bf[:, nx2:, :] = obf_data[np.newaxis, ...]
|
245 |
+
|
246 |
+
# Normalize and combine parameters
|
247 |
+
params = np.zeros((1, 3))
|
248 |
+
params[0, 0] = (a2 - a2_min) / (a2_max - a2_min)
|
249 |
+
params[0, 1] = (ell - ell_min) / (ell_max - ell_min)
|
250 |
+
params[0, 2] = eccentricity
|
251 |
+
|
252 |
+
rho = model.predict([params, bf])
|
253 |
+
fig = plot_simple(system, rho[0, ...])
|
254 |
+
return fig
|
255 |
+
|
256 |
+
|
257 |
+
with open('T5_config.json', 'r', encoding = 'utf-8') as json_file:
|
258 |
+
stnn_config = json.load(json_file)
|
259 |
+
|
260 |
+
model = stnn.build_stnn(stnn_config)
|
261 |
+
model.load_weights('T5_weights.h5')
|
262 |
+
|
263 |
+
with gr.Blocks() as demo:
|
264 |
+
gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo"
|
265 |
+
"\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) "
|
266 |
+
"to solve a parametric PDE problem on an elliptical annular domain. "
|
267 |
+
"See the paper for a detailed description of the problem and its applications."
|
268 |
+
"<br/>The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, "
|
269 |
+
"including intructions for solving the PDE using a conventional iterative method (GMRES). "
|
270 |
+
"Due to the long runtime of solving the PDE in this way, it is not included in the demo.")
|
271 |
+
gr.Markdown("<br/>The PDE is "
|
272 |
+
"$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, "
|
273 |
+
"where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. "
|
274 |
+
"Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to "
|
275 |
+
"the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! "
|
276 |
+
"on elliptical annular domains parameterized as shown below. ",
|
277 |
+
latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}])
|
278 |
+
with gr.Row():
|
279 |
+
with gr.Column():
|
280 |
+
gr.Markdown(
|
281 |
+
"## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, "
|
282 |
+
"and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False},
|
283 |
+
{"left": "!", "right": "!", "display": True}])
|
284 |
+
ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0)
|
285 |
+
eccentricity_input = gr.Number(
|
286 |
+
label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)",
|
287 |
+
value = 0.5, minimum = 0.0, maximum = 0.999)
|
288 |
+
a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0)
|
289 |
+
gr.Markdown(
|
290 |
+
"## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, "
|
291 |
+
"related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. "
|
292 |
+
"Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values "
|
293 |
+
"between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points "
|
294 |
+
"distributed uniformly over the allowable values of $s$ and $t$."
|
295 |
+
"<br/><br/>For the PDE problem to be well-posed, boundary data should only be specified where "
|
296 |
+
"$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the "
|
297 |
+
"inward-pointing unit normal vector. This requirement constrains the allowable values of $t$."
|
298 |
+
" and is automatically enforced when building boundary conditions from the user-specified expressions below.",
|
299 |
+
latex_delimiters = [{"left": "$", "right": "$", "display": False}])
|
300 |
+
|
301 |
+
inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))")
|
302 |
+
outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)")
|
303 |
+
|
304 |
+
submit_button = gr.Button("Submit")
|
305 |
+
|
306 |
+
with gr.Column():
|
307 |
+
gr.Markdown("## Predicted Solution")
|
308 |
+
predicted_output_plot = gr.Plot()
|
309 |
+
|
310 |
+
submit_button.click(
|
311 |
+
fn = predict_pde_solution,
|
312 |
+
inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary],
|
313 |
+
outputs = [predicted_output_plot]
|
314 |
+
)
|
315 |
+
|
316 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow>=2.15.0
|
2 |
+
numpy~=1.26.0
|
3 |
+
t3f~=1.2.0
|
4 |
+
scipy~=1.12.0
|
5 |
+
h5py~=3.10.0
|
6 |
+
matplotlib~=3.8.2
|
7 |
+
pydot~=1.4.2
|
8 |
+
openvino~=2023.3.0
|
9 |
+
pyyaml>=6.0.1
|
10 |
+
sympy
|
stnn/__init__.py
ADDED
File without changes
|
stnn/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (164 Bytes). View file
|
|
stnn/data/__init__.py
ADDED
File without changes
|
stnn/data/function_generators.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def generate_piecewise_linear_function(num_pieces, lower, upper, delta = 0.3):
|
5 |
+
"""
|
6 |
+
Generates a piece-wise linear function on the interval (lower, upper) that is 2*pi-periodic.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
num_pieces (int): Number of linear pieces in the function.
|
10 |
+
lower (float): The lower range of the interval
|
11 |
+
upper (float): The upper range of the interval
|
12 |
+
delta (float): Parameter determining how rapidly the function varies between the grid points (i.e., modulates
|
13 |
+
the slopes of the piecewise functions). Larger values mean more variability. Default is 0.3.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
function: A piece-wise linear function.
|
17 |
+
"""
|
18 |
+
# Generate equally spaced points in the interval (-pi, pi)
|
19 |
+
x_points = np.linspace(lower, upper, num_pieces + 1)
|
20 |
+
|
21 |
+
# Generate random y-values for each point
|
22 |
+
y_points = np.zeros(num_pieces + 1)
|
23 |
+
y_points[0] = np.random.uniform(-1, 1)
|
24 |
+
for n in range(1, y_points.shape[0]):
|
25 |
+
y_points[n] = y_points[n - 1] + 0.3 * np.random.uniform(-1, 1)
|
26 |
+
y_points[0] = y_points[-1]
|
27 |
+
min_y = y_points.min()
|
28 |
+
# ensure y values are nonegative
|
29 |
+
if min_y < 0:
|
30 |
+
y_points -= min_y
|
31 |
+
y_points += np.random.uniform(0, 0.5) # random (constant) offset
|
32 |
+
|
33 |
+
def piecewise_linear(x):
|
34 |
+
"""
|
35 |
+
Evaluates the piece-wise linear function at a given x.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
x (float): The x-coordinate at which to evaluate the function.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
float: The y-coordinate of the function at x.
|
42 |
+
"""
|
43 |
+
for i in range(num_pieces):
|
44 |
+
if x_points[i] <= x < x_points[i + 1]:
|
45 |
+
# Linear interpolation between the two points
|
46 |
+
slope = (y_points[i + 1] - y_points[i]) / (x_points[i + 1] - x_points[i])
|
47 |
+
return slope * (x - x_points[i]) + y_points[i]
|
48 |
+
return y_points[0] # For x = pi
|
49 |
+
|
50 |
+
return piecewise_linear
|
51 |
+
|
52 |
+
|
53 |
+
def generate_piecewise_constant_function(num_pieces, lower, upper):
|
54 |
+
"""
|
55 |
+
Generates a piece-wise constant function on the interval (lower, upper) that is 2*pi periodic.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
num_pieces (int): Number of constant pieces in the function.
|
59 |
+
lower (float): The lower range of the interval
|
60 |
+
upper (float): The upper range of the interval
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
function: A piece-wise constant function.
|
64 |
+
"""
|
65 |
+
# Generate equally spaced points in the interval (-pi, pi)
|
66 |
+
x_points = np.linspace(lower, upper, num_pieces + 1)
|
67 |
+
|
68 |
+
# Generate random y-values for each constant piece
|
69 |
+
y_values = np.random.rand(num_pieces) * 2 - 0 # Random values between 0 and 1
|
70 |
+
|
71 |
+
# Ensure the function is 2*pi periodic
|
72 |
+
y_values = np.append(y_values, y_values[0])
|
73 |
+
|
74 |
+
def piecewise_constant(x):
|
75 |
+
"""
|
76 |
+
Evaluates the piece-wise constant function at a given x.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
x (float): The x-coordinate at which to evaluate the function.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
float: The y-coordinate of the function at x.
|
83 |
+
"""
|
84 |
+
for i in range(num_pieces):
|
85 |
+
if x_points[i] <= x < x_points[i + 1]:
|
86 |
+
return y_values[i]
|
87 |
+
return y_values[0] # For x = pi
|
88 |
+
|
89 |
+
return piecewise_constant
|
90 |
+
|
91 |
+
|
92 |
+
def generate_piecewise_bc(x2_grid, x3_grid, num_pieces):
|
93 |
+
"""
|
94 |
+
Generates a piecewise linear function on the domain defined by x2_grid and x3_grid.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
x2_grid (numpy.ndarray): A 2D array, x2 grid values
|
98 |
+
x3_grid (numpy.ndarray): A 2D array, x3 grid values
|
99 |
+
num_pieces (int): The number of pieces in the piecewise linear function.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
numpy.ndarray: A 2D array representing the piecewise linear function
|
103 |
+
"""
|
104 |
+
x2_fun = generate_piecewise_linear_function(num_pieces, lower = x2_grid.min(), upper = x2_grid.max())
|
105 |
+
x3_fun = generate_piecewise_linear_function(num_pieces, lower = x3_grid.min(), upper = x3_grid.max())
|
106 |
+
|
107 |
+
x2vals_1d = np.zeros_like(x2_grid[:, 0])
|
108 |
+
x3vals_1d = np.zeros_like(x3_grid[0, :])
|
109 |
+
for i in range(x2vals_1d.shape[0]):
|
110 |
+
x2vals_1d[i] = x2_fun(x2_grid[i, 0])
|
111 |
+
for i in range(x3vals_1d.shape[0]):
|
112 |
+
x3vals_1d[i] = x3_fun(x3_grid[0, i])
|
113 |
+
x2vals_2d, x3vals_2d = np.meshgrid(x2vals_1d, x3vals_1d, indexing = 'ij')
|
114 |
+
return x2vals_2d * x3vals_2d
|
115 |
+
|
116 |
+
|
117 |
+
def random_2d_gaussian(theta, phi):
|
118 |
+
"""
|
119 |
+
Generates a 2D Gaussian G(x,y), where
|
120 |
+
x = np.cos(0.5 * freq_x * theta - phase_x)
|
121 |
+
y = np.cos(0.5 * freq_y * phi - phase_y)
|
122 |
+
Here, the frequencies and phases are randomly sampled, and (theta, phi) define a 2D meshgrid.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
theta (numpy.ndarray): 2D array, meshgrid of the first coordinate
|
126 |
+
phi (numpy.ndarray): 2D array, meshgrid of the second coordinate
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
numpy.ndarray: A 2D array representing the values of the Gaussian on the grid.
|
130 |
+
"""
|
131 |
+
phase_x = np.random.uniform(0, 2 * np.pi)
|
132 |
+
phase_y = np.random.uniform(0, 2 * np.pi)
|
133 |
+
freq_x = np.random.randint(1, 2)
|
134 |
+
freq_y = np.random.randint(1, 2)
|
135 |
+
|
136 |
+
x = np.cos(0.5 * freq_x * theta - phase_x)
|
137 |
+
y = np.cos(0.5 * freq_y * phi - phase_y)
|
138 |
+
|
139 |
+
sigma_x = np.random.uniform(0.1, 3.0)
|
140 |
+
sigma_y = np.random.uniform(0.1, 1.0)
|
141 |
+
rho = 0
|
142 |
+
|
143 |
+
covariance_matrix = np.array([[sigma_x**2, rho * sigma_x * sigma_y],
|
144 |
+
[rho * sigma_x * sigma_y, sigma_y**2]])
|
145 |
+
inv_sigma_xx = 1.0 / sigma_x**2
|
146 |
+
inv_sigma_yy = 1.0 / sigma_y**2
|
147 |
+
inv_sigma_xy = -rho / (sigma_x * sigma_y)
|
148 |
+
|
149 |
+
if np.any(np.linalg.eigvals(covariance_matrix) < 0):
|
150 |
+
raise ValueError('Covariance matrix is not positive semi-definite.')
|
151 |
+
|
152 |
+
def gaussian_2d(x, y):
|
153 |
+
return np.exp(-0.5 * (inv_sigma_xx * x**2 + inv_sigma_yy * y**2 + 2 * inv_sigma_xy * x * y))
|
154 |
+
|
155 |
+
gaussian_values = gaussian_2d(x, y)
|
156 |
+
return gaussian_values
|
157 |
+
|
158 |
+
|
159 |
+
def generate_random_functions(N, X, Y, num_terms = 16, min_freq = 1, max_freq = 16, func_gen_id = 0):
|
160 |
+
"""
|
161 |
+
Generates N random 2pi-periodic functions on a 2D grid as a Fourier series, with different types of
|
162 |
+
modulation applied to the amplitudes.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
N (int): Number of functions to generate.
|
166 |
+
X (numpy.ndarray): 2D array representing the values of the first coordinate on the grid
|
167 |
+
Y (numpy.ndarray): 2D array representing the values of the second coordinate on the grid
|
168 |
+
num_terms (int, optional): Number of terms in the Fourier series expansion. Default is 16.
|
169 |
+
min_freq (int, optional): Minimum frequency for the Fourier series terms. Default is 1.
|
170 |
+
max_freq (int, optional): Maximum frequency for the Fourier series terms. Default is 16.
|
171 |
+
func_gen_id (int, optional): Type of function to generate based on the decay of the expansion coefficients
|
172 |
+
as frequency is increased. Values can range from -1 to 4. Default is 0.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
numpy.ndarray: A 3D numpy array of shape (N, nx, ny) containing the function values.
|
176 |
+
|
177 |
+
Raises:
|
178 |
+
ValueError: If max_freq is less than min_freq or if an invalid func_gen_id is provided.
|
179 |
+
"""
|
180 |
+
|
181 |
+
# Check if the maximum frequency is less than the minimum frequency
|
182 |
+
if max_freq < min_freq:
|
183 |
+
raise ValueError('max_freq cannot be less than min_freq')
|
184 |
+
|
185 |
+
# Generate uniformly distributed functions if func_gen_id is -1
|
186 |
+
if func_gen_id == -1:
|
187 |
+
F_batch = np.random.uniform(0, 1, size = (N,) + X.shape)
|
188 |
+
return F_batch
|
189 |
+
|
190 |
+
# Initialize the batch of functions with zeros
|
191 |
+
F_batch = np.zeros((N,) + X.shape)
|
192 |
+
|
193 |
+
# Loop through each function to be generated
|
194 |
+
for n in range(N):
|
195 |
+
# Add a cosine term with a half frequency with 20% chance
|
196 |
+
if np.random.uniform(0, 1) < 0.2:
|
197 |
+
amp_cos_half = np.random.uniform(0, 1) # Amplitude for cosine term
|
198 |
+
phase_cos_half = np.random.uniform(0, 2 * np.pi) # Phase shift for cosine term
|
199 |
+
F_batch[n] += amp_cos_half * np.cos(0.5 * X - phase_cos_half)
|
200 |
+
|
201 |
+
# Fourier series
|
202 |
+
for _ in range(num_terms):
|
203 |
+
amplitude = np.random.uniform(-1, 1) # Random amplitude for y-component
|
204 |
+
kx, ky = np.random.randint(min_freq, max_freq + 1, 2) # Frequencies for x and y components
|
205 |
+
phase_x = np.random.uniform(0, 2 * np.pi) # Phase shift for x-component
|
206 |
+
phase_y = np.random.uniform(0, 2 * np.pi) # Phase shift for y-component
|
207 |
+
|
208 |
+
# Determine the coefficient amplitude based on the func_gen_id
|
209 |
+
if func_gen_id == 0:
|
210 |
+
# No decay applied to amplitude
|
211 |
+
pass
|
212 |
+
elif func_gen_id == 1:
|
213 |
+
if np.random.uniform(0, 1) < 0.5:
|
214 |
+
amplitude = amplitude / kx
|
215 |
+
else:
|
216 |
+
amplitude = amplitude / ky
|
217 |
+
elif func_gen_id == 2:
|
218 |
+
amplitude = amplitude / (kx * ky)
|
219 |
+
elif func_gen_id == 3:
|
220 |
+
amplitude = amplitude / (kx * kx * ky * ky)
|
221 |
+
elif func_gen_id == 4:
|
222 |
+
# Gaussian decay with random covariance matrix
|
223 |
+
sxx = np.random.uniform(0.1, 1.0)
|
224 |
+
syy = np.random.uniform(0.1, 1.0)
|
225 |
+
sxy = np.random.uniform(0.1, 1.0)
|
226 |
+
amplitude = amplitude * np.exp(-(sxx * kx**2 + syy * ky**2 + sxy * kx * ky))
|
227 |
+
else:
|
228 |
+
raise ValueError(
|
229 |
+
f'Invalid func_gen_id. Should be an integer in the range [-1, 4], but received {func_gen_id}')
|
230 |
+
|
231 |
+
# Add the term to the nth function in the batch
|
232 |
+
F_batch[n] += amplitude * np.cos(kx * X - phase_x) * np.cos(ky * Y - phase_y)
|
233 |
+
|
234 |
+
# Adjust the function to ensure it's positive
|
235 |
+
minF = np.min(F_batch[n])
|
236 |
+
if minF < 0:
|
237 |
+
F_batch[n] -= minF
|
238 |
+
|
239 |
+
return F_batch
|
stnn/data/preprocessing.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# If STRICT_WARNING = True, the program exits when negative values are detected in ibf, obf, or rho
|
5 |
+
# This is important to check because negative values are unphysical.
|
6 |
+
STRICT_WARNING = True
|
7 |
+
|
8 |
+
|
9 |
+
def verify_nonnegative(fname, ibf, obf, rho):
|
10 |
+
"""
|
11 |
+
Check ibf, obf, and rho for negative values.
|
12 |
+
"""
|
13 |
+
found_warning = False
|
14 |
+
if np.any(ibf < 0):
|
15 |
+
print(f'Warning: negative values detected in array "ibf" in {fname}; min val: {ibf.min()}')
|
16 |
+
found_warning = True
|
17 |
+
elif np.any(obf < 0):
|
18 |
+
print(f'Warning: negative values detected in array "obf" in {fname}')
|
19 |
+
found_warning = True
|
20 |
+
elif np.any(rho < 0):
|
21 |
+
print(f'Warning: negative values detected in array "rho" in {fname}')
|
22 |
+
found_warning = True
|
23 |
+
|
24 |
+
if found_warning and STRICT_WARNING:
|
25 |
+
print(f'Exiting program. To avoid exiting on this warning, set STRICT_WARNING to False in {__file__.name}')
|
26 |
+
exit()
|
27 |
+
|
28 |
+
|
29 |
+
def get_data_from_file(fname, nx2, nx3, Nrange = None):
|
30 |
+
"""
|
31 |
+
Retrieves training X from the given HDF5 file. Assumes that the PDE parameters
|
32 |
+
are stored in datasets with their respective names, i.e., 'ell', 'a1', 'a2'. Likewise,
|
33 |
+
the density rho(x1,x2) and boundary X ibf(x2,x3) / obf(x2,x3) are stored in datasets
|
34 |
+
'rho', 'ibf', and 'obf'.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
nx2 (int): Second grid dimension
|
38 |
+
nx3 (int): Third grid dimension
|
39 |
+
fname (str): Path to the HDF5 file containing the X.
|
40 |
+
Nrange (tuple, optional): A tuple of two integers specifying the range of X to extract (start, end).
|
41 |
+
Defaults to None.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
tuple: Tuple of extracted X
|
45 |
+
|
46 |
+
Raises:
|
47 |
+
ValueError: If the file does not contain the required datasets.
|
48 |
+
"""
|
49 |
+
if not isinstance(fname, str):
|
50 |
+
raise TypeError('Filename must be a string.')
|
51 |
+
type_check1 = not (Nrange is None or isinstance(Nrange, (tuple, list)))
|
52 |
+
type_check2 = False
|
53 |
+
if isinstance(Nrange, (tuple, list)):
|
54 |
+
type_check2 = len(Nrange) != 2
|
55 |
+
if not type_check2:
|
56 |
+
type_check2 = not all((isinstance(i, int) or i is None) for i in Nrange)
|
57 |
+
if type_check1 or type_check2:
|
58 |
+
raise TypeError('Nrange must be a length-2 tuple or list of integers.')
|
59 |
+
|
60 |
+
if Nrange is None:
|
61 |
+
N1, N2 = None, None
|
62 |
+
else:
|
63 |
+
N1, N2 = Nrange
|
64 |
+
|
65 |
+
# Check that all datasets are present
|
66 |
+
dset_names = ['ell', 'a1', 'a2', 'rho', 'ibf', 'obf']
|
67 |
+
with h5py.File(fname, 'r') as input_file:
|
68 |
+
missing_keys = [key for key in dset_names if key not in input_file.keys()]
|
69 |
+
if missing_keys:
|
70 |
+
raise ValueError(f"Missing / incorrectly labeled datasets in file {fname}.'"
|
71 |
+
f"Could not find datasets: {', '.join(missing_keys)}")
|
72 |
+
|
73 |
+
ell = input_file['ell'][N1:N2]
|
74 |
+
a2 = input_file['a2'][N1:N2] # minor axis of outer boundary
|
75 |
+
a1 = input_file['a1'][N1:N2] # minor axis of inner boundary
|
76 |
+
eccentricity = np.ones_like(a1) - a1 # eccentricity of inner boundary
|
77 |
+
rho = input_file['rho'][N1:N2]
|
78 |
+
ibf = input_file['ibf'][N1:N2] # boundary X on inner boundary
|
79 |
+
obf = input_file['obf'][N1:N2] # boundary X on outer boundary
|
80 |
+
|
81 |
+
verify_nonnegative(fname, ibf, obf, rho)
|
82 |
+
|
83 |
+
# Combine 'ibf' and 'obf' into single array
|
84 |
+
N = rho.shape[0]
|
85 |
+
bf = np.zeros((N, 2 * nx2, nx3 // 2), dtype = np.float32)
|
86 |
+
bf[:, :nx2, :] = ibf
|
87 |
+
bf[:, nx2:, :] = obf
|
88 |
+
|
89 |
+
return a2, ell, eccentricity, bf, rho
|
90 |
+
|
91 |
+
|
92 |
+
def reshape_and_stack(a2, ell, ecc):
|
93 |
+
a2 = a2.reshape((-1, 1))
|
94 |
+
ell = ell.reshape((-1, 1))
|
95 |
+
ecc = ecc.reshape((-1, 1))
|
96 |
+
return np.hstack([a2, ell, ecc])
|
97 |
+
|
98 |
+
|
99 |
+
def apply_normalization(bf, rho):
|
100 |
+
fac = np.average(np.abs(rho), axis = (1, 2))
|
101 |
+
fac = fac.reshape((-1, 1, 1))
|
102 |
+
bf /= fac
|
103 |
+
rho /= fac
|
104 |
+
return bf, rho
|
105 |
+
|
106 |
+
|
107 |
+
def load_data(files, nx2, nx3, ell_min, ell_max, a2_min, a2_max,
|
108 |
+
Nrange_list = None, params_slice = None, normalize_data = False):
|
109 |
+
"""
|
110 |
+
Loads X from the specified files and processes it for use with the STNN.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
nx2 (int): Second grid dimension
|
114 |
+
nx3 (int): Third grid dimension
|
115 |
+
ell_min / ell_max (float): Minimum / maximum value of 'ell' over parameter space
|
116 |
+
a2_min / a2_max (float): Minimum / maximum value of 'a2' over parameter space
|
117 |
+
files (str or list of str): List of file paths containing the X
|
118 |
+
Nrange_list (list of tuples, optional): Slice indices for the extracting X from the corresponding file. If
|
119 |
+
given, must have the same number of elements as 'file_list'. Defaults
|
120 |
+
to None.
|
121 |
+
params_slice (slice, optional): Boolean array for selecting X over a subset of parameter space (ell, a1, a2).
|
122 |
+
Defaults to None.
|
123 |
+
normalize_data (bool, optional): Flag to normalize 'bf' and 'rho'. Defaults to False.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
tuple: A tuple containing the values of ell, a1, a2, bf, and rho. The parameters
|
127 |
+
ell, a1, a2 are combined into a single array 'params'.
|
128 |
+
"""
|
129 |
+
if isinstance(files, (list, tuple)) and len(files) == 0:
|
130 |
+
raise ValueError(f'List of files provided to "load_data" is empty.')
|
131 |
+
if not isinstance(files, (list, tuple)):
|
132 |
+
files = [files]
|
133 |
+
if Nrange_list is None or len(Nrange_list) == 0:
|
134 |
+
# Default
|
135 |
+
Nrange_list = [None for _ in range(len(files))]
|
136 |
+
else:
|
137 |
+
# User-specified; check shapes
|
138 |
+
if not isinstance(Nrange_list, (list, tuple)):
|
139 |
+
Nrange_list = [Nrange_list]
|
140 |
+
if len(files) != len(Nrange_list):
|
141 |
+
raise ValueError('List of input files must have same length as list of Nrange tuples')
|
142 |
+
a2_list = []
|
143 |
+
ell_list = []
|
144 |
+
ecc_list = []
|
145 |
+
bf_list = []
|
146 |
+
rho_list = []
|
147 |
+
|
148 |
+
# Get X from each file and add to the lists
|
149 |
+
for file, Nrange in zip(files, Nrange_list):
|
150 |
+
a2, ell, ecc, bf, rho = get_data_from_file(file, nx2, nx3, Nrange = Nrange)
|
151 |
+
a2_list.append(a2)
|
152 |
+
ell_list.append(ell)
|
153 |
+
ecc_list.append(ecc)
|
154 |
+
bf_list.append(bf)
|
155 |
+
rho_list.append(rho)
|
156 |
+
|
157 |
+
a2 = np.concatenate(a2_list)
|
158 |
+
ell = np.concatenate(ell_list)
|
159 |
+
ecc = np.concatenate(ecc_list)
|
160 |
+
bf = np.vstack(bf_list)
|
161 |
+
rho = np.vstack(rho_list)
|
162 |
+
|
163 |
+
# Map ell and a2 values onto [0, 1]
|
164 |
+
ell = (ell - ell_min) / (ell_max - ell_min)
|
165 |
+
a2 = (a2 - a2_min) / (a2_max - a2_min)
|
166 |
+
|
167 |
+
params = reshape_and_stack(a2, ell, ecc)
|
168 |
+
|
169 |
+
if not params_slice is None:
|
170 |
+
# Extract subset of X, if params_slice is given
|
171 |
+
params = params[params_slice, ...]
|
172 |
+
bf = bf[params_slice, ...]
|
173 |
+
rho = rho[params_slice, ...]
|
174 |
+
|
175 |
+
if normalize_data:
|
176 |
+
bf, rho = apply_normalization(bf, rho)
|
177 |
+
|
178 |
+
return params, bf, rho
|
179 |
+
|
180 |
+
|
181 |
+
def load_training_data(file_list, nx2, nx3, ell_min, ell_max, a2_min, a2_max, Nrange_list = None,
|
182 |
+
params_slice = None, test_size = 0.1, random_state = 23, normalize_data = True):
|
183 |
+
"""
|
184 |
+
Loads training X from specified files and preprocesses it for use with training the STNN.
|
185 |
+
|
186 |
+
This function wraps the 'load_data' function, adding additional steps specific to preparing training X.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
nx2 (int): Second grid dimension
|
190 |
+
nx3 (int): Third grid dimension
|
191 |
+
ell_min / ell_max (float): Minimum / maximum value of 'ell' over parameter space
|
192 |
+
a2_min / a2_max (float): Minimum / maximum value of 'a2' over parameter space
|
193 |
+
file_list (list of str): List of file paths containing the X
|
194 |
+
Nrange_list (list of tuples, optional): Slice indices for the extracting X from the corresponding file. If
|
195 |
+
given, must have the same number of elements as 'file_list'. Defaults
|
196 |
+
to None.
|
197 |
+
params_slice (slice, optional): Boolean array for selecting X over a subset of parameter space (ell, a1, a2).
|
198 |
+
Defaults to None.
|
199 |
+
test_size (float, optional): Size of the test/validation dataset as a fraction of the total dataset size.
|
200 |
+
Defaults to 0.1.
|
201 |
+
random_state (int, optional): Random seed used to select the train-test split. Defaults to 23.
|
202 |
+
normalize_data (bool, optional): Flag to normalize 'bf' and 'rho'. Defaults to False.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
tuple: A tuple containing the values of ell, a1, a2, bf, and rho. The parameters
|
206 |
+
ell, a1, a2 are combined into a single array 'params'.
|
207 |
+
"""
|
208 |
+
params, bf, rho = load_data(file_list, nx2, nx3, ell_min, ell_max, a2_min, a2_max,
|
209 |
+
Nrange_list = Nrange_list, params_slice = params_slice, normalize_data = normalize_data)
|
210 |
+
|
211 |
+
(rho_train, rho_test,
|
212 |
+
Y_train, Y_test) = train_test_split(rho, [params, bf], test_size = test_size, random_state = random_state)
|
213 |
+
|
214 |
+
params_train = Y_train[0]
|
215 |
+
params_test = Y_test[0]
|
216 |
+
bf_train = Y_train[1]
|
217 |
+
bf_test = Y_test[1]
|
218 |
+
|
219 |
+
print('Finished loading training X:')
|
220 |
+
print(f' params_train.shape:\t{params_train.shape}')
|
221 |
+
print(f' bf_train.shape:\t{bf_train.shape}')
|
222 |
+
print(f' rho_train.shape:\t{rho_train.shape}')
|
223 |
+
print(f' params_test.shape:\t{params_test.shape}')
|
224 |
+
print(f' bf_test.shape:\t{bf_test.shape}')
|
225 |
+
print(f' rho_test.shape:\t{rho_test.shape}')
|
226 |
+
|
227 |
+
# Compute min/max extent of training X in parameter space.
|
228 |
+
# Note that 'params' is denormalized before computing the max/min.
|
229 |
+
min_a2 = np.min(a2_min + (a2_max - a2_min) * params[:, 0])
|
230 |
+
min_ell = np.min(ell_min + (ell_max - ell_min) * params[:, 1])
|
231 |
+
min_ecc = np.min(params[:, 2])
|
232 |
+
|
233 |
+
max_a2 = np.max(a2_min + (a2_max - a2_min) * params[:, 0])
|
234 |
+
max_ell = np.max(ell_min + (ell_max - ell_min) * params[:, 1])
|
235 |
+
max_ecc = np.max(params[:, 2])
|
236 |
+
|
237 |
+
print('')
|
238 |
+
print(f' Number of circle samples (train):\t{np.sum(params[:, 2] < 1e-7)}')
|
239 |
+
print(f' Number of ellipse samples (train):\t{np.sum(params[:, 2] > 0)}')
|
240 |
+
print(f' Min .. Max in training X:')
|
241 |
+
print(f' ell:\t{min_ell:.2f} .. {max_ell:.2f}')
|
242 |
+
print(f' a2:\t{min_a2:.2f} .. {max_a2:.2f}')
|
243 |
+
print(f' ecc:\t{min_ecc:.2f} .. {max_ecc:.2f}')
|
244 |
+
print('-------------------------------------------')
|
245 |
+
|
246 |
+
return params_train, bf_train, rho_train, params_test, bf_test, rho_test
|
247 |
+
|
248 |
+
|
249 |
+
def train_test_split(X, Y, test_size = 0.1, random_state = None):
|
250 |
+
"""
|
251 |
+
Split (X, Y) pairs into random train and test subsets.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
X (np.ndarray or list of arrays): Input dataset
|
255 |
+
Y (np.ndarray or list of arrays): Labels for the dataset
|
256 |
+
test_size (float): Proportion of the dataset to include in the test split
|
257 |
+
random_state (int): Controls the shuffling applied to the X and Y before applying the split
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
X_train, X_test, Y_train, Y_test: Lists containing train-test split of the dataset. The format is
|
261 |
+
the same as the input X. For example, if 'X' is an array and 'Y' is a list of arrays, then X_train
|
262 |
+
and X_test will be arrays, and Y_train and Y_test will be lists of arrays.
|
263 |
+
|
264 |
+
Note: This function is included primarilyto reduce module dependency requirements, and it may not be memory-efficient
|
265 |
+
for large datasets. sklearn.model_selection.train_test_split has similar functionality and may be preferred
|
266 |
+
for performance-critical applications.
|
267 |
+
"""
|
268 |
+
if len(X) == 0 or len(Y) == 0:
|
269 |
+
raise ValueError("Input arrays/lists X and Y cannot be empty.")
|
270 |
+
|
271 |
+
input_X_is_array = isinstance(X, np.ndarray)
|
272 |
+
input_Y_is_array = isinstance(Y, np.ndarray)
|
273 |
+
|
274 |
+
if input_X_is_array:
|
275 |
+
X = [X]
|
276 |
+
if input_Y_is_array:
|
277 |
+
Y = [Y]
|
278 |
+
|
279 |
+
total_samples = X[0].shape[0]
|
280 |
+
|
281 |
+
# Check for consistent number of samples across all datasets
|
282 |
+
if any(x.shape[0] != total_samples for x in X) or any(y.shape[0] != total_samples for y in Y):
|
283 |
+
raise ValueError('Inconsistent number of samples.')
|
284 |
+
|
285 |
+
Ntest = int(test_size * total_samples)
|
286 |
+
if Ntest < 1 or Ntest > total_samples:
|
287 |
+
raise ValueError('Size of test dataset cannot be less than 1 or greater than the total number of samples.')
|
288 |
+
|
289 |
+
if random_state is not None:
|
290 |
+
np.random.seed(random_state)
|
291 |
+
|
292 |
+
# Shuffle indices
|
293 |
+
indices = np.arange(total_samples)
|
294 |
+
np.random.shuffle(indices)
|
295 |
+
|
296 |
+
# Apply shuffled indices to all datasets
|
297 |
+
shuffled_X = [x[indices] for x in X]
|
298 |
+
shuffled_Y = [y[indices] for y in Y]
|
299 |
+
|
300 |
+
# Split X and Y
|
301 |
+
X_train = [x[:-Ntest] for x in shuffled_X]
|
302 |
+
X_test = [x[-Ntest:] for x in shuffled_X]
|
303 |
+
Y_train = [y[:-Ntest] for y in shuffled_Y]
|
304 |
+
Y_test = [y[-Ntest:] for y in shuffled_Y]
|
305 |
+
|
306 |
+
# Convert back to arrays if original input was array
|
307 |
+
if input_X_is_array:
|
308 |
+
X_train, X_test = X_train[0], X_test[0]
|
309 |
+
if input_Y_is_array:
|
310 |
+
Y_train, Y_test = Y_train[0], Y_test[0]
|
311 |
+
|
312 |
+
return X_train, X_test, Y_train, Y_test
|
stnn/data/test_functions.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
"""
|
4 |
+
Interface to a subset of the test functions listed at
|
5 |
+
https://en.wikipedia.org/wiki/Test_functions_for_optimization
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
def rastrigin(x, y):
|
10 |
+
args = (x, y)
|
11 |
+
A = 10
|
12 |
+
return A * len(x) + sum([(xi**2 - A * np.cos(2 * np.pi * xi)) for xi in args])
|
13 |
+
|
14 |
+
|
15 |
+
def ackley(x, y):
|
16 |
+
return -20 * np.exp(-0.2 * np.sqrt(0.5 * (x**2 + y**2))) - \
|
17 |
+
np.exp(0.5 * (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y))) + np.e + 20
|
18 |
+
|
19 |
+
|
20 |
+
def sphere(x, y):
|
21 |
+
return x**2 + y**2
|
22 |
+
|
23 |
+
|
24 |
+
def rosenbrock(x, y):
|
25 |
+
return 100 * (y - x**2)**2 + (1 - x)**2
|
26 |
+
|
27 |
+
|
28 |
+
def beale(x, y):
|
29 |
+
return (1.5 - x + x * y)**2 + (2.25 - x + x * y**2)**2 + (2.625 - x + x * y**3)**2
|
30 |
+
|
31 |
+
|
32 |
+
def goldstein_price(x, y):
|
33 |
+
return (1 + (x + y + 1)**2 * (19 - 14 * x + 3 * x**2 - 14 * y + 6 * x * y + 3 * y**2)) * \
|
34 |
+
(30 + (2 * x - 3 * y)**2 * (18 - 32 * x + 12 * x**2 + 48 * y - 36 * x * y + 27 * y**2))
|
35 |
+
|
36 |
+
|
37 |
+
def booth(x, y):
|
38 |
+
return (x + 2 * y - 7)**2 + (2 * x + y - 5)**2
|
39 |
+
|
40 |
+
|
41 |
+
def bukin(x, y):
|
42 |
+
return 100 * np.sqrt(abs(y - 0.01 * x**2)) + 0.01 * abs(x + 10)
|
43 |
+
|
44 |
+
|
45 |
+
def matyas(x, y):
|
46 |
+
return 0.26 * (x**2 + y**2) - 0.48 * x * y
|
47 |
+
|
48 |
+
|
49 |
+
def levi(x, y):
|
50 |
+
return np.sin(3 * np.pi * x)**2 + (x - 1)**2 * (1 + np.sin(3 * np.pi * y)**2) + \
|
51 |
+
(y - 1)**2 * (1 + np.sin(2 * np.pi * y)**2)
|
52 |
+
|
53 |
+
|
54 |
+
def himmelblau(x, y):
|
55 |
+
return (x**2 + y - 11)**2 + (x + y**2 - 7)**2
|
56 |
+
|
57 |
+
|
58 |
+
def three_hump_camel(x, y):
|
59 |
+
return 2 * x**2 - 1.05 * x**4 + x**6 / 6 + x * y + y**2
|
60 |
+
|
61 |
+
|
62 |
+
def easom(x, y):
|
63 |
+
return -np.cos(x) * np.cos(y) * np.exp(-((x - np.pi)**2 + (y - np.pi)**2))
|
64 |
+
|
65 |
+
|
66 |
+
def cross_in_tray(x, y):
|
67 |
+
return -0.0001 * (abs(np.sin(x) * np.sin(y) * np.exp(abs(100 - np.sqrt(x**2 + y**2) / np.pi))) + 1)**0.1
|
68 |
+
|
69 |
+
|
70 |
+
def eggholder(x, y):
|
71 |
+
return -(y + 47) * np.sin(np.sqrt(abs(x / 2 + (y + 47)))) - x * np.sin(np.sqrt(abs(x - (y + 47))))
|
72 |
+
|
73 |
+
|
74 |
+
def holder_table(x, y):
|
75 |
+
return -abs(np.sin(x) * np.cos(y) * np.exp(abs(1 - np.sqrt(x**2 + y**2) / np.pi)))
|
76 |
+
|
77 |
+
|
78 |
+
def mccormick(x, y):
|
79 |
+
return np.sin(x + y) + (x - y)**2 - 1.5 * x + 2.5 * y + 1
|
80 |
+
|
81 |
+
|
82 |
+
def schaffer2(x, y):
|
83 |
+
return 0.5 + (np.sin(x**2 - y**2)**2 - 0.5) / (1 + 0.001 * (x**2 + y**2))**2
|
84 |
+
|
85 |
+
|
86 |
+
def schaffer4(x, y):
|
87 |
+
return 0.5 + (np.cos(np.sin(abs(x**2 - y**2)))**2 - 0.5) / (1 + 0.001 * (x**2 + y**2))**2
|
88 |
+
|
89 |
+
|
90 |
+
def styblinski_tang(x, y):
|
91 |
+
args = (x, y)
|
92 |
+
return sum([xi**4 - 16 * xi**2 + 5 * xi for xi in args]) / 2
|
93 |
+
|
94 |
+
|
95 |
+
functions = [
|
96 |
+
rastrigin,
|
97 |
+
ackley,
|
98 |
+
sphere,
|
99 |
+
rosenbrock,
|
100 |
+
beale,
|
101 |
+
goldstein_price,
|
102 |
+
booth,
|
103 |
+
bukin,
|
104 |
+
matyas,
|
105 |
+
levi,
|
106 |
+
himmelblau,
|
107 |
+
three_hump_camel,
|
108 |
+
easom,
|
109 |
+
cross_in_tray,
|
110 |
+
eggholder,
|
111 |
+
holder_table,
|
112 |
+
mccormick,
|
113 |
+
schaffer2,
|
114 |
+
schaffer4,
|
115 |
+
styblinski_tang
|
116 |
+
]
|
117 |
+
|
118 |
+
function_names = [
|
119 |
+
'rastrigin',
|
120 |
+
'ackley',
|
121 |
+
'sphere',
|
122 |
+
'rosenbrock',
|
123 |
+
'beale',
|
124 |
+
'goldstein_price',
|
125 |
+
'booth',
|
126 |
+
'bukin',
|
127 |
+
'matyas',
|
128 |
+
'levi',
|
129 |
+
'himmelblau',
|
130 |
+
'three_hump_camel',
|
131 |
+
'easom',
|
132 |
+
'cross_in_tray',
|
133 |
+
'eggholder',
|
134 |
+
'holder_table',
|
135 |
+
'mccormick',
|
136 |
+
'schaffer2',
|
137 |
+
'schaffer4',
|
138 |
+
'styblinski_tang'
|
139 |
+
]
|
140 |
+
|
141 |
+
domains = {
|
142 |
+
'rastrigin': (-5.12, 5.12),
|
143 |
+
'ackley': (-5, 5),
|
144 |
+
'sphere': (-1, 1),
|
145 |
+
'rosenbrock': {'x': (-2, 2), 'y': (-10, 10)},
|
146 |
+
'beale': (-4.5, 4.5),
|
147 |
+
'goldstein_price': (-2, 2),
|
148 |
+
'booth': (-10, 10),
|
149 |
+
'bukin': {'x': (-15, -5), 'y': (-3, 3)},
|
150 |
+
'matyas': (-10, 10),
|
151 |
+
'levi': (-10, 10),
|
152 |
+
'himmelblau': (-5, 5),
|
153 |
+
'three_hump_camel': (-5, 5),
|
154 |
+
'easom': (-100, 100),
|
155 |
+
'cross_in_tray': (-10, 10),
|
156 |
+
'eggholder': (-512, 512),
|
157 |
+
'holder_table': (-10, 10),
|
158 |
+
'mccormick': {'x': (-1.5, 4), 'y': (-3, 4)},
|
159 |
+
'schaffer2': (-100, 100),
|
160 |
+
'schaffer4': (-100, 100),
|
161 |
+
'styblinski_tang': (-5, 5)
|
162 |
+
}
|
163 |
+
|
164 |
+
|
165 |
+
def scale_input(x, domain):
|
166 |
+
min_d, max_d = domain
|
167 |
+
return min_d + (max_d - min_d) * x
|
168 |
+
|
169 |
+
|
170 |
+
def get_test_function(X, Y, fun_idx):
|
171 |
+
"""
|
172 |
+
Evaluates a function on inputs (X, Y).
|
173 |
+
|
174 |
+
Args:
|
175 |
+
X (float or array-like): The X input values to be scaled and used in the function.
|
176 |
+
Y (float or array-like): The Y input values to be scaled and used in the function.
|
177 |
+
Ignored if the function takes only one argument.
|
178 |
+
fun_idx (int): The index of the function to be retrieved from a predefined list 'functions'.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
(float or array-like), values of the function on the grid
|
182 |
+
"""
|
183 |
+
func = functions[fun_idx]
|
184 |
+
|
185 |
+
domain = domains[func.__name__]
|
186 |
+
if isinstance(domain, dict):
|
187 |
+
x_scaled = scale_input(X, domain['x'])
|
188 |
+
y_scaled = scale_input(Y, domain['y'])
|
189 |
+
else:
|
190 |
+
x_scaled = scale_input(X, domain)
|
191 |
+
y_scaled = scale_input(Y, domain)
|
192 |
+
|
193 |
+
try:
|
194 |
+
output = func(X, Y)
|
195 |
+
except TypeError:
|
196 |
+
output = func(X)
|
197 |
+
return output
|
stnn/linalg_backend.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module imports and wrapper functions of the linear algebra backend.
|
3 |
+
|
4 |
+
If __usecupy__ is "True" and cupy is successfully imported, then
|
5 |
+
|
6 |
+
xp --> cupy
|
7 |
+
spx --> cupyx.scipy.sparse.linalg
|
8 |
+
|
9 |
+
Otherwise,
|
10 |
+
|
11 |
+
xp --> numpy
|
12 |
+
spx --> scipy.sparse
|
13 |
+
|
14 |
+
For example, if __usecupy__ is False, then
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import scipy.sparse.linalg as sp
|
18 |
+
|
19 |
+
is equivalent to
|
20 |
+
|
21 |
+
from stnn.linalg_backend import xp, spx
|
22 |
+
"""
|
23 |
+
__usecupy__ = True
|
24 |
+
|
25 |
+
try:
|
26 |
+
# If CuPy is not preferred or available, fall back to NumPy
|
27 |
+
if not __usecupy__:
|
28 |
+
raise ImportError
|
29 |
+
import cupy as cp
|
30 |
+
import cupyx.scipy.sparse.linalg
|
31 |
+
import cupyx.scipy.sparse as cupy_sparse
|
32 |
+
|
33 |
+
xp = cp
|
34 |
+
spx = cupy_sparse
|
35 |
+
using_cupy = True
|
36 |
+
except ImportError:
|
37 |
+
import numpy as np
|
38 |
+
import scipy.sparse.linalg
|
39 |
+
import scipy.sparse as scipy_sparse
|
40 |
+
|
41 |
+
xp = np
|
42 |
+
spx = scipy_sparse
|
43 |
+
using_cupy = False
|
44 |
+
|
45 |
+
|
46 |
+
def csr_matrix(L):
|
47 |
+
"""
|
48 |
+
Create a CSR (Compressed Sparse Row) matrix.
|
49 |
+
|
50 |
+
If CuPy is available and enabled, this function will create a CuPy CSR matrix.
|
51 |
+
Otherwise, it converts the given data to a SciPy CSR matrix.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
L (array_like or sparse matrix): 2-D array or sparse matrix to convert.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
CSR matrix: The converted CSR matrix, using either CuPy or SciPy.
|
58 |
+
"""
|
59 |
+
if using_cupy:
|
60 |
+
return spx.csr_matrix(L, dtype=xp.float64)
|
61 |
+
return L.tocsr()
|
62 |
+
|
63 |
+
|
64 |
+
def asnumpy(arr):
|
65 |
+
"""
|
66 |
+
Convert an array from the backend library (CuPy or NumPy) to NumPy.
|
67 |
+
If NumPy is enabled, the input array is returned unchanged.
|
68 |
+
"""
|
69 |
+
if using_cupy:
|
70 |
+
return cp.asnumpy(arr)
|
71 |
+
return arr
|
72 |
+
|
73 |
+
|
74 |
+
def asarray(arr):
|
75 |
+
"""
|
76 |
+
Convert the input to an array of the backend library (CuPy or NumPy).
|
77 |
+
If NumPy is enabled, the input array is returned unchanged.
|
78 |
+
"""
|
79 |
+
if using_cupy:
|
80 |
+
return cp.asarray(arr, dtype=cp.float64)
|
81 |
+
return arr
|
stnn/nn/__init__.py
ADDED
File without changes
|
stnn/nn/stnn.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras.layers import Input, Multiply, Add
|
3 |
+
from tensorflow.keras.models import Model
|
4 |
+
|
5 |
+
from .stnn_layers import TTL, SoftmaxEmbeddingLayer
|
6 |
+
|
7 |
+
|
8 |
+
def build_stnn(config):
|
9 |
+
"""
|
10 |
+
Constructs a Stacked Tensorial Neural Network (STNN) as a TensorFlow model based on
|
11 |
+
the provided configuration dictionary.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
config (dict): Configuration dictionary for the STNN model. Must contain the following entries:
|
15 |
+
- 'K' (int): Number of tensor networks to be stacked
|
16 |
+
- 'd' (int): Number of dense layers in the model's SoftmaxEmbeddingLayer.
|
17 |
+
- 'nx1', 'nx2', 'nx3' (int): Dimensions of the finite-difference grid
|
18 |
+
- All other required entries for the TTL class, not already listed above.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
tf.keras.Model: The constructed STNN model.
|
22 |
+
|
23 |
+
Raises:
|
24 |
+
ValueError: If the config dictionary does not contain positive integers 'K', 'd', 'nx2', 'nx3';
|
25 |
+
also if config['nx3'] is not divisible by 2.
|
26 |
+
"""
|
27 |
+
required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
|
28 |
+
missing_keys = [key for key in required_keys if key not in config]
|
29 |
+
if missing_keys:
|
30 |
+
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
|
31 |
+
|
32 |
+
for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
|
33 |
+
if not isinstance(config[key], int):
|
34 |
+
raise TypeError(f"{key} must be an integer.")
|
35 |
+
|
36 |
+
for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
|
37 |
+
if config[key] <= 0:
|
38 |
+
raise ValueError(f"{key} must be positive.")
|
39 |
+
|
40 |
+
if config['nx3'] % 2 == 1:
|
41 |
+
raise ValueError('Config error: nx3 must be divisible by 2.')
|
42 |
+
|
43 |
+
K = config['K'] # Number of tensor networks
|
44 |
+
d = config['d'] # Number of dense layers in SoftmaxEmbeddingLayer
|
45 |
+
input_shape = (2 * config['nx2'], config['nx3'] // 2, 1)
|
46 |
+
input_tensor = Input(shape = input_shape)
|
47 |
+
|
48 |
+
# Process parameter array (ell, a1, a2) and output weights for stacking the tensor networks
|
49 |
+
preprocess_layer = SoftmaxEmbeddingLayer(K, d)
|
50 |
+
params_input = Input(shape = (3,))
|
51 |
+
stack_weights = preprocess_layer(params_input)[:, tf.newaxis, tf.newaxis, :]
|
52 |
+
|
53 |
+
# Build the tensor networks using the custom keras layer class TLL
|
54 |
+
models = [TTL(config) for _ in range(K)]
|
55 |
+
|
56 |
+
# Combine the tensor networks based on the weights outputted by 'preprocess_layer'
|
57 |
+
weighted_outputs = []
|
58 |
+
for i, model in enumerate(models):
|
59 |
+
processed_output = model(input_tensor)
|
60 |
+
weighted_output = Multiply()([processed_output, stack_weights[..., i]])
|
61 |
+
weighted_outputs.append(weighted_output)
|
62 |
+
final_output = Add()(weighted_outputs)
|
63 |
+
|
64 |
+
model = Model(inputs = [params_input, input_tensor], outputs = final_output)
|
65 |
+
|
66 |
+
return model
|
stnn/nn/stnn_layers.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras.models import Sequential
|
4 |
+
from tensorflow.keras.layers import Reshape, Flatten
|
5 |
+
import t3f
|
6 |
+
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
|
10 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
|
11 |
+
logging.getLogger('tensorflow').setLevel(logging.FATAL)
|
12 |
+
|
13 |
+
|
14 |
+
class SoftmaxEmbeddingLayer(tf.keras.layers.Layer):
|
15 |
+
"""
|
16 |
+
Parameter embedding layer that generates the weights used for stacking the tensor networks. It
|
17 |
+
takes the parameter array, lambda = (ell, a1, a2), as input and outputs K numbers that sum to 1.
|
18 |
+
|
19 |
+
Attributes:
|
20 |
+
output_dim (int): The dimension of the output
|
21 |
+
expansion_dim (int): The dimension used for expanding the input in intermediate layers.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, output_dim, d, expansion_dim = 30, **kwargs):
|
25 |
+
super(SoftmaxEmbeddingLayer, self).__init__(**kwargs)
|
26 |
+
self.reduction_layer = None
|
27 |
+
self.expansion_layers = None
|
28 |
+
self.output_dim = output_dim
|
29 |
+
self.expansion_dim = expansion_dim
|
30 |
+
self.d = d # Number of dense layers
|
31 |
+
|
32 |
+
def build(self, input_shape):
|
33 |
+
# Expansion layers to increase dimensionality
|
34 |
+
self.expansion_layers = [tf.keras.layers.Dense(self.expansion_dim, activation = 'relu') for _ in range(self.d)]
|
35 |
+
# Reduction layer to bring dimensionality back to the desired output dimension
|
36 |
+
self.reduction_layer = tf.keras.layers.Dense(self.output_dim)
|
37 |
+
|
38 |
+
def call(self, inputs):
|
39 |
+
expanded = inputs
|
40 |
+
for layer in self.expansion_layers:
|
41 |
+
expanded = layer(expanded)
|
42 |
+
return tf.nn.softmax(self.reduction_layer(expanded))
|
43 |
+
|
44 |
+
def get_config(self):
|
45 |
+
return {'output_dim': self.output_dim, 'expansion_dim': self.expansion_dim}
|
46 |
+
|
47 |
+
|
48 |
+
class EinsumTTLRegularizer(tf.keras.regularizers.Regularizer):
|
49 |
+
"""
|
50 |
+
Regularizer for the Einsum layer of the TTL layer class, penalizing high-frequency components of the
|
51 |
+
weights vector.
|
52 |
+
|
53 |
+
Attributes:
|
54 |
+
strength (float): The regularization strength.
|
55 |
+
midpoint (int): Index demarcating the inner and outer boundaries, i.e. x[:midpoint] contains
|
56 |
+
data for the inner boundary, and x[midpoint:] contains data for the outer boundary.
|
57 |
+
The regularization is designed so it does not penalize variations across this index.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, strength, midpoint):
|
61 |
+
self.strength = strength
|
62 |
+
self.midpoint = midpoint
|
63 |
+
|
64 |
+
def __call__(self, x):
|
65 |
+
diff = tf.abs(x[1:self.midpoint - 1] - x[0:self.midpoint - 2]) \
|
66 |
+
+ tf.abs(x[self.midpoint + 1:2 * self.midpoint - 1] - x[self.midpoint:2 * self.midpoint - 2])
|
67 |
+
return self.strength * tf.reduce_sum(diff)
|
68 |
+
|
69 |
+
def get_config(self):
|
70 |
+
return {'strength': self.strength, 'midpoint': self.midpoint}
|
71 |
+
|
72 |
+
|
73 |
+
def cosine_initializer(kx = 1.0):
|
74 |
+
"""
|
75 |
+
Initializer for the Einsum layer of the TTL layer class. Sets the weights to a linear combination
|
76 |
+
of cos(kx * x) and cos(2 * kx * x), where x is the weight vector.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
kx (float, optional): Frequency of the cosine terms. Defaults to 1.0.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
_initializer: Weight initializer function
|
83 |
+
"""
|
84 |
+
|
85 |
+
def _initializer(shape, dtype = None):
|
86 |
+
x_values = np.linspace(-np.pi, np.pi, shape[0])
|
87 |
+
cos_values = np.random.uniform(-0.1, 0.3) * np.abs(np.cos(kx * x_values)) \
|
88 |
+
+ np.random.uniform(-0.05, 0.05) * np.abs(np.cos(2.0 * kx * x_values))
|
89 |
+
return tf.convert_to_tensor(-cos_values, dtype = dtype)
|
90 |
+
|
91 |
+
return _initializer
|
92 |
+
|
93 |
+
|
94 |
+
class EinsumTTL(tf.keras.layers.Layer):
|
95 |
+
"""
|
96 |
+
Layer that contracts the input tensor over the second dimension before passing it to the TTL.
|
97 |
+
If regularization is enabled, it applies an `EinsumTTLRegularizer` to the kernels.
|
98 |
+
|
99 |
+
Attributes:
|
100 |
+
(nx2, nx3) (integers): Shape parameters characterizing input tensor dimensions. T
|
101 |
+
The shape of the input tensor is (2*nx2, nx3//2).
|
102 |
+
W (int): Number of einsum contractions
|
103 |
+
kernels (list): List of weight matrices for each einsum contraction
|
104 |
+
regularization_strength (float): The strength of the regularization if used.
|
105 |
+
use_regularization (bool): Flag to indicate whether regularization is used.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, nx2, nx3, W, use_regularization, regularization_strength = 0.005, **kwargs):
|
109 |
+
super(EinsumTTL, self).__init__(**kwargs)
|
110 |
+
self.nx2 = nx2
|
111 |
+
self.nx3 = nx3
|
112 |
+
self.W = W
|
113 |
+
self.kernels = []
|
114 |
+
|
115 |
+
self.regularization_strength = regularization_strength
|
116 |
+
self.use_regularization = use_regularization
|
117 |
+
if self.use_regularization:
|
118 |
+
regularizer = EinsumTTLRegularizer(self.regularization_strength, self.nx3 // 4)
|
119 |
+
else:
|
120 |
+
regularizer = None
|
121 |
+
|
122 |
+
initializer_values_ = [1.0, 0.5, 2.0, 3.0] * W
|
123 |
+
initializer_values = initializer_values_[:W]
|
124 |
+
for i in range(W):
|
125 |
+
self.kernels.append(self.add_weight(
|
126 |
+
name = f'w{i + 1}',
|
127 |
+
shape = (nx3 // 2,),
|
128 |
+
regularizer = regularizer,
|
129 |
+
initializer = cosine_initializer(initializer_values[i])
|
130 |
+
))
|
131 |
+
|
132 |
+
def call(self, inputs):
|
133 |
+
parts = []
|
134 |
+
for w in self.kernels:
|
135 |
+
part_a = tf.einsum('abc,c->ab', inputs[:, :self.nx2, :self.nx3 // 4], w[:self.nx3 // 4]) + \
|
136 |
+
tf.einsum('abc,c->ab', inputs[:, :self.nx2, self.nx3 // 4:self.nx3 // 2],
|
137 |
+
tf.reverse(w[:self.nx3 // 4], axis = [0]))
|
138 |
+
part_b = tf.einsum('abc,c->ab', inputs[:, self.nx2:, :self.nx3 // 4], w[self.nx3 // 4:self.nx3 // 2]) + \
|
139 |
+
tf.einsum('abc,c->ab', inputs[:, self.nx2:, self.nx3 // 4:self.nx3 // 2],
|
140 |
+
tf.reverse(w[self.nx3 // 4:self.nx3 // 2], axis = [0]))
|
141 |
+
parts.extend([part_a, part_b])
|
142 |
+
|
143 |
+
return tf.concat(parts, axis = 1)
|
144 |
+
|
145 |
+
def get_config(self):
|
146 |
+
return {'use_regularization': self.use_regularization,
|
147 |
+
'regularization_strength': self.regularization_strength}
|
148 |
+
|
149 |
+
|
150 |
+
class TTL(tf.keras.layers.Layer):
|
151 |
+
"""
|
152 |
+
TTL (Tensor Train Layer) is a custom TensorFlow Keras layer that builds a model
|
153 |
+
based on the given configuration. This layer is designed to work with
|
154 |
+
tensor train decomposition in neural networks.
|
155 |
+
|
156 |
+
Attributes:
|
157 |
+
config (dict): Configuration dictionary containing parameters for the model.
|
158 |
+
|
159 |
+
'nx1', 'nx2', 'nx3': Integers, dimensions of the finite-difference grid
|
160 |
+
|
161 |
+
'shape1': List of integers, defines the shape of the output tensor in the tensor train format.
|
162 |
+
The length of shape1 must match the length of shape2.
|
163 |
+
'shape2': List of integers, specifies the shape of the input tensor in the tensor train format.
|
164 |
+
The length of shape2 must match the length of shape1.
|
165 |
+
'ranks': List of integers, represents the ranks in the tensor train decomposition.
|
166 |
+
The length of this list determines the complexity and the number of parameters in the tensor train layer.
|
167 |
+
'W' (int): Number of weight vectors to use in the initial EinsumTTL layer. Setting W = 0 means that no EinsumTLL
|
168 |
+
used.
|
169 |
+
|
170 |
+
'use_regularization' (boolean, optional, default: False): Indicates whether regularization is used in the EinsumTTL.
|
171 |
+
'regularization_strength' (float, optional, default: 0): Strength of the regularization
|
172 |
+
|
173 |
+
model (tf.keras.Sequential): The Sequential model built based on the provided configuration.
|
174 |
+
|
175 |
+
Methods:
|
176 |
+
load_config(self, config): Loads configuration
|
177 |
+
build_model(self): Builds the layer
|
178 |
+
call(inputs): Method for the forward pass of the layer.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, config, **kwargs):
|
182 |
+
super(TTL, self).__init__(**kwargs)
|
183 |
+
self.model = Sequential()
|
184 |
+
self.nx1 = None
|
185 |
+
self.nx2 = None
|
186 |
+
self.nx3 = None
|
187 |
+
self.shape1 = None
|
188 |
+
self.shape2 = None
|
189 |
+
self.ranks = None
|
190 |
+
self.W = None
|
191 |
+
self.use_regularization = None
|
192 |
+
self.regularization_strength = None
|
193 |
+
self._required_keys = ['nx1', 'nx2', 'nx3', 'shape1', 'shape2', 'ranks', 'W']
|
194 |
+
|
195 |
+
config.setdefault('use_regularization', False)
|
196 |
+
config.setdefault('regularization_strength', 0.0)
|
197 |
+
self.load_config(config)
|
198 |
+
self.config = config
|
199 |
+
self.build_model()
|
200 |
+
|
201 |
+
def load_config(self, config):
|
202 |
+
missing_keys = [key for key in self._required_keys if key not in config]
|
203 |
+
if missing_keys:
|
204 |
+
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
|
205 |
+
|
206 |
+
if not isinstance(config['use_regularization'], bool):
|
207 |
+
raise TypeError('use_regularization must be a boolean.')
|
208 |
+
else:
|
209 |
+
self.use_regularization = config['use_regularization']
|
210 |
+
|
211 |
+
self.regularization_strength = 0.0
|
212 |
+
|
213 |
+
for key in ['nx1', 'nx2', 'nx3', 'W']:
|
214 |
+
if not isinstance(config[key], int):
|
215 |
+
raise TypeError(f"{key} must be an integer.")
|
216 |
+
|
217 |
+
for key in ['nx1', 'nx2', 'nx3']:
|
218 |
+
if config[key] <= 0:
|
219 |
+
raise ValueError(f"{key} must be positive.")
|
220 |
+
|
221 |
+
if config['W'] < 0:
|
222 |
+
raise ValueError("W must be non-negative.")
|
223 |
+
|
224 |
+
nx1, nx2, nx3 = config['nx1'], config['nx2'], config['nx3']
|
225 |
+
self.nx1 = nx1
|
226 |
+
self.nx2 = nx2
|
227 |
+
self.nx3 = nx3
|
228 |
+
|
229 |
+
W = config['W']
|
230 |
+
self.W = W
|
231 |
+
|
232 |
+
input_dim = 2 * nx2 * W
|
233 |
+
if W == 0:
|
234 |
+
input_dim = nx2 * nx3
|
235 |
+
|
236 |
+
shape1, shape2 = config['shape1'], config['shape2']
|
237 |
+
if len(shape1) != len(shape2):
|
238 |
+
raise ValueError(
|
239 |
+
f'shape1 and shape2 must have the same length. '
|
240 |
+
f'Received: shape1 = {shape1}, shape2 = {shape2}.'
|
241 |
+
)
|
242 |
+
elif np.prod(np.array(shape1)) != nx1 * nx2:
|
243 |
+
raise ValueError(
|
244 |
+
f'prod(shape1) must be equal to the output dimension of the TTL '
|
245 |
+
f'(nx1 * nx2,). Received: prod(shape1) = {np.prod(np.array(shape1))}, '
|
246 |
+
f'nx1 * nx2 = {nx1 * nx2}.'
|
247 |
+
)
|
248 |
+
elif np.prod(np.array(shape2)) != input_dim:
|
249 |
+
raise ValueError(
|
250 |
+
f'prod(shape2) must be equal to the input dimension of the TTL '
|
251 |
+
f'(2 * nx2 * W or nx2 * nx3 if W = 0). '
|
252 |
+
f'Received: prod(shape2) = {np.prod(np.array(shape2))}, required input dimension = {input_dim}.'
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
self.shape1 = shape1
|
256 |
+
self.shape2 = shape2
|
257 |
+
|
258 |
+
self.ranks = config['ranks']
|
259 |
+
|
260 |
+
def build_model(self):
|
261 |
+
if self.W == 0:
|
262 |
+
self.model.add(Flatten(input_shape = (2 * self.nx2, self.nx3 // 2)))
|
263 |
+
else:
|
264 |
+
self.model.add(EinsumTTL(self.nx2, self.nx3, self.W, self.use_regularization,
|
265 |
+
regularization_strength = self.regularization_strength,
|
266 |
+
input_shape = (2 * self.nx2, self.nx3 // 2)))
|
267 |
+
self.model.add(Flatten())
|
268 |
+
tt_layer = t3f.nn.KerasDense(input_dims = self.shape2, output_dims = self.shape1,
|
269 |
+
tt_rank = self.ranks, use_bias = False, activation = 'linear')
|
270 |
+
self.model.add(tt_layer)
|
271 |
+
self.model.add(Reshape((self.nx1, self.nx2)))
|
272 |
+
|
273 |
+
def call(self, inputs):
|
274 |
+
return self.model(inputs)
|
stnn/pde/__init__.py
ADDED
File without changes
|
stnn/pde/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (168 Bytes). View file
|
|
stnn/pde/circle.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import *
|
2 |
+
|
3 |
+
|
4 |
+
def u_dot_thetavec(r, theta, w):
|
5 |
+
"""
|
6 |
+
Dot product of u = (cos(w), sin(w)) with the coordinate vector for theta.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
r (float or array-like): the radial coordinate(s)
|
10 |
+
theta (float or array-like): The angular coordinate(s).
|
11 |
+
w (float or array-like): The w coordinate(s).
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
numpy.ndarray: The calculated dot product for each point.
|
15 |
+
"""
|
16 |
+
return r * np.sin(w - theta)
|
17 |
+
|
18 |
+
|
19 |
+
def u_dot_thetahat(theta, w):
|
20 |
+
"""
|
21 |
+
Dot product of u = (cos(w), sin(w)) with the unit vector for theta.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
theta (float or array-like): The angular (theta) coordinate(s).
|
25 |
+
w (float or array-like): The w coordinate(s).
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
numpy.ndarray: The calculated dot product for each point.
|
29 |
+
"""
|
30 |
+
return np.sin(w - theta)
|
31 |
+
|
32 |
+
|
33 |
+
def u_dot_rvec(theta, w):
|
34 |
+
"""
|
35 |
+
Dot product of u = (cos(w), sin(w)) with the coordinate vector for r.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
theta (float or array-like): The radial (r) coordinate(s).
|
39 |
+
w (float or array-like): The w coordinate(s).
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
numpy.ndarray: The calculated dot product for each point.
|
43 |
+
"""
|
44 |
+
return np.cos(w - theta)
|
45 |
+
|
46 |
+
|
47 |
+
def get_system_circle(config):
|
48 |
+
"""
|
49 |
+
For a circular geometry, constructs the matrices, grids, and other quantities corresponding to the PDE system
|
50 |
+
specified by "config"".
|
51 |
+
|
52 |
+
Args:
|
53 |
+
config (dict): Configuration dictionary containing the system parameters.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tuple: A tuple containing matrices, grids, etc. for the PDE system
|
57 |
+
"""
|
58 |
+
required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2']
|
59 |
+
optional_keys = []
|
60 |
+
|
61 |
+
missing_keys = [key for key in required_keys if key not in config]
|
62 |
+
if missing_keys:
|
63 |
+
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
|
64 |
+
|
65 |
+
unused_keys = [key for key in config if key not in required_keys + optional_keys]
|
66 |
+
if unused_keys:
|
67 |
+
warnings.warn(f"Unused keys in config: {', '.join(unused_keys)}")
|
68 |
+
|
69 |
+
for key in ['nx1', 'nx2', 'nx3']:
|
70 |
+
if not isinstance(config[key], int):
|
71 |
+
raise TypeError(f"{key} must be an integer.")
|
72 |
+
|
73 |
+
for key in ['nx1', 'nx2', 'nx3', 'ell', 'a2']:
|
74 |
+
if config[key] <= 0:
|
75 |
+
raise ValueError(f"{key} must be positive.")
|
76 |
+
|
77 |
+
if config['a2'] < 1.0:
|
78 |
+
raise ValueError('a2 must be greater than 1.')
|
79 |
+
|
80 |
+
nr, ntheta, nw = config['nx1'], config['nx2'], config['nx3']
|
81 |
+
R1 = 1.0
|
82 |
+
R2 = config['a2']
|
83 |
+
ell = config['ell']
|
84 |
+
|
85 |
+
# 1D grids
|
86 |
+
theta, w = get_angular_grids(ntheta, nw)
|
87 |
+
# r grid: non-uniform spacing and Dirichlet boundary conditions
|
88 |
+
y = np.linspace(-np.pi / 2, np.pi / 2, nr + 2)
|
89 |
+
r_ = (R2 - R1) * (np.sin(y) / 2 + 0.5) + R1
|
90 |
+
dr1 = r_[1] - r_[0]
|
91 |
+
dr2 = r_[-1] - r_[-2]
|
92 |
+
r = r_[1:-1]
|
93 |
+
|
94 |
+
# 1D finite-difference operators
|
95 |
+
Dtheta_minus, Dtheta_plus = d_dx_upwind(theta, ntheta)
|
96 |
+
D2w = d2_dx2_fourth_order(w, nw)
|
97 |
+
Dr_minus, Dr_plus = d_dx_upwind_nonuniform(r_, nr)
|
98 |
+
|
99 |
+
# 3D quantities. Kronecker products are used to build the 3D difference operators
|
100 |
+
r_3D, theta_3D, w_3D = np.meshgrid(r, theta, w, indexing = 'ij')
|
101 |
+
I_r = sp.eye(nr)
|
102 |
+
I_theta = sp.eye(ntheta)
|
103 |
+
I_w = sp.eye(nw)
|
104 |
+
Dtheta_3D_minus = sp.kron(sp.kron(I_r, Dtheta_minus), I_w)
|
105 |
+
Dtheta_3D_plus = sp.kron(sp.kron(I_r, Dtheta_plus), I_w)
|
106 |
+
D2w_3D = sp.kron(sp.kron(I_r, I_theta), D2w)
|
107 |
+
Dr_3D_minus = sp.kron(sp.kron(Dr_minus, I_theta), I_w)
|
108 |
+
Dr_3D_plus = sp.kron(sp.kron(Dr_plus, I_theta), I_w)
|
109 |
+
|
110 |
+
# Metric tensor. Note that g_12 = g_21 = 0.
|
111 |
+
g_11 = np.ones_like(r_3D)
|
112 |
+
g_22_over_r = r_3D # divide out factor of r
|
113 |
+
|
114 |
+
# Dot products
|
115 |
+
dp_r = u_dot_rvec(theta_3D, w_3D)
|
116 |
+
dp_thetahat = u_dot_thetahat(theta_3D, w_3D)
|
117 |
+
|
118 |
+
# Coefficient of d / dr
|
119 |
+
Dr_3D_coeff_meshgrid = dp_r / g_11
|
120 |
+
test_ill_conditioned(Dr_3D_coeff_meshgrid)
|
121 |
+
Dr_3D_coeff = sp.diags(Dr_3D_coeff_meshgrid.ravel())
|
122 |
+
|
123 |
+
# Coefficient of d / dtheta
|
124 |
+
Dtheta_3D_coeff_meshgrid = dp_thetahat / g_22_over_r
|
125 |
+
Dtheta_3D_coeff = sp.diags(Dtheta_3D_coeff_meshgrid.ravel())
|
126 |
+
|
127 |
+
# Upwind differencing
|
128 |
+
Dr_3D_upwind = upwind_operator(Dr_3D_minus, Dr_3D_plus, Dr_3D_coeff_meshgrid)
|
129 |
+
Dtheta_3D_upwind = upwind_operator(Dtheta_3D_minus, Dtheta_3D_plus, Dtheta_3D_coeff_meshgrid)
|
130 |
+
|
131 |
+
# Full operator
|
132 |
+
L = Dr_3D_coeff @ Dr_3D_upwind + Dtheta_3D_coeff @ Dtheta_3D_upwind - (1 / ell) * D2w_3D
|
133 |
+
|
134 |
+
return L, r_3D, theta_3D, w_3D, dr1, dr2, Dr_3D_coeff_meshgrid
|
135 |
+
|
136 |
+
|
137 |
+
def get_boundary_quantities_circle(theta_3D, w_3D):
|
138 |
+
"""
|
139 |
+
Gets grid coordinates on the boundaries, as well as slice arrays
|
140 |
+
for positive/negative angles with respect to the boundary angle.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
theta_3D (numpy.ndarray): 3D array of theta values on the grid.
|
144 |
+
w_3D (numpy.ndarray): 3D array of w values on the grid.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
tuple: Tuple of the grid coordinates and slice arrays
|
148 |
+
"""
|
149 |
+
th1 = theta_3D[0, :, :]
|
150 |
+
wb1 = w_3D[0, :, :]
|
151 |
+
th2 = theta_3D[-1, :, :]
|
152 |
+
wb2 = w_3D[-1, :, :]
|
153 |
+
ib_slice = np.cos(th1 - wb1) > 0
|
154 |
+
ob_slice = np.cos(th2 - wb2) < 0
|
155 |
+
|
156 |
+
return th1, th2, wb1, wb2, ib_slice, ob_slice
|
stnn/pde/common.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import numpy as np
|
3 |
+
import scipy.sparse as sp
|
4 |
+
|
5 |
+
|
6 |
+
def d_dx_upwind(x, nx):
|
7 |
+
"""
|
8 |
+
Sparse matrix representation of d/dx using first-order left/right differences with Dirichlet boundary conditions
|
9 |
+
"""
|
10 |
+
dx = x[1] - x[0]
|
11 |
+
Dx_minus = sp.diags([-1, 1], [0, 1], shape = (nx, nx)).tolil() / dx
|
12 |
+
Dx_plus = sp.diags([-1, 1], [-1, 0], shape = (nx, nx)).tolil() / dx
|
13 |
+
Dx_minus[-1, 0] = 1 / dx
|
14 |
+
Dx_plus[0, -1] = -1 / dx
|
15 |
+
Dx_minus = Dx_minus.tocsr()
|
16 |
+
Dx_plus = Dx_plus.tocsr()
|
17 |
+
return Dx_minus, Dx_plus
|
18 |
+
|
19 |
+
|
20 |
+
def d2_dx2_fourth_order(x, nx):
|
21 |
+
"""
|
22 |
+
Sparse matrix representation of d^2/dx^2 using fourth order central differences and periodic boundary conditions
|
23 |
+
"""
|
24 |
+
dx = x[1] - x[0]
|
25 |
+
D2x = sp.diags([-1, 16, -30, 16, -1], [-2, -1, 0, 1, 2],
|
26 |
+
shape = (nx, nx)).tolil() / (12 * dx**2)
|
27 |
+
D2x[0, -1] = 16 / (12 * dx**2)
|
28 |
+
D2x[0, -2] = -1 / (12 * dx**2)
|
29 |
+
D2x[1, -1] = -1 / (12 * dx**2)
|
30 |
+
D2x[-1, 0] = 16 / (12 * dx**2)
|
31 |
+
D2x[-1, 1] = -1 / (12 * dx**2)
|
32 |
+
D2x[-2, 0] = -1 / (12 * dx**2)
|
33 |
+
D2x = D2x.tocsr()
|
34 |
+
return D2x
|
35 |
+
|
36 |
+
|
37 |
+
def d_dx_upwind_nonuniform(x, nx):
|
38 |
+
"""
|
39 |
+
Sparse matrix representation of d/dx on a nonuniform grid, using first-order left/right differences
|
40 |
+
with Dirichlet boundary conditions.
|
41 |
+
"""
|
42 |
+
Dx_ = np.diff(x)
|
43 |
+
Dx_minus = np.diff(x[1:])
|
44 |
+
Dx_minus_inv = 1 / Dx_minus
|
45 |
+
Dx_plus_inv = 1 / Dx_
|
46 |
+
Dx_minus = sp.diags([-Dx_minus_inv, Dx_minus_inv], [0, 1], shape = (nx, nx)).tolil()
|
47 |
+
Dx_plus = sp.diags([-Dx_plus_inv[1:], Dx_plus_inv[:-1]], [-1, 0], shape = (nx, nx)).tolil()
|
48 |
+
Dx_minus = Dx_minus.tocsr()
|
49 |
+
Dx_plus = Dx_plus.tocsr()
|
50 |
+
return Dx_minus, Dx_plus
|
51 |
+
|
52 |
+
|
53 |
+
def get_angular_grids(nx2, nx3):
|
54 |
+
"""
|
55 |
+
x2 / x3 grids: uniform spacing and periodic boundary conditions
|
56 |
+
The x3 grid has an offset to ensure cos(x2 - x3) != 0.
|
57 |
+
"""
|
58 |
+
x2 = np.linspace(-np.pi, np.pi, nx2, endpoint = False)
|
59 |
+
x3_min, x3_max = 0 + 0.125 * (2 * np.pi / nx3), 2 * np.pi + 0.125 * (2 * np.pi / nx3)
|
60 |
+
x3 = np.linspace(x3_min, x3_max, nx3, endpoint = False)
|
61 |
+
return x2, x3
|
62 |
+
|
63 |
+
|
64 |
+
def upwind_operator(Dx_minus, Dx_plus, Dx_coeff):
|
65 |
+
"""
|
66 |
+
Upwind finite difference operator.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
Dx_minus (scipy.sparse matrix): backward (minus) finite difference operator.
|
70 |
+
Dx_plus (scipy.sparse matrix): forward (plus) finite difference operator.
|
71 |
+
Dx_coeff (numpy.ndarray): coefficient array
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
scipy.sparse matrix: The upwind operator
|
75 |
+
"""
|
76 |
+
mask_x = Dx_coeff <= 0
|
77 |
+
Dx_masked_minus = sp.diags(mask_x.ravel().astype(int)) @ Dx_minus
|
78 |
+
Dx_masked_plus = sp.diags((~mask_x).ravel().astype(int)) @ Dx_plus
|
79 |
+
Dx_upwind = Dx_masked_minus + Dx_masked_plus
|
80 |
+
return Dx_upwind
|
81 |
+
|
82 |
+
|
83 |
+
def test_ill_conditioned(Dx_coeff):
|
84 |
+
"""
|
85 |
+
Test for ill-conditioning. The thresholds are heuristic only.
|
86 |
+
"""
|
87 |
+
ill_conditioning_test = np.min(np.abs(Dx_coeff.ravel()))
|
88 |
+
if ill_conditioning_test < 1e-10:
|
89 |
+
raise ValueError(f'System is ill-conditioned; min |Dx1_coeff| = {ill_conditioning_test}')
|
90 |
+
elif ill_conditioning_test < 1e-6:
|
91 |
+
warnings.warn(f'System may be ill-conditioned; min |Dx1_coeff| = {ill_conditioning_test}')
|
stnn/pde/ellipse.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import *
|
2 |
+
|
3 |
+
|
4 |
+
def u_dot_muvec(mu, eta, w):
|
5 |
+
"""
|
6 |
+
Dot product of u = (cos(w), sin(w)) with the coordinate vector for mu.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
mu (float or array-like): The mu coordinate(s).
|
10 |
+
eta (float or array-like): The eta coordinate(s).
|
11 |
+
w (float or array-like): The w coordinate(s).
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
numpy.ndarray: The calculated dot product for each point.
|
15 |
+
"""
|
16 |
+
return (0.5 * np.cosh(mu) * np.cos(eta - w) + 0.5 * np.cosh(mu) * np.cos(eta + w)
|
17 |
+
+ 0.5 * np.sinh(mu) * np.cos(eta - w) - 0.5 * np.sinh(mu) * np.cos(eta + w))
|
18 |
+
|
19 |
+
|
20 |
+
def u_dot_etavec(mu, eta, w):
|
21 |
+
"""
|
22 |
+
Dot product of u = (cos(w), sin(w)) with the coordinate vector for eta.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
mu (float or array-like): The mu coordinate(s).
|
26 |
+
eta (float or array-like): The eta coordinate(s).
|
27 |
+
w (float or array-like): The w coordinate(s).
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
numpy.ndarray: The calculated dot product for each point.
|
31 |
+
"""
|
32 |
+
return (-0.5 * np.sinh(mu) * np.sin(eta + w) - 0.5 * np.sinh(mu) * np.sin(eta - w)
|
33 |
+
+ 0.5 * np.cosh(mu) * np.sin(eta + w) - 0.5 * np.cosh(mu) * np.sin(eta - w))
|
34 |
+
|
35 |
+
|
36 |
+
def get_system_ellipse(config):
|
37 |
+
"""
|
38 |
+
For an elliptical geometry, constructs the matrices, grids, and other quantities corresponding to the PDE system
|
39 |
+
specified by "config"".
|
40 |
+
|
41 |
+
Args:
|
42 |
+
config (dict): Configuration dictionary containing the system parameters.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
tuple: A tuple containing matrices, grids, etc. for the PDE system
|
46 |
+
"""
|
47 |
+
required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2', 'eccentricity']
|
48 |
+
optional_keys = []
|
49 |
+
|
50 |
+
missing_keys = [key for key in required_keys if key not in config]
|
51 |
+
if missing_keys:
|
52 |
+
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
|
53 |
+
|
54 |
+
unused_keys = [key for key in config if key not in required_keys + optional_keys]
|
55 |
+
if unused_keys:
|
56 |
+
warnings.warn(f"Unused keys in config: {', '.join(unused_keys)}")
|
57 |
+
|
58 |
+
for key in ['nx1', 'nx2', 'nx3']:
|
59 |
+
if not isinstance(config[key], int):
|
60 |
+
raise TypeError(f"{key} must be an integer.")
|
61 |
+
|
62 |
+
for key in ['nx1', 'nx2', 'nx3', 'ell']:
|
63 |
+
if config[key] <= 0:
|
64 |
+
raise ValueError(f"{key} must be positive.")
|
65 |
+
|
66 |
+
if not (0 <= config['eccentricity'] < 1.0):
|
67 |
+
raise ValueError('eccentricity must be >= 0 and < 1.')
|
68 |
+
|
69 |
+
if config['a2'] <= config['eccentricity']:
|
70 |
+
raise ValueError(f'a2 must be greater than the eccentricity.')
|
71 |
+
|
72 |
+
nmu, neta, nw = config['nx1'], config['nx2'], config['nx3']
|
73 |
+
minor_axis_outer = config['a2']
|
74 |
+
ell = config['ell']
|
75 |
+
minor_axis = 1.0 - config['eccentricity']
|
76 |
+
major_axis = 1.0
|
77 |
+
focal_distance = np.sqrt(major_axis**2 - minor_axis**2)
|
78 |
+
mu1 = np.arccosh(major_axis / focal_distance)
|
79 |
+
major_axis_outer = np.sqrt(focal_distance**2 + minor_axis_outer**2)
|
80 |
+
mu2 = np.arccosh(major_axis_outer / focal_distance)
|
81 |
+
|
82 |
+
# 1D grids
|
83 |
+
eta, w = get_angular_grids(neta, nw)
|
84 |
+
# mu grid: non-uniform spacing and Dirichlet boundary conditions
|
85 |
+
y = np.linspace(-np.pi / 2, np.pi / 2, nmu + 2, dtype = np.float64)
|
86 |
+
mu_ = np.log((np.exp(mu2) - np.exp(mu1)) * (np.sin(y) / 2 + 0.5) + np.exp(mu1))
|
87 |
+
dmu1 = mu_[1] - mu_[0]
|
88 |
+
dmu2 = mu_[-1] - mu_[-2]
|
89 |
+
mu = mu_[1:-1]
|
90 |
+
|
91 |
+
# 1D finite-difference operators
|
92 |
+
Deta_minus, Deta_plus = d_dx_upwind(eta, neta)
|
93 |
+
D2w = d2_dx2_fourth_order(w, nw)
|
94 |
+
Dmu_minus, Dmu_plus = d_dx_upwind_nonuniform(mu_, nmu)
|
95 |
+
|
96 |
+
# 3D quantities. Kronecker products are used to build the 3D difference operators
|
97 |
+
mu_3D, eta_3D, w_3D = np.meshgrid(mu, eta, w, indexing = 'ij')
|
98 |
+
I_mu = sp.eye(nmu)
|
99 |
+
I_eta = sp.eye(neta)
|
100 |
+
I_w = sp.eye(nw)
|
101 |
+
Deta_3D_minus = sp.kron(sp.kron(I_mu, Deta_minus), I_w)
|
102 |
+
Deta_3D_plus = sp.kron(sp.kron(I_mu, Deta_plus), I_w)
|
103 |
+
D2w_3D = sp.kron(sp.kron(I_mu, I_eta), D2w)
|
104 |
+
Dmu_3D_minus = sp.kron(sp.kron(Dmu_minus, I_eta), I_w)
|
105 |
+
Dmu_3D_plus = sp.kron(sp.kron(Dmu_plus, I_eta), I_w)
|
106 |
+
|
107 |
+
# Metric tensor. Note that g_12 = g_21 = 0 and g_11 = g_22.
|
108 |
+
g_11 = focal_distance * (np.cosh(mu_3D) * np.cosh(mu_3D) * np.cos(eta_3D) * np.cos(eta_3D)
|
109 |
+
+ np.sinh(mu_3D) * np.sinh(mu_3D) * np.sin(eta_3D) * np.sin(eta_3D))
|
110 |
+
|
111 |
+
# Dot products
|
112 |
+
dp_mu = u_dot_muvec(mu_3D, eta_3D, w_3D)
|
113 |
+
dp_eta = u_dot_etavec(mu_3D, eta_3D, w_3D)
|
114 |
+
|
115 |
+
# Coefficient of d / dmu
|
116 |
+
Dmu_3D_coeff_meshgrid = dp_mu / g_11
|
117 |
+
test_ill_conditioned(Dmu_3D_coeff_meshgrid)
|
118 |
+
|
119 |
+
Dmu_3D_coeff = sp.diags(Dmu_3D_coeff_meshgrid.ravel())
|
120 |
+
|
121 |
+
# Coefficient of d / deta
|
122 |
+
Deta_3D_coeff_meshgrid = dp_eta / g_11
|
123 |
+
Deta_3D_coeff = sp.diags(Deta_3D_coeff_meshgrid.ravel())
|
124 |
+
|
125 |
+
# Upwind differencing
|
126 |
+
Dmu_3D_upwind = upwind_operator(Dmu_3D_minus, Dmu_3D_plus, Dmu_3D_coeff_meshgrid)
|
127 |
+
Deta_3D_upwind = upwind_operator(Deta_3D_minus, Deta_3D_plus, Deta_3D_coeff_meshgrid)
|
128 |
+
|
129 |
+
# Full operator
|
130 |
+
L = Dmu_3D_coeff @ Dmu_3D_upwind + Deta_3D_coeff @ Deta_3D_upwind - (1 / ell) * D2w_3D
|
131 |
+
|
132 |
+
return L, mu_3D, eta_3D, w_3D, dmu1, dmu2, Dmu_3D_coeff_meshgrid, major_axis_outer
|
133 |
+
|
134 |
+
|
135 |
+
def get_boundary_quantities_ellipse(mu_3D, eta_3D, w_3D):
|
136 |
+
"""
|
137 |
+
Gets grid coordinates on the boundaries, as well as slice arrays
|
138 |
+
for positive/negative angles with respect to the boundary angle.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
mu_3D: 3D array of mu values on the grid.
|
142 |
+
eta_3D: 3D array of eta values on the grid.
|
143 |
+
w_3D (numpy.ndarray): 3D array of w values on the grid.
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
tuple: Tuple of the grid coordinates and slice arrays
|
147 |
+
"""
|
148 |
+
eta_2D_ib = eta_3D[0, ...]
|
149 |
+
eta_2D_ob = eta_3D[-1, ...]
|
150 |
+
w_2D_ib = w_3D[0, ...]
|
151 |
+
w_2D_ob = w_3D[-1, ...]
|
152 |
+
ib_slice = u_dot_muvec(mu_3D, eta_3D, w_3D)[0, ...] > 0
|
153 |
+
ob_slice = u_dot_muvec(mu_3D, eta_3D, w_3D)[-1, ...] < 0
|
154 |
+
return eta_2D_ib, eta_2D_ob, w_2D_ib, w_2D_ob, ib_slice, ob_slice
|
stnn/pde/pde_system.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from stnn.data.function_generators import generate_random_functions
|
4 |
+
from .circle import get_system_circle, get_boundary_quantities_circle, u_dot_thetahat
|
5 |
+
from .ellipse import (get_system_ellipse, get_boundary_quantities_ellipse, u_dot_etavec)
|
6 |
+
|
7 |
+
|
8 |
+
class PDESystem:
|
9 |
+
"""
|
10 |
+
Constructs the PDE system given input parameters. The finite-difference matrices, grids, and other relevant
|
11 |
+
quantities are available as attributes.
|
12 |
+
|
13 |
+
Constructor Args:
|
14 |
+
params (dict): Configuration dictionary containing the parameters that define the PDE system.
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
|
18 |
+
ib_slice (numpy.ndarray): Boolean array defining nodes adjacent to the inner boundary
|
19 |
+
ob_slice (numpy.ndarray): Boolean array defining nodes adjacent to the outer boundary
|
20 |
+
|
21 |
+
x2_ib (numpy.ndarray): The x2 coordinate values at the inner boundary.
|
22 |
+
x2_ob (numpy.ndarray): The x2 coordinate values at the outer boundary.
|
23 |
+
x3_ib (numpy.ndarray): The x3 coordinate values at the inner boundary.
|
24 |
+
x3_ob (numpy.ndarray): The x3 coordinate values at the outer boundary.
|
25 |
+
|
26 |
+
Dx1_coeff (numpy.ndarray): Coefficients for the advection operator in the radial direction. Used for converting
|
27 |
+
boundary conditions to the r.h.s. of the linear system defining a
|
28 |
+
boundary-value problem.
|
29 |
+
|
30 |
+
dx1a (numpy.ndarray): Grid spacing adjacent to the inner boundary
|
31 |
+
dx1b (numpy.ndarray): Grid spacing adjacent to the outer boundary
|
32 |
+
|
33 |
+
L (numpy.ndarray): Finite-difference representation of the linear operator defining the PDE.
|
34 |
+
|
35 |
+
x1 (numpy.ndarray): The grid values in radial coordinate (r or mu)
|
36 |
+
x2 (numpy.ndarray): The grid values in the angular coordinate (theta or eta).
|
37 |
+
x3 (numpy.ndarray): The grid values in the w coordinate
|
38 |
+
|
39 |
+
a1 (float): Minor axis of the inner boundary
|
40 |
+
a2 (float): Minor axis of the outer boundary
|
41 |
+
b1 (float): Major axis of the inner boundary
|
42 |
+
b2 (float): Major axis of the outer boundary
|
43 |
+
|
44 |
+
_coords (str): The type of coordinate system used ('ellipse' or 'circle'). This affects how the grids and
|
45 |
+
other geometric properties are calculated.
|
46 |
+
|
47 |
+
params (dict): The configuration dictionary containing the PDE parameters.
|
48 |
+
"""
|
49 |
+
def __init__(self, params):
|
50 |
+
self._required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2', 'eccentricity']
|
51 |
+
self._optional_keys = []
|
52 |
+
self.ib_slice = None
|
53 |
+
self.ob_slice = None
|
54 |
+
self.x2_ib = None
|
55 |
+
self.x2_ob = None
|
56 |
+
self.x3_ib = None
|
57 |
+
self.x3_ob = None
|
58 |
+
self.Dx1_coeff = None
|
59 |
+
self.dx1b = None
|
60 |
+
self.dx1a = None
|
61 |
+
self.L = None
|
62 |
+
self.x1 = None
|
63 |
+
self.x2 = None
|
64 |
+
self.x3 = None
|
65 |
+
self.a1 = None
|
66 |
+
self.a2 = None
|
67 |
+
self.b1 = None
|
68 |
+
self.b2 = None
|
69 |
+
self._coords = None
|
70 |
+
self.params = params
|
71 |
+
self.initialize()
|
72 |
+
|
73 |
+
def initialize(self):
|
74 |
+
"""
|
75 |
+
Constructs the system matrices and vectors based on the stored configuration.
|
76 |
+
|
77 |
+
Depending on the 'eccentricity' parameter in `self.params`, the coordinate system is set
|
78 |
+
to either 'circle' or 'ellipse'; the domain parametrization and finite-difference grid are defined
|
79 |
+
accordingly.
|
80 |
+
"""
|
81 |
+
params = self.params
|
82 |
+
missing_keys = [key for key in self._required_keys if key not in params]
|
83 |
+
if missing_keys:
|
84 |
+
raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
|
85 |
+
|
86 |
+
# The functions 'get_system_circle' and 'get_system_ellipse' have a fair amount
|
87 |
+
# of overlap and probably should be combined, but for now they are kept separate
|
88 |
+
# for simplicity and readability.
|
89 |
+
if params['eccentricity'] < 1e-7:
|
90 |
+
self._coords = 'circle'
|
91 |
+
L, x1, x2, x3, dx1a, dx1b, Dx1_coeff = get_system_circle(params)
|
92 |
+
x2_ib, x2_ob, x3_ib, x3_ob, ib_slice, ob_slice = get_boundary_quantities_circle(x2, x3)
|
93 |
+
self.b2 = params['a2']
|
94 |
+
else:
|
95 |
+
self._coords = 'ellipse'
|
96 |
+
(L, x1, x2, x3, dx1a, dx1b, Dx1_coeff, major_axis_outer) = get_system_ellipse(params)
|
97 |
+
x2_ib, x2_ob, x3_ib, x3_ob, ib_slice, ob_slice = get_boundary_quantities_ellipse(x1, x2, x3)
|
98 |
+
self.b2 = major_axis_outer
|
99 |
+
|
100 |
+
self.a1 = 1.0 - params['eccentricity']
|
101 |
+
self.a2 = params['a2']
|
102 |
+
self.b1 = 1.0
|
103 |
+
|
104 |
+
self.x1 = x1
|
105 |
+
self.x2 = x2
|
106 |
+
self.x3 = x3
|
107 |
+
|
108 |
+
self.L = L
|
109 |
+
|
110 |
+
self.dx1a = dx1a
|
111 |
+
self.dx1b = dx1b
|
112 |
+
self.Dx1_coeff = Dx1_coeff
|
113 |
+
|
114 |
+
self.x2_ib = x2_ib
|
115 |
+
self.x2_ob = x2_ob
|
116 |
+
self.x3_ib = x3_ib
|
117 |
+
self.x3_ob = x3_ob
|
118 |
+
self.ib_slice = ib_slice
|
119 |
+
self.ob_slice = ob_slice
|
120 |
+
|
121 |
+
def generate_random_bc(self, func_gen_id):
|
122 |
+
"""
|
123 |
+
Generates random boundary conditions for the PDE system.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
func_gen_id (int): Integer representing the type of 'function generator' used to construct the
|
127 |
+
boundary conditions.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
tuple: A tuple containing:
|
131 |
+
- ibf_data: Inner boundary data
|
132 |
+
- obf_data: Outer boundary data
|
133 |
+
- b: 3D array for passing to the GMRES solver. 'b' contains the boundary data but is defined
|
134 |
+
on the full 3D grid.
|
135 |
+
- bf: Flattened boundary data before it is permuted
|
136 |
+
|
137 |
+
The boundary conditions are defined on the inner and outer boundaries of the domain and are denoted
|
138 |
+
by 'ibf_data' and 'obf_data'. The function passes 'ibf_data' and 'obf_data' through 'convert_boundary_data',
|
139 |
+
which converts them into formats suitable for passing into the GMRES solver and STNN model (e.g., by
|
140 |
+
reshaping and permutation operations).
|
141 |
+
"""
|
142 |
+
|
143 |
+
# Note the change of variable (x2, x3) -> (x2, x2 - x3).
|
144 |
+
ibf_data = generate_random_functions(1, self.x2_ib, self.x2_ib - self.x3_ib,
|
145 |
+
max_freq=self.params['nx3'], func_gen_id = func_gen_id)[0, self.ib_slice]
|
146 |
+
obf_data = generate_random_functions(1, self.x2_ob, self.x2_ob - self.x3_ob,
|
147 |
+
max_freq=self.params['nx3'], func_gen_id = func_gen_id)[0, self.ob_slice]
|
148 |
+
|
149 |
+
# Combine boundary data in single vector
|
150 |
+
bf = np.concatenate([ibf_data, obf_data], axis = -1).flatten()
|
151 |
+
|
152 |
+
# Permutes 'ibf_data' and 'obf_data' and construct 'b'
|
153 |
+
ibf_data, obf_data, b = self.convert_boundary_data(ibf_data, obf_data)
|
154 |
+
|
155 |
+
return ibf_data, obf_data, b, bf
|
156 |
+
|
157 |
+
def convert_boundary_data(self, ibf_data, obf_data):
|
158 |
+
"""
|
159 |
+
Converts boundary data into formats suitable for passing into the GMRES solver and STNN model.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
ibf_data: Inner boundary data
|
163 |
+
obf_data: Outer boundary data
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
tuple: A tuple containing:
|
167 |
+
- ibf_data: Inner boundary data, permuted to match the input structure of the EinsumTTL layer.
|
168 |
+
- obf_data: Outer boundary data, permuted to match the input structure of the EinsumTTL layer
|
169 |
+
- b: 3D array for passing to the GMRES solver. 'b' contains the boundary data but is defined
|
170 |
+
on the full 3D grid.
|
171 |
+
"""
|
172 |
+
nx1, nx2, nx3 = self.params['nx1'], self.params['nx2'], self.params['nx3']
|
173 |
+
b = np.zeros((nx1, nx2, nx3), dtype=np.float64)
|
174 |
+
b[0, self.ib_slice] = self.Dx1_coeff[0, self.ib_slice] * (ibf_data / self.dx1a)
|
175 |
+
b[-1, self.ob_slice] = -self.Dx1_coeff[-1, self.ob_slice] * (obf_data / self.dx1b)
|
176 |
+
|
177 |
+
if self._coords == 'ellipse':
|
178 |
+
sin_angle = u_dot_etavec(self.x1, self.x2, self.x3)
|
179 |
+
elif self._coords == 'circle':
|
180 |
+
sin_angle = u_dot_thetahat(self.x2, self.x3)
|
181 |
+
else:
|
182 |
+
raise ValueError(f'"_coords" attribute should be either "ellipse" or "circle"; instead received {self._coords}')
|
183 |
+
|
184 |
+
# reshape and permute 'ibf_data' and 'obf_data'
|
185 |
+
sin_angle_i = sin_angle[0, self.ib_slice].reshape(nx2, nx3 // 2)
|
186 |
+
sin_angle_o = sin_angle[-1, self.ob_slice].reshape(nx2, nx3 // 2)
|
187 |
+
W_I = np.argsort(sin_angle_i, axis=1)
|
188 |
+
W_O = np.argsort(sin_angle_o, axis=1)
|
189 |
+
ibf_data = ibf_data.reshape((nx2, nx3 // 2))
|
190 |
+
obf_data = obf_data.reshape((nx2, nx3 // 2))
|
191 |
+
for n in range(nx2):
|
192 |
+
ibf_data[n, :] = ibf_data[n, W_I[n, :]]
|
193 |
+
obf_data[n, :] = obf_data[n, W_O[n, :]]
|
194 |
+
|
195 |
+
return ibf_data, obf_data, b
|
196 |
+
|
197 |
+
def get_xy_grids(self):
|
198 |
+
"""
|
199 |
+
Converts the native grids of the PDE system to xy coordinates (no interpolation).
|
200 |
+
|
201 |
+
The function also applies a wrap-around in the x2 domain for plotting purposes, ensuring
|
202 |
+
continuity across the (periodic) domain.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
tuple of numpy.ndarray: A tuple containing two 2D numpy arrays:
|
206 |
+
- x_grid: The x-coordinates grid.
|
207 |
+
- y_grid: The y-coordinates grid.
|
208 |
+
"""
|
209 |
+
x1_1D = self.x1[:, 0, 0]
|
210 |
+
x2_1D = self.x2[0, :, 0]
|
211 |
+
x2_1D = np.append(x2_1D, np.array([np.pi - 1e-3])) # wrap around for plotting
|
212 |
+
x1_2D, x2_2D = np.meshgrid(x1_1D, x2_1D, indexing='ij')
|
213 |
+
if self._coords == 'ellipse':
|
214 |
+
focal_distance = np.sqrt(self.b1**2 - self.a1**2)
|
215 |
+
x_grid = focal_distance * np.sinh(x1_2D) * np.cos(x2_2D)
|
216 |
+
y_grid = focal_distance * np.cosh(x1_2D) * np.sin(x2_2D)
|
217 |
+
elif self._coords == 'circle':
|
218 |
+
x_grid = x1_2D * np.cos(x2_2D)
|
219 |
+
y_grid = x1_2D * np.sin(x2_2D)
|
220 |
+
else:
|
221 |
+
raise ValueError(f'"_coords" should be either "ellipse" or "circle"; instead received {self._coords}')
|
222 |
+
|
223 |
+
return x_grid, y_grid
|
stnn/tests/test_circle.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
import scipy.sparse as sp
|
5 |
+
from stnn.pde.circle import get_system_circle
|
6 |
+
|
7 |
+
|
8 |
+
class TestGetSystemCircle(unittest.TestCase):
|
9 |
+
|
10 |
+
def setUp(self):
|
11 |
+
self.config = {
|
12 |
+
'nx1': 10,
|
13 |
+
'nx2': 20,
|
14 |
+
'nx3': 30,
|
15 |
+
'a2': 2.0,
|
16 |
+
'ell': 1.5,
|
17 |
+
}
|
18 |
+
self.saved_config = copy.deepcopy(self.config)
|
19 |
+
self._required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2']
|
20 |
+
self._optional_keys = []
|
21 |
+
|
22 |
+
def test_valid_output(self):
|
23 |
+
L, r_3D, theta_3D, w_3D, dr1, dr2, Dr_3D_coeff_meshgrid = get_system_circle(self.config)
|
24 |
+
|
25 |
+
# Test types
|
26 |
+
self.assertIsInstance(L, sp.csr_matrix)
|
27 |
+
self.assertIsInstance(r_3D, np.ndarray)
|
28 |
+
self.assertIsInstance(theta_3D, np.ndarray)
|
29 |
+
self.assertIsInstance(w_3D, np.ndarray)
|
30 |
+
self.assertIsInstance(Dr_3D_coeff_meshgrid, np.ndarray)
|
31 |
+
|
32 |
+
# Test shapes
|
33 |
+
self.assertEqual(r_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
|
34 |
+
self.assertEqual(theta_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
|
35 |
+
self.assertEqual(w_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
|
36 |
+
|
37 |
+
# Test values
|
38 |
+
self.assertTrue(dr1 > 0)
|
39 |
+
self.assertTrue(dr2 > 0)
|
40 |
+
|
41 |
+
def test_missing_keys(self):
|
42 |
+
for key in self._required_keys:
|
43 |
+
del self.config[key]
|
44 |
+
with self.assertRaises(KeyError):
|
45 |
+
get_system_circle(self.config)
|
46 |
+
self.config[key] = self.saved_config[key]
|
47 |
+
|
48 |
+
def test_invalid_parameters(self):
|
49 |
+
for key in self._required_keys:
|
50 |
+
self.config[key] = 0 # None of the required keys should be zero.
|
51 |
+
with self.assertRaises(ValueError):
|
52 |
+
get_system_circle(self.config)
|
53 |
+
self.config[key] = self.saved_config[key]
|
54 |
+
|
55 |
+
def test_unused_params_warning(self):
|
56 |
+
copied_params = copy.deepcopy(self.config)
|
57 |
+
copied_params['unusedkey'] = 0
|
58 |
+
with self.assertWarns(UserWarning) as _:
|
59 |
+
get_system_circle(copied_params)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == '__main__':
|
63 |
+
unittest.main()
|
stnn/tests/test_dependencies.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
|
5 |
+
class TestDependencies(unittest.TestCase):
|
6 |
+
def test_required_modules_installed(self):
|
7 |
+
required_modules = [
|
8 |
+
'numpy',
|
9 |
+
'tensorflow',
|
10 |
+
'numpy',
|
11 |
+
't3f',
|
12 |
+
'scipy',
|
13 |
+
'h5py',
|
14 |
+
'matplotlib',
|
15 |
+
'pydot',
|
16 |
+
'openvino',
|
17 |
+
]
|
18 |
+
|
19 |
+
for module in required_modules:
|
20 |
+
with self.subTest(module = module):
|
21 |
+
importlib.import_module(module)
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == '__main__':
|
25 |
+
unittest.main()
|
stnn/tests/test_differential_ops.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
from stnn.pde.common import d_dx_upwind, d2_dx2_fourth_order, d_dx_upwind_nonuniform
|
4 |
+
|
5 |
+
|
6 |
+
def assert_derivative(operator, function, expected_derivative, boundary = None, rtol = None, atol = None):
|
7 |
+
"""
|
8 |
+
Assert the derivative of a function using a given finite-difference operator
|
9 |
+
|
10 |
+
Args:
|
11 |
+
operator (np.ndarray or sparse matrix): Differential operator matrix.
|
12 |
+
function (np.ndarray): Values of the function to differentiate
|
13 |
+
expected_derivative (np.ndarray):: Expected result of the derivative.
|
14 |
+
boundary (int): Specifies if boundary elements should be excluded, and how many.
|
15 |
+
rtol: Relative tolerance.
|
16 |
+
atol: Absolute tolerance.
|
17 |
+
"""
|
18 |
+
observed_derivative = operator @ function
|
19 |
+
if boundary is not None:
|
20 |
+
observed_derivative = observed_derivative[boundary:-boundary]
|
21 |
+
expected_derivative = expected_derivative[boundary:-boundary]
|
22 |
+
|
23 |
+
np.testing.assert_allclose(observed_derivative, expected_derivative, rtol = rtol, atol = atol)
|
24 |
+
|
25 |
+
|
26 |
+
class TestDifferentialOperators(unittest.TestCase):
|
27 |
+
|
28 |
+
def setUp(self):
|
29 |
+
# ----- Set up grids
|
30 |
+
|
31 |
+
# grid for radial coordinate (non-periodic, non-uniform spacing)
|
32 |
+
self.nx = 100000
|
33 |
+
zs = np.linspace(-np.pi / 2, np.pi / 2, self.nx + 2)
|
34 |
+
R1, R2 = 0.5, 1.4
|
35 |
+
self.z_ = (R2 - R1) * (np.sin(zs) / 2 + 0.5) + R1
|
36 |
+
self.z = self.z_[1:-1]
|
37 |
+
|
38 |
+
# grid for angular coordinates (periodic, uniform spacing)
|
39 |
+
self.y = np.linspace(0, 2 * np.pi, self.nx, endpoint = False)
|
40 |
+
dy = (2 * np.pi) / self.nx
|
41 |
+
|
42 |
+
# ----- Set tolerances
|
43 |
+
|
44 |
+
# Tolerances for "exact" tests, i.e., where the finite differences do not have truncation error
|
45 |
+
self.atol = 1e-6
|
46 |
+
self.rtol = 1e-6
|
47 |
+
|
48 |
+
# Tolerances for some inexact tests
|
49 |
+
self.atol_inexact = 1e-3
|
50 |
+
self.rtol_z = 3 * np.max(np.diff(self.z_)) # relative tolerance of 3*dx for first-order one-sided differences
|
51 |
+
self.rtol_y = 3 * dy # relative tolerance of 3*dx for first-order one-sided differences
|
52 |
+
self.rtol_y2 = 3 * dy**4 # relative tolerance of 3*dy**4 for fourth-order central differences
|
53 |
+
|
54 |
+
# ----- Test inputs
|
55 |
+
|
56 |
+
# "Exact" test inputs
|
57 |
+
self.f1 = -2.3 * np.ones(self.nx)
|
58 |
+
self.f2 = 0.8 * self.z
|
59 |
+
self.f3 = 0.8 * self.y
|
60 |
+
self.f4 = -0.1 * self.y * self.y
|
61 |
+
|
62 |
+
# Inexact test inputs
|
63 |
+
# noinspection PyRedundantParentheses
|
64 |
+
self.g1 = (self.z)**2 - (self.z)**3
|
65 |
+
self.g2 = (np.cos(self.y))**2 - (np.sin(self.y))**3
|
66 |
+
self.g3 = (np.sin(2 * self.y))**2 - (np.cos(self.y))**3
|
67 |
+
|
68 |
+
def test_d_dx_upwind(self):
|
69 |
+
dx_m, dx_p = d_dx_upwind(self.y, self.nx)
|
70 |
+
|
71 |
+
assert_derivative(dx_m, self.f1, np.zeros(self.nx), rtol = self.rtol, atol = self.atol)
|
72 |
+
assert_derivative(dx_p, self.f1, np.zeros(self.nx), rtol = self.rtol, atol = self.atol)
|
73 |
+
|
74 |
+
assert_derivative(dx_m, self.f3, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
|
75 |
+
assert_derivative(dx_p, self.f3, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
|
76 |
+
|
77 |
+
expected_dg2 = -2 * np.cos(self.y) * np.sin(self.y) - 3 * np.sin(self.y)**2 * np.cos(self.y)
|
78 |
+
assert_derivative(dx_m, self.g2, expected_dg2, rtol = self.rtol_y, atol = self.atol_inexact)
|
79 |
+
assert_derivative(dx_p, self.g2, expected_dg2, rtol = self.rtol_y, atol = self.atol_inexact)
|
80 |
+
|
81 |
+
expected_dg3 = 4 * np.sin(2 * self.y) * np.cos(2 * self.y) + 3 * np.cos(self.y)**2 * np.sin(self.y)
|
82 |
+
assert_derivative(dx_m, self.g3, expected_dg3, rtol = self.rtol_y, atol = self.atol_inexact)
|
83 |
+
assert_derivative(dx_p, self.g3, expected_dg3, rtol = self.rtol_y, atol = self.atol_inexact)
|
84 |
+
|
85 |
+
def test_d2_dx2_fourth_order(self):
|
86 |
+
d2x = d2_dx2_fourth_order(self.y, self.nx)
|
87 |
+
|
88 |
+
assert_derivative(d2x, self.f1, np.zeros(self.nx), rtol = self.rtol, atol = self.atol)
|
89 |
+
assert_derivative(d2x, self.f3, np.zeros(self.nx), boundary = 4, rtol = self.rtol, atol = self.atol)
|
90 |
+
assert_derivative(d2x, self.f4, -0.2 * np.ones(self.nx), boundary = 4, rtol = self.rtol, atol = self.atol)
|
91 |
+
|
92 |
+
expected_dg2 = 2 * np.sin(self.y)**2 - 2 * np.cos(self.y)**2 - \
|
93 |
+
6 * np.sin(self.y) * np.cos(self.y)**2 + 3 * np.sin(self.y)**3
|
94 |
+
assert_derivative(d2x, self.g2, expected_dg2, rtol = self.rtol_y2, atol = self.atol_inexact)
|
95 |
+
|
96 |
+
expected_dg3 = 8 * np.cos(2 * self.y)**2 - 8 * np.sin(2 * self.y)**2 - \
|
97 |
+
6 * np.cos(self.y) * np.sin(self.y)**2 + 3 * np.cos(self.y)**3
|
98 |
+
assert_derivative(d2x, self.g3, expected_dg3, rtol = self.rtol_y2, atol = self.atol_inexact)
|
99 |
+
|
100 |
+
def test_d_dx_upwind_nonuniform(self):
|
101 |
+
dx_minus, dx_plus = d_dx_upwind_nonuniform(self.z_, self.nx)
|
102 |
+
|
103 |
+
assert_derivative(dx_minus, self.f1, np.zeros(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
|
104 |
+
assert_derivative(dx_plus, self.f1, np.zeros(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
|
105 |
+
|
106 |
+
assert_derivative(dx_minus, self.f2, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
|
107 |
+
assert_derivative(dx_plus, self.f2, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
|
108 |
+
|
109 |
+
expected_dg1 = (2 * self.z - 3 * self.z**2)
|
110 |
+
assert_derivative(dx_minus, self.g1, expected_dg1, boundary = 1, rtol = self.rtol_z, atol = self.atol_inexact)
|
111 |
+
assert_derivative(dx_plus, self.g1, expected_dg1, boundary = 1, rtol = self.rtol_z, atol = self.atol_inexact)
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == '__main__':
|
115 |
+
unittest.main()
|
stnn/tests/test_ellipse.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import copy
|
3 |
+
import numpy as np
|
4 |
+
import scipy.sparse as sp
|
5 |
+
from stnn.pde.ellipse import get_system_ellipse
|
6 |
+
|
7 |
+
|
8 |
+
class TestGetSystemCircle(unittest.TestCase):
|
9 |
+
|
10 |
+
def setUp(self):
|
11 |
+
self.config = {
|
12 |
+
'nx1': 10,
|
13 |
+
'nx2': 20,
|
14 |
+
'nx3': 30,
|
15 |
+
'a2': 2.0,
|
16 |
+
'ell': 1.5,
|
17 |
+
'eccentricity': 0.5
|
18 |
+
}
|
19 |
+
self.saved_config = copy.deepcopy(self.config)
|
20 |
+
self._required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2', 'eccentricity']
|
21 |
+
self._optional_keys = []
|
22 |
+
|
23 |
+
def test_valid_output(self):
|
24 |
+
L, mu_3D, eta_3D, w_3D, dmu1, dmu2, Dmu_3D_coeff_meshgrid, major_axis_outer = get_system_ellipse(self.config)
|
25 |
+
|
26 |
+
# Test types
|
27 |
+
self.assertIsInstance(L, sp.csr_matrix)
|
28 |
+
self.assertIsInstance(mu_3D, np.ndarray)
|
29 |
+
self.assertIsInstance(eta_3D, np.ndarray)
|
30 |
+
self.assertIsInstance(w_3D, np.ndarray)
|
31 |
+
self.assertIsInstance(Dmu_3D_coeff_meshgrid, np.ndarray)
|
32 |
+
|
33 |
+
# Test shapes
|
34 |
+
self.assertEqual(mu_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
|
35 |
+
self.assertEqual(eta_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
|
36 |
+
self.assertEqual(w_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
|
37 |
+
N = self.config['nx1'] * self.config['nx2'] * self.config['nx3']
|
38 |
+
self.assertEqual(L.shape, (N, N))
|
39 |
+
|
40 |
+
# Test values
|
41 |
+
self.assertTrue(dmu1 > 0)
|
42 |
+
self.assertTrue(dmu2 > 0)
|
43 |
+
self.assertTrue(major_axis_outer > self.config['a2'])
|
44 |
+
|
45 |
+
def test_missing_keys(self):
|
46 |
+
for key in self._required_keys:
|
47 |
+
del self.config[key]
|
48 |
+
with self.assertRaises(KeyError):
|
49 |
+
get_system_ellipse(self.config)
|
50 |
+
self.config[key] = self.saved_config[key]
|
51 |
+
|
52 |
+
def test_invalid_parameters(self):
|
53 |
+
for key in self._required_keys:
|
54 |
+
self.config[key] = -1 # None of the required keys should be negative.
|
55 |
+
with self.assertRaises(ValueError):
|
56 |
+
get_system_ellipse(self.config)
|
57 |
+
self.config[key] = self.saved_config[key]
|
58 |
+
|
59 |
+
def test_unused_params_warning(self):
|
60 |
+
copied_params = copy.deepcopy(self.config)
|
61 |
+
copied_params['unusedkey'] = 0
|
62 |
+
with self.assertWarns(UserWarning) as _:
|
63 |
+
get_system_ellipse(copied_params)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == '__main__':
|
67 |
+
unittest.main()
|
stnn/tests/test_file.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
import h5py
|
4 |
+
import tempfile
|
5 |
+
from stnn.data.preprocessing import get_data_from_file, load_data, load_training_data
|
6 |
+
|
7 |
+
|
8 |
+
class TestGetDataFromFile(unittest.TestCase):
|
9 |
+
|
10 |
+
def setUp(self):
|
11 |
+
self.temp_file = tempfile.NamedTemporaryFile(delete = False)
|
12 |
+
self.nx1, self.nx2, self.nx3 = 30, 20, 16
|
13 |
+
self.Nsamples = 10
|
14 |
+
with h5py.File(self.temp_file.name, 'w') as f:
|
15 |
+
f.create_dataset('ell', data = np.random.rand(self.Nsamples))
|
16 |
+
f.create_dataset('a1', data = np.random.rand(self.Nsamples))
|
17 |
+
f.create_dataset('a2', data = np.random.rand(self.Nsamples))
|
18 |
+
f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
|
19 |
+
f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
|
20 |
+
f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
|
21 |
+
|
22 |
+
self.temp_file1 = tempfile.NamedTemporaryFile(delete = False)
|
23 |
+
with h5py.File(self.temp_file1.name, 'w') as f:
|
24 |
+
f.create_dataset('ell', data = np.random.rand(self.Nsamples))
|
25 |
+
f.create_dataset('a1', data = np.random.rand(self.Nsamples))
|
26 |
+
f.create_dataset('a2', data = np.random.rand(self.Nsamples))
|
27 |
+
f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
|
28 |
+
f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
|
29 |
+
f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
|
30 |
+
|
31 |
+
self.bad_file = tempfile.NamedTemporaryFile(delete = False)
|
32 |
+
with h5py.File(self.bad_file.name, 'w') as f:
|
33 |
+
f.create_dataset('ell', data = np.random.rand(self.Nsamples))
|
34 |
+
f.create_dataset('a1', data = np.random.rand(self.Nsamples))
|
35 |
+
f.create_dataset('a2', data = np.random.rand(self.Nsamples))
|
36 |
+
f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
|
37 |
+
|
38 |
+
def tearDown(self):
|
39 |
+
self.temp_file.close()
|
40 |
+
self.temp_file1.close()
|
41 |
+
self.bad_file.close()
|
42 |
+
|
43 |
+
def test_missing_datasets(self):
|
44 |
+
with self.assertRaises(ValueError):
|
45 |
+
get_data_from_file(self.bad_file.name, self.nx2, self.nx2)
|
46 |
+
|
47 |
+
def test_data_extraction_shapes(self):
|
48 |
+
result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3)
|
49 |
+
self.assertEqual(result[0].shape, (self.Nsamples,))
|
50 |
+
self.assertEqual(result[1].shape, (self.Nsamples,))
|
51 |
+
self.assertEqual(result[2].shape, (self.Nsamples,))
|
52 |
+
self.assertEqual(result[3].shape, (self.Nsamples, 2 * self.nx2, self.nx3 // 2))
|
53 |
+
self.assertEqual(result[4].shape, (self.Nsamples, self.nx1, self.nx2))
|
54 |
+
|
55 |
+
def test_nrange_parameter(self):
|
56 |
+
Nrange = (2, 5)
|
57 |
+
result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
|
58 |
+
expected_size = Nrange[1] - Nrange[0]
|
59 |
+
self.assertEqual(result[0].shape, (expected_size,))
|
60 |
+
self.assertEqual(result[0].shape, (expected_size,))
|
61 |
+
|
62 |
+
def test_list_input(self):
|
63 |
+
file_list = [self.temp_file.name, self.temp_file1.name]
|
64 |
+
Nrange_list = [(0, -1), (0, -1)]
|
65 |
+
with self.assertRaises(TypeError):
|
66 |
+
# noinspection PyTypeChecker
|
67 |
+
_ = get_data_from_file(file_list, self.nx2, self.nx3, Nrange = Nrange_list)
|
68 |
+
|
69 |
+
def test_invalid_Nrange(self):
|
70 |
+
Nrange_list = [(0, -1), (0, -1)]
|
71 |
+
with self.assertRaises(TypeError):
|
72 |
+
_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange_list)
|
73 |
+
|
74 |
+
for Nrange in [(0, 1, 1), 1, (1), (1.5, 3), (3, 1.5), (1.5, 1.5), 'x']:
|
75 |
+
with self.assertRaises(TypeError):
|
76 |
+
_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
|
77 |
+
with self.assertRaises(TypeError):
|
78 |
+
_ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = list(Nrange))
|
79 |
+
|
80 |
+
def test_good_data_load(self):
|
81 |
+
files = [self.temp_file.name, self.temp_file1.name]
|
82 |
+
Nrange_list = [(0, None), (0, self.Nsamples)]
|
83 |
+
ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
|
84 |
+
|
85 |
+
params, bf, rho = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
|
86 |
+
self.assertEqual(params.shape, (2 * self.Nsamples, 3))
|
87 |
+
self.assertEqual(bf.shape, (2 * self.Nsamples, 2 * self.nx2, self.nx3 // 2))
|
88 |
+
self.assertEqual(rho.shape, (2 * self.Nsamples, self.nx1, self.nx2))
|
89 |
+
|
90 |
+
test_size = 0.3
|
91 |
+
(params_train, bf_train, rho_train,
|
92 |
+
params_test, bf_test, rho_test) = load_training_data(files, self.nx2, self.nx3,
|
93 |
+
ell1, ell2, a1, a2, test_size = test_size,
|
94 |
+
Nrange_list = Nrange_list)
|
95 |
+
Ntest = int(test_size * 2 * self.Nsamples)
|
96 |
+
Ntrain = 2 * self.Nsamples - Ntest
|
97 |
+
self.assertEqual(params_train.shape, (Ntrain, 3))
|
98 |
+
self.assertEqual(bf_train.shape, (Ntrain, 2 * self.nx2, self.nx3 // 2))
|
99 |
+
self.assertEqual(rho_train.shape, (Ntrain, self.nx1, self.nx2))
|
100 |
+
self.assertEqual(params_test.shape, (Ntest, 3))
|
101 |
+
self.assertEqual(bf_test.shape, (Ntest, 2 * self.nx2, self.nx3 // 2))
|
102 |
+
self.assertEqual(rho_test.shape, (Ntest, self.nx1, self.nx2))
|
103 |
+
|
104 |
+
def test_bad_data_load(self):
|
105 |
+
files = [self.temp_file.name, self.temp_file1.name]
|
106 |
+
Nrange_list = (0, -1)
|
107 |
+
ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
|
108 |
+
with self.assertRaises(TypeError):
|
109 |
+
_ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
|
110 |
+
with self.assertRaises(TypeError):
|
111 |
+
_ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = list(Nrange_list))
|
112 |
+
|
113 |
+
Nrange_list = [(0, -1), (0, -1)]
|
114 |
+
for test_size in [-1, 0.0, 1.5]:
|
115 |
+
with self.assertRaises(ValueError):
|
116 |
+
_ = load_training_data(files, self.nx2, self.nx3,
|
117 |
+
ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
unittest.main()
|
stnn/tests/test_pde_system.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from stnn.pde.pde_system import PDESystem
|
7 |
+
|
8 |
+
|
9 |
+
class TestPDESystem(unittest.TestCase):
|
10 |
+
|
11 |
+
def setUp(self):
|
12 |
+
self.params = {
|
13 |
+
'nx1': 50,
|
14 |
+
'nx2': 100,
|
15 |
+
'nx3': 75,
|
16 |
+
'a1': 1.0,
|
17 |
+
'a2': 2.0,
|
18 |
+
'ell': 0.1,
|
19 |
+
'eccentricity': 0.5
|
20 |
+
}
|
21 |
+
self.system = PDESystem(self.params)
|
22 |
+
|
23 |
+
def test_initialization(self):
|
24 |
+
self.assertEqual(self.system.params, self.params)
|
25 |
+
|
26 |
+
def test_attribute_types(self):
|
27 |
+
self.assertIsInstance(self.system.ib_slice, np.ndarray)
|
28 |
+
self.assertIsInstance(self.system.ob_slice, np.ndarray)
|
29 |
+
|
30 |
+
def test_attribute_values(self):
|
31 |
+
self.assertEqual(self.system.a1, 1.0 - self.params['eccentricity'])
|
32 |
+
|
33 |
+
def test_coordinate_system(self):
|
34 |
+
expected_coords = 'ellipse' if self.params['eccentricity'] != 0 else 'circle'
|
35 |
+
self.assertEqual(self.system._coords, expected_coords)
|
36 |
+
|
37 |
+
def test_unused_params_warning(self):
|
38 |
+
copied_params = copy.deepcopy(self.params)
|
39 |
+
copied_params['unusedkey'] = 0
|
40 |
+
with self.assertWarns(UserWarning) as _:
|
41 |
+
_ = PDESystem(copied_params)
|
42 |
+
|
43 |
+
def test_L(self):
|
44 |
+
# If f(x1, x2, x3) is constant, then L * f should be 0 except adjacent to the boundary.
|
45 |
+
L = self.system.L
|
46 |
+
nx1, nx2, nx3 = self.params['nx1'], self.params['nx2'], self.params['nx3']
|
47 |
+
f0 = 2.3 * np.ones((nx1, nx2, nx3))
|
48 |
+
result = L @ f0.ravel()
|
49 |
+
result = result.reshape((nx1, nx2, nx3))
|
50 |
+
np.testing.assert_allclose(result[1:-1, ...], 0, atol = 1e-7)
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
unittest.main()
|
stnn/tests/test_preprocessing.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
from stnn.data.preprocessing import train_test_split
|
4 |
+
|
5 |
+
|
6 |
+
class TestTrainTestSplit(unittest.TestCase):
|
7 |
+
|
8 |
+
def setUp(self):
|
9 |
+
self.X_array = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
|
10 |
+
self.Y_array = np.array([1, 2, 3, 4])
|
11 |
+
self.X_list = [self.X_array, self.X_array]
|
12 |
+
self.Y_list = [self.Y_array, self.Y_array]
|
13 |
+
self.X_list_bad = [self.Y_array, self.X_array]
|
14 |
+
self.Y_list_bad = [self.Y_array, self.X_array]
|
15 |
+
|
16 |
+
def test_basic_functionality_array(self):
|
17 |
+
X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
|
18 |
+
self.assertEqual(len(X_train), 3)
|
19 |
+
self.assertEqual(len(X_test), 1)
|
20 |
+
self.assertEqual(len(Y_train), 3)
|
21 |
+
self.assertEqual(len(Y_test), 1)
|
22 |
+
|
23 |
+
def test_basic_functionality_list(self):
|
24 |
+
X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
|
25 |
+
self.assertEqual(len(X_train[0]), 3)
|
26 |
+
self.assertEqual(len(X_test[0]), 1)
|
27 |
+
self.assertEqual(len(Y_train[0]), 3)
|
28 |
+
self.assertEqual(len(Y_test[0]), 1)
|
29 |
+
|
30 |
+
def test_return_type_consistency_array(self):
|
31 |
+
X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
|
32 |
+
self.assertIsInstance(X_train, np.ndarray)
|
33 |
+
self.assertIsInstance(X_test, np.ndarray)
|
34 |
+
self.assertIsInstance(Y_train, np.ndarray)
|
35 |
+
self.assertIsInstance(Y_test, np.ndarray)
|
36 |
+
|
37 |
+
X_train, X_test, Y_train, Y_test = train_test_split([self.X_array], [self.Y_array], test_size = 0.25)
|
38 |
+
self.assertIsInstance(X_train, list)
|
39 |
+
self.assertIsInstance(X_test, list)
|
40 |
+
self.assertIsInstance(Y_train, list)
|
41 |
+
self.assertIsInstance(Y_test, list)
|
42 |
+
|
43 |
+
def test_return_type_consistency_list(self):
|
44 |
+
X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
|
45 |
+
self.assertIsInstance(X_train, list)
|
46 |
+
self.assertIsInstance(X_test, list)
|
47 |
+
self.assertIsInstance(Y_train, list)
|
48 |
+
self.assertIsInstance(Y_test, list)
|
49 |
+
|
50 |
+
# noinspection PyTypeChecker
|
51 |
+
X_train, X_test, Y_train, Y_test = train_test_split(tuple(self.X_list), tuple(self.Y_list), test_size = 0.25)
|
52 |
+
self.assertIsInstance(X_train, list)
|
53 |
+
self.assertIsInstance(X_test, list)
|
54 |
+
self.assertIsInstance(Y_train, list)
|
55 |
+
self.assertIsInstance(Y_test, list)
|
56 |
+
|
57 |
+
def test_random_state(self):
|
58 |
+
X_train1, X_test1, Y_train1, Y_test1 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
|
59 |
+
random_state = 42)
|
60 |
+
X_train2, X_test2, Y_train2, Y_test2 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
|
61 |
+
random_state = 42)
|
62 |
+
np.testing.assert_array_equal(X_train1, X_train2)
|
63 |
+
np.testing.assert_array_equal(X_test1, X_test2)
|
64 |
+
np.testing.assert_array_equal(Y_train1, Y_train2)
|
65 |
+
np.testing.assert_array_equal(Y_test1, Y_test2)
|
66 |
+
|
67 |
+
def test_invalid_test_size(self):
|
68 |
+
with self.assertRaises(ValueError):
|
69 |
+
train_test_split(self.X_array, self.Y_array, test_size = -0.1)
|
70 |
+
with self.assertRaises(ValueError):
|
71 |
+
train_test_split(self.X_array, self.Y_array, test_size = 1.5)
|
72 |
+
|
73 |
+
def test_inconsistent_length(self):
|
74 |
+
X = np.array([[1, 2], [3, 4]])
|
75 |
+
Y = np.array([1, 2, 3])
|
76 |
+
with self.assertRaises(ValueError):
|
77 |
+
train_test_split(X, Y)
|
78 |
+
with self.assertRaises(ValueError):
|
79 |
+
train_test_split(self.X_list_bad, self.Y_list_bad)
|
80 |
+
with self.assertRaises(ValueError):
|
81 |
+
train_test_split(self.X_list_bad, self.Y_list)
|
82 |
+
with self.assertRaises(ValueError):
|
83 |
+
train_test_split(self.X_list, self.Y_list_bad)
|
84 |
+
|
85 |
+
def test_empty(self):
|
86 |
+
X_empty = np.zeros(0)
|
87 |
+
Y_empty = np.zeros(0)
|
88 |
+
with self.assertRaises(ValueError):
|
89 |
+
train_test_split(X_empty, Y_empty)
|
90 |
+
with self.assertRaises(ValueError):
|
91 |
+
train_test_split([X_empty], [])
|
92 |
+
with self.assertRaises(ValueError):
|
93 |
+
train_test_split([], [Y_empty])
|
94 |
+
with self.assertRaises(ValueError):
|
95 |
+
train_test_split([X_empty], [Y_empty])
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
unittest.main()
|
stnn/tests/test_stats.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import unittest
|
3 |
+
import os
|
4 |
+
from stnn.utils.stats import get_stats
|
5 |
+
|
6 |
+
|
7 |
+
class TestGetStats(unittest.TestCase):
|
8 |
+
|
9 |
+
def setUp(self):
|
10 |
+
self.rho = np.array([[1, 2], [3, 4]])
|
11 |
+
self.rho_pred = np.array([[1, 2], [3, 4]])
|
12 |
+
self.filename = 'test_stats.npz'
|
13 |
+
|
14 |
+
def test_correctness(self):
|
15 |
+
get_stats(self.rho, self.rho_pred, self.filename)
|
16 |
+
with np.load(self.filename) as data:
|
17 |
+
self.assertAlmostEqual(data['max_loss'], 0.0, places=5)
|
18 |
+
self.assertEqual(data['avg_loss'], 0.0)
|
19 |
+
self.assertEqual(data['N'], self.rho.shape[0])
|
20 |
+
|
21 |
+
def test_file_creation(self):
|
22 |
+
get_stats(self.rho, self.rho_pred, self.filename)
|
23 |
+
self.assertTrue(os.path.exists(self.filename))
|
24 |
+
|
25 |
+
def test_file_content(self):
|
26 |
+
get_stats(self.rho, self.rho_pred, self.filename)
|
27 |
+
with np.load(self.filename) as data:
|
28 |
+
self.assertIn('max_loss', data)
|
29 |
+
self.assertIn('avg_loss', data)
|
30 |
+
self.assertIn('N', data)
|
31 |
+
|
32 |
+
def test_invalid_input(self):
|
33 |
+
with self.assertRaises(ValueError):
|
34 |
+
get_stats(np.array([1, 2]), np.array([[1, 2], [3, 4]]))
|
35 |
+
|
36 |
+
def tearDown(self):
|
37 |
+
if os.path.exists(self.filename):
|
38 |
+
os.remove(self.filename)
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
unittest.main()
|
stnn/tests/test_stnn_config.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
import copy
|
4 |
+
from stnn.nn.stnn import build_stnn
|
5 |
+
|
6 |
+
|
7 |
+
class TestBuildSTNN(unittest.TestCase):
|
8 |
+
|
9 |
+
def setUp(self):
|
10 |
+
self.config = {
|
11 |
+
'K': 1,
|
12 |
+
'nx1': 8,
|
13 |
+
'nx2': 8,
|
14 |
+
'nx3': 8,
|
15 |
+
'd': 8,
|
16 |
+
'W': 3,
|
17 |
+
'shape1': [1, 2, 3],
|
18 |
+
'shape2': [2, 2, 2],
|
19 |
+
'ranks': [1, 2, 2, 1],
|
20 |
+
}
|
21 |
+
self.saved_config = copy.deepcopy(self.config)
|
22 |
+
self._required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
|
23 |
+
self._optional_keys = ['use_regularization', 'regularization_strength']
|
24 |
+
|
25 |
+
def test_missing_keys(self):
|
26 |
+
for key in self._required_keys:
|
27 |
+
del self.config[key]
|
28 |
+
with self.assertRaises(KeyError):
|
29 |
+
build_stnn(self.config)
|
30 |
+
self.config[key] = self.saved_config[key]
|
31 |
+
|
32 |
+
def test_invalid_values(self):
|
33 |
+
for key in ['K', 'd', 'W', 'nx1', 'nx2', 'nx3']:
|
34 |
+
for value in [1.5, 'a', None, np.nan]:
|
35 |
+
with self.subTest(value = value):
|
36 |
+
self.config[key] = value
|
37 |
+
with self.assertRaises(TypeError):
|
38 |
+
build_stnn(self.config)
|
39 |
+
self.config[key] = self.saved_config[key]
|
40 |
+
value = -1
|
41 |
+
with self.subTest(value = value):
|
42 |
+
self.config[key] = value
|
43 |
+
with self.assertRaises(ValueError):
|
44 |
+
build_stnn(self.config)
|
45 |
+
self.config[key] = self.saved_config[key]
|
46 |
+
|
47 |
+
self.config['nx3'] = 7 # not divisible by 2
|
48 |
+
with self.assertRaises(ValueError):
|
49 |
+
build_stnn(self.config)
|
50 |
+
self.config[key] = self.saved_config[key]
|
51 |
+
|
52 |
+
def test_positive_values(self):
|
53 |
+
for key in ['K', 'nx1', 'nx2', 'nx3', 'd']:
|
54 |
+
self.config[key] = 0
|
55 |
+
with self.assertRaises(ValueError):
|
56 |
+
build_stnn(self.config)
|
57 |
+
self.config[key] = self.saved_config[key]
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
unittest.main()
|
stnn/tests/test_ttl.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from stnn.nn.stnn_layers import TTL
|
3 |
+
|
4 |
+
|
5 |
+
class TestTTL(unittest.TestCase):
|
6 |
+
|
7 |
+
def test_missing_config_keys(self):
|
8 |
+
with self.assertRaises(KeyError):
|
9 |
+
TTL(config = {}) # Empty config
|
10 |
+
|
11 |
+
def test_invalid_use_regularization_type(self):
|
12 |
+
config = {'use_regularization': 'not a boolean', 'nx1': 1, 'nx2': 1, 'nx3': 1, 'shape1': [1], 'shape2': [1],
|
13 |
+
'ranks': [1], 'W': 1}
|
14 |
+
with self.assertRaises(TypeError):
|
15 |
+
TTL(config = config)
|
16 |
+
|
17 |
+
def test_invalid_nx_values(self):
|
18 |
+
config = {'nx1': 5, 'nx2' : 5, 'nx3': 8, 'use_regularization': False, 'shape1': [1], 'shape2': [1], 'ranks': [1], 'W': 1}
|
19 |
+
for nx in ['nx1', 'nx2', 'nx3']:
|
20 |
+
for value in [-1, 0]:
|
21 |
+
config.update({nx: value})
|
22 |
+
with self.assertRaises(ValueError):
|
23 |
+
TTL(config = config)
|
24 |
+
|
25 |
+
def test_invalid_W_value(self):
|
26 |
+
config = {'use_regularization': False, 'nx1': 1, 'nx2': 1, 'nx3': 1, 'shape1': [1], 'shape2': [1], 'ranks': [1],
|
27 |
+
'W': -1}
|
28 |
+
with self.assertRaises(ValueError):
|
29 |
+
TTL(config = config)
|
30 |
+
|
31 |
+
def test_shape_length_mismatch(self):
|
32 |
+
config = {'use_regularization': False, 'nx1': 1, 'nx2': 1, 'nx3': 1, 'shape1': [1, 2], 'shape2': [1],
|
33 |
+
'ranks': [1], 'W': 1}
|
34 |
+
with self.assertRaises(ValueError):
|
35 |
+
TTL(config = config)
|
36 |
+
|
37 |
+
def test_incorrect_shape1_product(self):
|
38 |
+
config = {'use_regularization': False, 'nx1': 2, 'nx2': 2, 'nx3': 1, 'shape1': [1, 3], 'shape2': [4],
|
39 |
+
'ranks': [1], 'W': 1}
|
40 |
+
with self.assertRaises(ValueError):
|
41 |
+
TTL(config = config)
|
42 |
+
|
43 |
+
def test_incorrect_shape2_product(self):
|
44 |
+
config = {'use_regularization': False, 'nx1': 1, 'nx2': 2, 'nx3': 1, 'shape1': [2], 'shape2': [1, 5],
|
45 |
+
'ranks': [1], 'W': 1}
|
46 |
+
with self.assertRaises(ValueError):
|
47 |
+
TTL(config = config)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
unittest.main()
|
stnn/utils/__init__.py
ADDED
File without changes
|
stnn/utils/input_output.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
from scipy.sparse import csr_matrix
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import tensorflow as tf
|
6 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
7 |
+
|
8 |
+
|
9 |
+
def save_to_hdf5(filename, datasets, start_idx, end_idx):
|
10 |
+
"""
|
11 |
+
Saves subsets of datasets to an HDF5 file, either by creating new datasets or appending to existing ones.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
filename (str): The name of the HDF5 file where the data will be saved.
|
15 |
+
datasets (dict): A dictionary where keys are dataset names and values are the corresponding data arrays.
|
16 |
+
start_idx (int): The starting index of the data slice to be saved.
|
17 |
+
end_idx (int): The ending index (exclusive) of the data slice to be saved.
|
18 |
+
|
19 |
+
This function will create new datasets if they do not already exist. If a dataset already exists, it will be resized
|
20 |
+
to accommodate the new data, and the data slice will be appended.
|
21 |
+
"""
|
22 |
+
print(f'Saving data from n = {start_idx} to n = {end_idx}...')
|
23 |
+
with h5py.File(filename, 'a') as f:
|
24 |
+
for name, data in datasets.items():
|
25 |
+
if name not in f:
|
26 |
+
f.create_dataset(name, data = data[start_idx:end_idx], maxshape = (None,) + data.shape[1:],
|
27 |
+
chunks = (1,) + data.shape[1:])
|
28 |
+
else:
|
29 |
+
f[name].resize((f[name].shape[0] + end_idx - start_idx,) + f[name].shape[1:])
|
30 |
+
f[name][-(end_idx - start_idx):] = data[start_idx:end_idx]
|
31 |
+
print('Done.')
|
32 |
+
|
33 |
+
|
34 |
+
def write_sparse_matrix_hdf5(filename, sparse_matrix, dataset_name = 'sparse_matrix', format_ = 'csr'):
|
35 |
+
"""
|
36 |
+
Writes a sparse matrix to an HDF5 file in a specified format.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
filename (str): Name of the output file
|
40 |
+
sparse_matrix (scipy.sparse matrix): Sparse matrix to be written to the file.
|
41 |
+
dataset_name (str, optional): Name of the HDF5 dataset. Defaults to 'sparse_matrix'.
|
42 |
+
format_ (str, optional): The format of the sparse matrix. Currently only supports 'csr' (Compressed Sparse Row).
|
43 |
+
"""
|
44 |
+
if format_ == 'csr':
|
45 |
+
with h5py.File(filename, 'w') as f:
|
46 |
+
g = f.create_group(dataset_name)
|
47 |
+
g.create_dataset('data', data = sparse_matrix.data)
|
48 |
+
g.create_dataset('indices', data = sparse_matrix.indices)
|
49 |
+
g.create_dataset('indptr', data = sparse_matrix.indptr)
|
50 |
+
g.create_dataset('shape', data = np.array(sparse_matrix.shape))
|
51 |
+
g.attrs['format'] = 'csr'
|
52 |
+
else:
|
53 |
+
raise ValueError(f'Unsupported sparse matrix format: {format_}')
|
54 |
+
|
55 |
+
|
56 |
+
def read_sparse_matrix_hdf5(filename, dataset_name = 'sparse_matrix'):
|
57 |
+
"""
|
58 |
+
Reads a sparse matrix from an HDF5 file.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
filename (str): Name of the output file
|
62 |
+
dataset_name (str, optional): Name of the dataset containing the matrix. Defaults to 'sparse_matrix'.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
scipy.sparse matrix: The sparse matrix read from the file.
|
66 |
+
"""
|
67 |
+
with h5py.File(filename, 'r') as f:
|
68 |
+
g = f[dataset_name]
|
69 |
+
data = g['data'][:]
|
70 |
+
indices = g['indices'][:]
|
71 |
+
indptr = g['indptr'][:]
|
72 |
+
shape = tuple(g['shape'][:])
|
73 |
+
format_ = g.attrs['format']
|
74 |
+
if format_ == 'csr':
|
75 |
+
return csr_matrix((data, indices, indptr), shape = shape)
|
76 |
+
else:
|
77 |
+
raise ValueError(f'Unsupported sparse matrix format: {format_}')
|
78 |
+
|
79 |
+
|
80 |
+
def data_dump(bf, rho, rho_pred, params_dict):
|
81 |
+
"""
|
82 |
+
Writes data and config to file for later use.
|
83 |
+
"""
|
84 |
+
np.savez('sample_data.npz', rho = rho, rho_pred = rho_pred, bf = bf)
|
85 |
+
json.dump('sample_config.json', params_dict, encoding = 'utf-8')
|
86 |
+
|
87 |
+
|
88 |
+
def save_as_frozen_graph(model, saved_model_dir):
|
89 |
+
"""
|
90 |
+
Converts a TensorFlow model into a 'frozen' SavedModel format.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model: TensorFlow model to be frozen.
|
94 |
+
saved_model_dir (str): The directory path where the frozen model will be saved.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
TFModelServer object: An instance of the TFModelServer class, which can be used for serving the model.
|
98 |
+
|
99 |
+
The function converts model variables to constants and is required for converting the model to the intermediate
|
100 |
+
representation (IR) used by openvino.
|
101 |
+
"""
|
102 |
+
# Create input specifications for the model
|
103 |
+
input_specs = [tf.TensorSpec([1] + model.inputs[i].shape[1:].as_list(), model.inputs[i].dtype) for i in
|
104 |
+
range(len(model.inputs))]
|
105 |
+
|
106 |
+
# Create a concrete function from the model
|
107 |
+
full_model = tf.function(lambda x: model(x))
|
108 |
+
full_model = full_model.get_concrete_function(input_specs)
|
109 |
+
|
110 |
+
# Convert the model to a frozen function
|
111 |
+
frozen_func = convert_variables_to_constants_v2(full_model)
|
112 |
+
|
113 |
+
# Define a new module with a method that has a `@tf.function` decorator with input signatures
|
114 |
+
class TFModelServer(tf.Module):
|
115 |
+
def __init__(self, frozen_func):
|
116 |
+
super().__init__()
|
117 |
+
self.frozen_func = frozen_func
|
118 |
+
|
119 |
+
@tf.function(input_signature = input_specs)
|
120 |
+
def serve(self, *args):
|
121 |
+
return self.frozen_func(*args)
|
122 |
+
|
123 |
+
# Create an instance of TFModelServer with the frozen function
|
124 |
+
model_server = TFModelServer(frozen_func)
|
125 |
+
|
126 |
+
# Save the module as a SavedModel
|
127 |
+
tf.saved_model.save(model_server, saved_model_dir, signatures = {"serving_default": model_server.serve})
|
128 |
+
|
129 |
+
return model_server
|
130 |
+
|
131 |
+
|
132 |
+
def log_and_print(message):
|
133 |
+
print(message)
|
134 |
+
with open('log.txt', 'a') as log_file:
|
135 |
+
log_file.write(message + '\n')
|
stnn/utils/network_visualization.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pydot
|
2 |
+
import re
|
3 |
+
from keras.models import Model
|
4 |
+
from keras.layers import Layer, InputLayer
|
5 |
+
from pygments.lexers import graphviz
|
6 |
+
|
7 |
+
|
8 |
+
# May be necessary to manually add Graphviz to PATH, e.g.
|
9 |
+
# import os
|
10 |
+
# os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin'
|
11 |
+
|
12 |
+
def visualize_model(model, layer_labels = None, layer_colors = None, groupings = None, exclude_input_layer = False,
|
13 |
+
verbose = False, output_filename = 'model_graph.png'):
|
14 |
+
"""
|
15 |
+
Creates a visual graph of a keras model. There is an option to group certain layers into subgraphs
|
16 |
+
(argument 'groupings').
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model: A Keras Model instance
|
20 |
+
layer_labels (optional): List of labels for each layer. Defaults to layer names.
|
21 |
+
layer_colors (optional): List of colors for each layer. Defaults to white for all layers.
|
22 |
+
groupings (optional): Dictionary specifying groups of layers. Each key is a group name,
|
23 |
+
and its value is a list of layer names belonging to that group.
|
24 |
+
exclude_input_layer (optional): Boolean indicating whether to exclude the input layer from the graph.
|
25 |
+
verbose (boolean, optional): Whether to print verbose output. Defaults to False.
|
26 |
+
output_filename (optional): name of the output file for saving the generated graph.
|
27 |
+
|
28 |
+
Output:
|
29 |
+
Image file with name 'output_filename'.
|
30 |
+
"""
|
31 |
+
if not isinstance(model, Model):
|
32 |
+
raise ValueError("model should be a Keras model instance")
|
33 |
+
num_layers = len(model.layers)
|
34 |
+
|
35 |
+
# Default labels and colors if not provided
|
36 |
+
if not layer_labels:
|
37 |
+
layer_labels = [layer.name for layer in model.layers]
|
38 |
+
if not layer_colors:
|
39 |
+
default_color = 'white'
|
40 |
+
layer_colors = [default_color] * num_layers
|
41 |
+
|
42 |
+
# Create a directed graph
|
43 |
+
graph = pydot.Dot(graph_type = 'digraph', rankdir = 'LR')
|
44 |
+
|
45 |
+
# Create nodes for each layer and add to subgraphs if specified
|
46 |
+
subgraphs = {}
|
47 |
+
layer_id_map = {}
|
48 |
+
for i, layer in enumerate(model.layers):
|
49 |
+
# Exclude the input layer if specified
|
50 |
+
if exclude_input_layer and isinstance(layer, InputLayer):
|
51 |
+
continue
|
52 |
+
|
53 |
+
# Create a node for the layer
|
54 |
+
layer_id = str(id(layer))
|
55 |
+
layer_id_map[layer] = layer_id
|
56 |
+
label = layer_labels[i]
|
57 |
+
color = layer_colors[i]
|
58 |
+
|
59 |
+
node = pydot.Node(layer_id, label = label, style = 'filled', fillcolor = color, shape = 'box')
|
60 |
+
|
61 |
+
# Check for groupings and add the node to the appropriate subgraph or main graph
|
62 |
+
group_name = None
|
63 |
+
if groupings:
|
64 |
+
for group, members in groupings.items():
|
65 |
+
if layer.name in members:
|
66 |
+
group_name = group
|
67 |
+
break
|
68 |
+
|
69 |
+
if group_name:
|
70 |
+
if group_name not in subgraphs:
|
71 |
+
subgraph = pydot.Cluster(group_name, label = group_name, style = 'dashed', fontsize = 24)
|
72 |
+
subgraphs[group_name] = subgraph
|
73 |
+
subgraphs[group_name].add_node(node)
|
74 |
+
else:
|
75 |
+
graph.add_node(node)
|
76 |
+
|
77 |
+
# Add subgraphs to the main graph
|
78 |
+
for subgraph in subgraphs.values():
|
79 |
+
graph.add_subgraph(subgraph)
|
80 |
+
|
81 |
+
# Add edges based on layer connections
|
82 |
+
for layer in model.layers:
|
83 |
+
if exclude_input_layer and isinstance(layer, InputLayer):
|
84 |
+
continue
|
85 |
+
# Handle custom or non-standard layers
|
86 |
+
if hasattr(layer, '_inbound_nodes'):
|
87 |
+
inbound_nodes = layer._inbound_nodes
|
88 |
+
else:
|
89 |
+
# If the layer doesn't have '_inbound_nodes', skip edge creation
|
90 |
+
continue
|
91 |
+
|
92 |
+
inbound_layers = []
|
93 |
+
for inbound_node in inbound_nodes:
|
94 |
+
inbound_layers = inbound_node.inbound_layers
|
95 |
+
if not isinstance(inbound_layers, list):
|
96 |
+
inbound_layers = [inbound_layers]
|
97 |
+
|
98 |
+
for inbound_node in inbound_nodes:
|
99 |
+
for inbound_layer in inbound_layers:
|
100 |
+
if isinstance(inbound_layer, Layer) and inbound_layer in layer_id_map:
|
101 |
+
src_id = layer_id_map[inbound_layer]
|
102 |
+
dest_id = layer_id_map[layer]
|
103 |
+
if (re.search('sequential', inbound_layer.name, flags = re.IGNORECASE) or
|
104 |
+
re.search(r'operators__.getitem_[0-9]+$', inbound_layer.name, flags = re.IGNORECASE)):
|
105 |
+
graph.add_edge(pydot.Edge(src_id, dest_id, style = 'invis'))
|
106 |
+
else:
|
107 |
+
graph.add_edge(pydot.Edge(src_id, dest_id))
|
108 |
+
if verbose:
|
109 |
+
print(f"Added edge from {inbound_layer.name} to {layer.name}")
|
110 |
+
|
111 |
+
graph.set_graph_defaults(sep = '+125,125')
|
112 |
+
try:
|
113 |
+
graph.write_png(output_filename)
|
114 |
+
except FileNotFoundError as e:
|
115 |
+
print(f'\nFailed to create network visualization using pydot and graphviz. Pleasure ensure that '
|
116 |
+
'the output filename is valid, and graphviz is installed and included in the system PATH variable. '
|
117 |
+
f'Original error: {e}')
|
118 |
+
except Exception as e:
|
119 |
+
print(f'\nFailed to create network visualization using pydot and graphviz. Original error: {e}')
|
120 |
+
else:
|
121 |
+
print(f'Model visualization saved to {output_filename}')
|
stnn/utils/plotting.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
4 |
+
from matplotlib.cm import get_cmap
|
5 |
+
|
6 |
+
|
7 |
+
def plot_comparison(system, bf, rho, rho_pred, fontscale = 1, output_filename = 'comparison.png',
|
8 |
+
wspace = -0.1, hspace = 0.5):
|
9 |
+
"""
|
10 |
+
Side-by-side comparison of contour plots for arguments 'rho' and 'rho_pred'.
|
11 |
+
Also plots the boundary data (argument 'bf') as a contour plot.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
system: PDESystem object
|
15 |
+
bf (np.ndarray): 2D array representing the boundary data on the (x2, x3) grid.
|
16 |
+
rho (np.ndarray): 2D array, known rho
|
17 |
+
rho_pred (np.ndarray): 2D array, predicted rho
|
18 |
+
fontscale (int, optional): A scaling factor for the font size in plots.
|
19 |
+
output_filename (str, optional): name of the output file
|
20 |
+
wspace (float, optional): horizontal spacing between subplots
|
21 |
+
hspace (float, optional): vertical spacing between subplots
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Does not return a value. The figure is saved to a file with name given by 'output_filename'. The default
|
25 |
+
is 'comparison.png'.
|
26 |
+
"""
|
27 |
+
# System parameters
|
28 |
+
ell = system.params['ell']
|
29 |
+
a2 = system.a2
|
30 |
+
e = system.params['eccentricity']
|
31 |
+
nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
|
32 |
+
|
33 |
+
# Get x, y grids from 'PDESystem' object
|
34 |
+
x, y = system.get_xy_grids()
|
35 |
+
|
36 |
+
# Relative error
|
37 |
+
err = np.linalg.norm(rho - rho_pred)
|
38 |
+
rel_err = err / np.linalg.norm(rho)
|
39 |
+
|
40 |
+
# wrap around values for continuity
|
41 |
+
rho = np.append(rho, rho[:, 0:1], axis = 1)
|
42 |
+
rho_pred = np.append(rho_pred, rho_pred[:, 0:1], axis = 1)
|
43 |
+
|
44 |
+
plots = [rho, rho_pred]
|
45 |
+
|
46 |
+
# Color bar limits
|
47 |
+
vmin = np.nanmin(rho_pred[:, :])
|
48 |
+
vmax = np.nanmax(rho_pred[:, :])
|
49 |
+
vmin = min(vmin, np.nanmin(rho[:, :]))
|
50 |
+
vmax = max(vmax, np.nanmax(rho[:, :]))
|
51 |
+
|
52 |
+
# Figure layout
|
53 |
+
cbar_coords = (0.89, 0.11, 0.03, 0.77)
|
54 |
+
fig = plt.figure(figsize = (24, 24))
|
55 |
+
gs = fig.add_gridspec(11, 2, hspace = hspace, wspace = wspace)
|
56 |
+
axs = [fig.add_subplot(gs[:11, 0]), fig.add_subplot(gs[:5, 1]), fig.add_subplot(gs[6:11, 1])]
|
57 |
+
|
58 |
+
# boundary data contour plot
|
59 |
+
bf_plot = np.nan * np.ones((2 * nx2, nx3))
|
60 |
+
bf_plot[:nx2, :nx3 // 2] = bf[:nx2, :]
|
61 |
+
bf_plot[nx2:, nx3 // 2:] = bf[nx2:, :]
|
62 |
+
im = axs[0].imshow(bf_plot)
|
63 |
+
axs[0].set_yticks([1, nx2 // 4, nx2 // 2 - 1, nx2 // 2 + 1, 3 * nx2 // 4, nx2])
|
64 |
+
axs[0].set_yticklabels(['-pi', '0', 'pi', '', '0', 'pi'])
|
65 |
+
axs[0].set_xticks([1, nx3 // 2, nx3])
|
66 |
+
axs[0].set_xticklabels(['0', 'pi', '2pi'])
|
67 |
+
axs[0].set_title('Boundary data\n\n', fontsize = fontscale * 48, y = 0.91)
|
68 |
+
for label in axs[0].get_xticklabels() + axs[0].get_yticklabels():
|
69 |
+
label.set_fontsize(fontscale * 48)
|
70 |
+
divider = make_axes_locatable(axs[0])
|
71 |
+
cax = divider.append_axes('bottom', size = "3%", pad = 1.5)
|
72 |
+
cb = fig.colorbar(im, cax = cax, orientation = 'horizontal')
|
73 |
+
cb.ax.tick_params(labelsize = fontscale * 48)
|
74 |
+
cax.xaxis.set_ticks_position('bottom')
|
75 |
+
|
76 |
+
# rho(x, y) contour plots
|
77 |
+
titles = ['Direct Solution', 'Tensor network']
|
78 |
+
for i, ax in enumerate(axs[1:]):
|
79 |
+
z = plots[i]
|
80 |
+
im = ax.contourf(x, y, z, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
|
81 |
+
ax.set_title(titles[i], fontsize = fontscale * 48)
|
82 |
+
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
83 |
+
label.set_fontsize(fontscale * 32)
|
84 |
+
ax.set_aspect(1.0)
|
85 |
+
|
86 |
+
cbar_ax = fig.add_axes(cbar_coords)
|
87 |
+
cb = fig.colorbar(im, cax = cbar_ax, pad = 0.05)
|
88 |
+
cb.ax.tick_params(labelsize = fontscale * 32)
|
89 |
+
|
90 |
+
suptitle = f'ell = {ell:.3f}; a2 = {a2:.3f}; e = {e:.3f}; Relative error: {rel_err:.3f}'
|
91 |
+
plt.suptitle(suptitle, fontsize = fontscale * 48, x = 0.1, y = 0.97, horizontalalignment = 'left')
|
92 |
+
plt.savefig(output_filename)
|
93 |
+
plt.close()
|
stnn/utils/stats.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def get_stats(rho, rho_pred, output_filename = 'stats.npz'):
|
5 |
+
"""
|
6 |
+
Calculates statistical metrics for model predictions and saves them to a file.
|
7 |
+
|
8 |
+
This function computes the normalized loss for each instance in the dataset, identifies the maximum loss and its
|
9 |
+
index, and calculates the average loss. These statistics are then saved to an NPZ file.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
rho (numpy.ndarray): True values.
|
13 |
+
rho_pred (numpy.ndarray): Predicted values.
|
14 |
+
output_filename (str, optional): The name of the file where the statistics will be saved. Defaults to 'stats.npz'.
|
15 |
+
"""
|
16 |
+
if rho.shape != rho_pred.shape:
|
17 |
+
raise ValueError('rho and rho_pred must have the same shape.')
|
18 |
+
|
19 |
+
y_true_flattened = rho.reshape(rho.shape[0], -1)
|
20 |
+
y_pred_flattened = rho_pred.reshape(rho_pred.shape[0], -1)
|
21 |
+
loss = np.linalg.norm(y_true_flattened - y_pred_flattened, axis = 1) / np.linalg.norm(y_true_flattened, axis = 1)
|
22 |
+
max_loss = np.max(loss)
|
23 |
+
max_loss_index = np.argmax(loss)
|
24 |
+
print(f'Maximum loss: {max_loss}')
|
25 |
+
print(f'Index of instance with max loss: {max_loss_index}')
|
26 |
+
print(f'Average loss: {np.average(loss)}')
|
27 |
+
np.savez(output_filename, max_loss = max_loss, avg_loss = np.average(loss), N = loss.shape[0])
|