caleb2 commited on
Commit
d68c650
·
1 Parent(s): 942221a

initial commit

Browse files
README.md CHANGED
@@ -10,4 +10,11 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
10
  license: mit
11
  ---
12
 
13
+ ## Stacked Tensorial Neural Network (STNN) demo
14
+ This demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979)
15
+ to solve a parametric PDE problem on an elliptical annular domain. See the paper for a
16
+ detailed description of the problem and its applications.
17
+
18
+ The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, including
19
+ intructions for solving the PDE using a conventional iterative method (GMRES). Due to the long runtime of
20
+ solving the PDE in this way, it is not included in the demo.
T5_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "K": 20,
3
+ "d": 8,
4
+ "W": 2,
5
+ "ranks": [
6
+ 1,
7
+ 16,
8
+ 16,
9
+ 16,
10
+ 16,
11
+ 16,
12
+ 7,
13
+ 1
14
+ ],
15
+ "shape1": [
16
+ 4,
17
+ 4,
18
+ 4,
19
+ 4,
20
+ 4,
21
+ 4,
22
+ 4
23
+ ],
24
+ "shape2": [
25
+ 4,
26
+ 2,
27
+ 2,
28
+ 2,
29
+ 2,
30
+ 2,
31
+ 2
32
+ ],
33
+ "nx1": 256,
34
+ "nx2": 64,
35
+ "nx3": 32,
36
+ "ell_min": 0.01,
37
+ "ell_max": 100.0,
38
+ "a2_min": 2.0,
39
+ "a2_max": 20.0
40
+ }
T5_weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51b60beefbbd6e5c8ade77d6f4f64c54146f34dc79e2b91605e0864ecbfcbd07
3
+ size 957976
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.ticker as ticker
5
+ import numpy as np
6
+ import sympy
7
+ from matplotlib.cm import get_cmap
8
+ from stnn.nn import stnn
9
+ from stnn.pde.pde_system import PDESystem
10
+
11
+
12
+ def adjust_to_nice_number(value, round_down = False):
13
+ """
14
+ Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks.
15
+ """
16
+ if value == 0:
17
+ return value
18
+
19
+ is_negative = False
20
+ if value < 0:
21
+ round_down = True
22
+ is_negative = True
23
+ value = -value
24
+ exponent = np.floor(np.log10(value)) # Find exponent of 10
25
+ fractional_part = value / 10**exponent # Find leading digit(s)
26
+
27
+ if round_down:
28
+ if fractional_part < 1.5:
29
+ nice_fractional = 1
30
+ elif fractional_part < 3:
31
+ nice_fractional = 2
32
+ elif fractional_part < 7:
33
+ nice_fractional = 5
34
+ else:
35
+ nice_fractional = 10
36
+ else:
37
+ if fractional_part <= 1:
38
+ nice_fractional = 1
39
+ elif fractional_part <= 2:
40
+ nice_fractional = 2
41
+ elif fractional_part <= 5:
42
+ nice_fractional = 5
43
+ else:
44
+ nice_fractional = 10
45
+
46
+ nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1)
47
+ if is_negative:
48
+ nice_value = -nice_value
49
+ return nice_value
50
+
51
+
52
+ def find_nice_values(min_val_raw, max_val, num_values = 4):
53
+ """
54
+ Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks.
55
+ """
56
+ # Calculate rough spacing between values
57
+ min_val = adjust_to_nice_number(min_val_raw)
58
+ frac_val = (min_val - min_val_raw) / (max_val - min_val_raw)
59
+ if frac_val < 1 / num_values:
60
+ min_val = min_val_raw
61
+ raw_spacing = (max_val - min_val) / (num_values - 1)
62
+
63
+ # Calculate order of magnitude of the spacing
64
+ magnitude = np.floor(np.log10(raw_spacing))
65
+
66
+ nice_factors = np.array([1, 2, 5, 10])
67
+ normalized_spacing = raw_spacing / (10**magnitude)
68
+ closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))]
69
+ nice_spacing = closest_factor * (10**magnitude)
70
+
71
+ nice_values = min_val + nice_spacing * np.arange(num_values)
72
+
73
+ # Adjust if last value exceeds max_val
74
+ if nice_values[-1] < max_val - nice_spacing:
75
+ last_val = nice_values[-1]
76
+ nice_values = np.append(nice_values, [last_val + nice_spacing])
77
+
78
+ return [val for val in nice_values if min_val <= val <= max_val]
79
+
80
+
81
+ def format_tick_label(val):
82
+ """
83
+ Format w/ scientific notation for large/small values.
84
+ """
85
+ if val != 0:
86
+ magnitude = np.abs(np.floor(np.log10(np.abs(val))))
87
+ if magnitude > 2:
88
+ return f'{val:.1e}'
89
+ elif magnitude > 1:
90
+ return f'{val:.0f}'
91
+ elif magnitude > 0:
92
+ return f'{val:.1f}'
93
+ else:
94
+ return f'{val:.2f}'
95
+ else:
96
+ return f'{val}'
97
+
98
+
99
+ def plot_simple(system, rho, fontscale = 1):
100
+ # Major axis of outer boundary
101
+ b2 = system.b2
102
+
103
+ # Get x, y grids from 'PDESystem' object
104
+ x, y = system.get_xy_grids()
105
+
106
+ # wrap around values for continuity
107
+ rho = np.append(rho, rho[:, 0:1], axis = 1)
108
+
109
+ # Color bar limits
110
+ vmin = np.nanmin(rho)
111
+ vmax = np.nanmax(rho)
112
+
113
+ fig = plt.figure(figsize = (5, 5))
114
+ ax = plt.gca()
115
+
116
+ im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
117
+ ax.set_title('rho(x,y)', fontsize = fontscale * 16)
118
+ for label in ax.get_xticklabels() + ax.get_yticklabels():
119
+ label.set_fontsize(fontscale * 12)
120
+ ax.set_aspect(1.0)
121
+ fac = 1.05
122
+ ax.set_xlim([-fac * b2, fac * b2])
123
+ ax.set_ylim([-fac * b2, fac * b2])
124
+
125
+ cbar = fig.colorbar(im, shrink = 0.8)
126
+
127
+ # Set colorbar ticks and labels to "nice" values
128
+ nice_values = find_nice_values(vmin, vmax, num_values = 5)
129
+ cbar.set_ticks(nice_values)
130
+ cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x)))
131
+
132
+ return fig
133
+
134
+
135
+ def evaluate_2d_expression(expr_str, xvals, yvals):
136
+ x, y = sympy.symbols('s t')
137
+ expr = sympy.sympify(expr_str)
138
+ f = sympy.lambdify((x, y), expr, modules = ['numpy'])
139
+ result = f(xvals, yvals)
140
+ if isinstance(result, (int, float)):
141
+ return result * np.ones(xvals.shape)
142
+ return f(xvals, yvals)
143
+
144
+
145
+ '''
146
+ # Currently unused in gradio interface
147
+ def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations):
148
+ # Direct solution
149
+ start = timeit.default_timer()
150
+
151
+ pde_config = {}
152
+ for key in ['nx1', 'nx2', 'nx3']:
153
+ pde_config[key] = stnn_config[key]
154
+ pde_config['ell'] = ell
155
+ pde_config['eccentricity'] = eccentricity
156
+ pde_config['a2'] = a2
157
+ system = PDESystem(pde_config)
158
+
159
+ try:
160
+ ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
161
+ except:
162
+ raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
163
+ try:
164
+ obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
165
+ except:
166
+ raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")
167
+
168
+ if np.any(np.isnan(ibf_data)):
169
+ raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.")
170
+
171
+ if np.any(np.isnan(obf_data)):
172
+ raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.")
173
+
174
+ ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)
175
+
176
+ L_xp = csr_matrix(system.L) # Sparse matrix representation of the PDE operator
177
+ nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
178
+ b_xp = asarray(b.reshape((nx1 * nx2 * nx3,))) # r.h.s. vector
179
+
180
+ def callback(res):
181
+ print(f'GMRES residual: {res}')
182
+
183
+ f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback)
184
+
185
+ residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp))
186
+
187
+ if info > 0:
188
+ warnings.simplefilter('always')
189
+ warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning)
190
+
191
+ f = asnumpy(f_xp)
192
+ rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1)
193
+ direct_time = timeit.default_timer() - start
194
+ print(f'Done with direct solution. Time: {direct_time} seconds.')
195
+
196
+ fig = plot_simple(system, rho_direct)
197
+ return fig, info
198
+ '''
199
+
200
+
201
+ def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str):
202
+ if a2 <= eccentricity:
203
+ raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).')
204
+
205
+ pde_config = {}
206
+ for key in ['nx1', 'nx2', 'nx3']:
207
+ pde_config[key] = stnn_config[key]
208
+ pde_config['ell'] = ell
209
+ pde_config['eccentricity'] = eccentricity
210
+ pde_config['a2'] = a2
211
+ system = PDESystem(pde_config)
212
+
213
+ try:
214
+ ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
215
+ except:
216
+ raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
217
+ try:
218
+ obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
219
+ except:
220
+ raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")
221
+
222
+ if np.any(np.isnan(ibf_data)):
223
+ raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.")
224
+
225
+ if np.any(np.isnan(obf_data)):
226
+ raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.")
227
+
228
+ # Permute and reshape boundary data to the format expected by the STNN model
229
+ ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)
230
+
231
+ '''
232
+ # Currently unused in gradio interface
233
+ ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id)
234
+ '''
235
+
236
+ # Load some relevant quantities from the config dictionaries
237
+ ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max']
238
+ a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max']
239
+ nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3']
240
+
241
+ # Combine boundary data in single vector
242
+ bf = np.zeros((1, 2 * nx2, nx3 // 2))
243
+ bf[:, :nx2, :] = ibf_data[np.newaxis, ...]
244
+ bf[:, nx2:, :] = obf_data[np.newaxis, ...]
245
+
246
+ # Normalize and combine parameters
247
+ params = np.zeros((1, 3))
248
+ params[0, 0] = (a2 - a2_min) / (a2_max - a2_min)
249
+ params[0, 1] = (ell - ell_min) / (ell_max - ell_min)
250
+ params[0, 2] = eccentricity
251
+
252
+ rho = model.predict([params, bf])
253
+ fig = plot_simple(system, rho[0, ...])
254
+ return fig
255
+
256
+
257
+ with open('T5_config.json', 'r', encoding = 'utf-8') as json_file:
258
+ stnn_config = json.load(json_file)
259
+
260
+ model = stnn.build_stnn(stnn_config)
261
+ model.load_weights('T5_weights.h5')
262
+
263
+ with gr.Blocks() as demo:
264
+ gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo"
265
+ "\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) "
266
+ "to solve a parametric PDE problem on an elliptical annular domain. "
267
+ "See the paper for a detailed description of the problem and its applications."
268
+ "<br/>The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, "
269
+ "including intructions for solving the PDE using a conventional iterative method (GMRES). "
270
+ "Due to the long runtime of solving the PDE in this way, it is not included in the demo.")
271
+ gr.Markdown("<br/>The PDE is "
272
+ "$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, "
273
+ "where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. "
274
+ "Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to "
275
+ "the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! "
276
+ "on elliptical annular domains parameterized as shown below. ",
277
+ latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}])
278
+ with gr.Row():
279
+ with gr.Column():
280
+ gr.Markdown(
281
+ "## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, "
282
+ "and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False},
283
+ {"left": "!", "right": "!", "display": True}])
284
+ ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0)
285
+ eccentricity_input = gr.Number(
286
+ label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)",
287
+ value = 0.5, minimum = 0.0, maximum = 0.999)
288
+ a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0)
289
+ gr.Markdown(
290
+ "## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, "
291
+ "related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. "
292
+ "Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values "
293
+ "between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points "
294
+ "distributed uniformly over the allowable values of $s$ and $t$."
295
+ "<br/><br/>For the PDE problem to be well-posed, boundary data should only be specified where "
296
+ "$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the "
297
+ "inward-pointing unit normal vector. This requirement constrains the allowable values of $t$."
298
+ " and is automatically enforced when building boundary conditions from the user-specified expressions below.",
299
+ latex_delimiters = [{"left": "$", "right": "$", "display": False}])
300
+
301
+ inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))")
302
+ outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)")
303
+
304
+ submit_button = gr.Button("Submit")
305
+
306
+ with gr.Column():
307
+ gr.Markdown("## Predicted Solution")
308
+ predicted_output_plot = gr.Plot()
309
+
310
+ submit_button.click(
311
+ fn = predict_pde_solution,
312
+ inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary],
313
+ outputs = [predicted_output_plot]
314
+ )
315
+
316
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow>=2.15.0
2
+ numpy~=1.26.0
3
+ t3f~=1.2.0
4
+ scipy~=1.12.0
5
+ h5py~=3.10.0
6
+ matplotlib~=3.8.2
7
+ pydot~=1.4.2
8
+ openvino~=2023.3.0
9
+ pyyaml>=6.0.1
10
+ sympy
stnn/__init__.py ADDED
File without changes
stnn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (164 Bytes). View file
 
stnn/data/__init__.py ADDED
File without changes
stnn/data/function_generators.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def generate_piecewise_linear_function(num_pieces, lower, upper, delta = 0.3):
5
+ """
6
+ Generates a piece-wise linear function on the interval (lower, upper) that is 2*pi-periodic.
7
+
8
+ Args:
9
+ num_pieces (int): Number of linear pieces in the function.
10
+ lower (float): The lower range of the interval
11
+ upper (float): The upper range of the interval
12
+ delta (float): Parameter determining how rapidly the function varies between the grid points (i.e., modulates
13
+ the slopes of the piecewise functions). Larger values mean more variability. Default is 0.3.
14
+
15
+ Returns:
16
+ function: A piece-wise linear function.
17
+ """
18
+ # Generate equally spaced points in the interval (-pi, pi)
19
+ x_points = np.linspace(lower, upper, num_pieces + 1)
20
+
21
+ # Generate random y-values for each point
22
+ y_points = np.zeros(num_pieces + 1)
23
+ y_points[0] = np.random.uniform(-1, 1)
24
+ for n in range(1, y_points.shape[0]):
25
+ y_points[n] = y_points[n - 1] + 0.3 * np.random.uniform(-1, 1)
26
+ y_points[0] = y_points[-1]
27
+ min_y = y_points.min()
28
+ # ensure y values are nonegative
29
+ if min_y < 0:
30
+ y_points -= min_y
31
+ y_points += np.random.uniform(0, 0.5) # random (constant) offset
32
+
33
+ def piecewise_linear(x):
34
+ """
35
+ Evaluates the piece-wise linear function at a given x.
36
+
37
+ Args:
38
+ x (float): The x-coordinate at which to evaluate the function.
39
+
40
+ Returns:
41
+ float: The y-coordinate of the function at x.
42
+ """
43
+ for i in range(num_pieces):
44
+ if x_points[i] <= x < x_points[i + 1]:
45
+ # Linear interpolation between the two points
46
+ slope = (y_points[i + 1] - y_points[i]) / (x_points[i + 1] - x_points[i])
47
+ return slope * (x - x_points[i]) + y_points[i]
48
+ return y_points[0] # For x = pi
49
+
50
+ return piecewise_linear
51
+
52
+
53
+ def generate_piecewise_constant_function(num_pieces, lower, upper):
54
+ """
55
+ Generates a piece-wise constant function on the interval (lower, upper) that is 2*pi periodic.
56
+
57
+ Args:
58
+ num_pieces (int): Number of constant pieces in the function.
59
+ lower (float): The lower range of the interval
60
+ upper (float): The upper range of the interval
61
+
62
+ Returns:
63
+ function: A piece-wise constant function.
64
+ """
65
+ # Generate equally spaced points in the interval (-pi, pi)
66
+ x_points = np.linspace(lower, upper, num_pieces + 1)
67
+
68
+ # Generate random y-values for each constant piece
69
+ y_values = np.random.rand(num_pieces) * 2 - 0 # Random values between 0 and 1
70
+
71
+ # Ensure the function is 2*pi periodic
72
+ y_values = np.append(y_values, y_values[0])
73
+
74
+ def piecewise_constant(x):
75
+ """
76
+ Evaluates the piece-wise constant function at a given x.
77
+
78
+ Args:
79
+ x (float): The x-coordinate at which to evaluate the function.
80
+
81
+ Returns:
82
+ float: The y-coordinate of the function at x.
83
+ """
84
+ for i in range(num_pieces):
85
+ if x_points[i] <= x < x_points[i + 1]:
86
+ return y_values[i]
87
+ return y_values[0] # For x = pi
88
+
89
+ return piecewise_constant
90
+
91
+
92
+ def generate_piecewise_bc(x2_grid, x3_grid, num_pieces):
93
+ """
94
+ Generates a piecewise linear function on the domain defined by x2_grid and x3_grid.
95
+
96
+ Args:
97
+ x2_grid (numpy.ndarray): A 2D array, x2 grid values
98
+ x3_grid (numpy.ndarray): A 2D array, x3 grid values
99
+ num_pieces (int): The number of pieces in the piecewise linear function.
100
+
101
+ Returns:
102
+ numpy.ndarray: A 2D array representing the piecewise linear function
103
+ """
104
+ x2_fun = generate_piecewise_linear_function(num_pieces, lower = x2_grid.min(), upper = x2_grid.max())
105
+ x3_fun = generate_piecewise_linear_function(num_pieces, lower = x3_grid.min(), upper = x3_grid.max())
106
+
107
+ x2vals_1d = np.zeros_like(x2_grid[:, 0])
108
+ x3vals_1d = np.zeros_like(x3_grid[0, :])
109
+ for i in range(x2vals_1d.shape[0]):
110
+ x2vals_1d[i] = x2_fun(x2_grid[i, 0])
111
+ for i in range(x3vals_1d.shape[0]):
112
+ x3vals_1d[i] = x3_fun(x3_grid[0, i])
113
+ x2vals_2d, x3vals_2d = np.meshgrid(x2vals_1d, x3vals_1d, indexing = 'ij')
114
+ return x2vals_2d * x3vals_2d
115
+
116
+
117
+ def random_2d_gaussian(theta, phi):
118
+ """
119
+ Generates a 2D Gaussian G(x,y), where
120
+ x = np.cos(0.5 * freq_x * theta - phase_x)
121
+ y = np.cos(0.5 * freq_y * phi - phase_y)
122
+ Here, the frequencies and phases are randomly sampled, and (theta, phi) define a 2D meshgrid.
123
+
124
+ Args:
125
+ theta (numpy.ndarray): 2D array, meshgrid of the first coordinate
126
+ phi (numpy.ndarray): 2D array, meshgrid of the second coordinate
127
+
128
+ Returns:
129
+ numpy.ndarray: A 2D array representing the values of the Gaussian on the grid.
130
+ """
131
+ phase_x = np.random.uniform(0, 2 * np.pi)
132
+ phase_y = np.random.uniform(0, 2 * np.pi)
133
+ freq_x = np.random.randint(1, 2)
134
+ freq_y = np.random.randint(1, 2)
135
+
136
+ x = np.cos(0.5 * freq_x * theta - phase_x)
137
+ y = np.cos(0.5 * freq_y * phi - phase_y)
138
+
139
+ sigma_x = np.random.uniform(0.1, 3.0)
140
+ sigma_y = np.random.uniform(0.1, 1.0)
141
+ rho = 0
142
+
143
+ covariance_matrix = np.array([[sigma_x**2, rho * sigma_x * sigma_y],
144
+ [rho * sigma_x * sigma_y, sigma_y**2]])
145
+ inv_sigma_xx = 1.0 / sigma_x**2
146
+ inv_sigma_yy = 1.0 / sigma_y**2
147
+ inv_sigma_xy = -rho / (sigma_x * sigma_y)
148
+
149
+ if np.any(np.linalg.eigvals(covariance_matrix) < 0):
150
+ raise ValueError('Covariance matrix is not positive semi-definite.')
151
+
152
+ def gaussian_2d(x, y):
153
+ return np.exp(-0.5 * (inv_sigma_xx * x**2 + inv_sigma_yy * y**2 + 2 * inv_sigma_xy * x * y))
154
+
155
+ gaussian_values = gaussian_2d(x, y)
156
+ return gaussian_values
157
+
158
+
159
+ def generate_random_functions(N, X, Y, num_terms = 16, min_freq = 1, max_freq = 16, func_gen_id = 0):
160
+ """
161
+ Generates N random 2pi-periodic functions on a 2D grid as a Fourier series, with different types of
162
+ modulation applied to the amplitudes.
163
+
164
+ Args:
165
+ N (int): Number of functions to generate.
166
+ X (numpy.ndarray): 2D array representing the values of the first coordinate on the grid
167
+ Y (numpy.ndarray): 2D array representing the values of the second coordinate on the grid
168
+ num_terms (int, optional): Number of terms in the Fourier series expansion. Default is 16.
169
+ min_freq (int, optional): Minimum frequency for the Fourier series terms. Default is 1.
170
+ max_freq (int, optional): Maximum frequency for the Fourier series terms. Default is 16.
171
+ func_gen_id (int, optional): Type of function to generate based on the decay of the expansion coefficients
172
+ as frequency is increased. Values can range from -1 to 4. Default is 0.
173
+
174
+ Returns:
175
+ numpy.ndarray: A 3D numpy array of shape (N, nx, ny) containing the function values.
176
+
177
+ Raises:
178
+ ValueError: If max_freq is less than min_freq or if an invalid func_gen_id is provided.
179
+ """
180
+
181
+ # Check if the maximum frequency is less than the minimum frequency
182
+ if max_freq < min_freq:
183
+ raise ValueError('max_freq cannot be less than min_freq')
184
+
185
+ # Generate uniformly distributed functions if func_gen_id is -1
186
+ if func_gen_id == -1:
187
+ F_batch = np.random.uniform(0, 1, size = (N,) + X.shape)
188
+ return F_batch
189
+
190
+ # Initialize the batch of functions with zeros
191
+ F_batch = np.zeros((N,) + X.shape)
192
+
193
+ # Loop through each function to be generated
194
+ for n in range(N):
195
+ # Add a cosine term with a half frequency with 20% chance
196
+ if np.random.uniform(0, 1) < 0.2:
197
+ amp_cos_half = np.random.uniform(0, 1) # Amplitude for cosine term
198
+ phase_cos_half = np.random.uniform(0, 2 * np.pi) # Phase shift for cosine term
199
+ F_batch[n] += amp_cos_half * np.cos(0.5 * X - phase_cos_half)
200
+
201
+ # Fourier series
202
+ for _ in range(num_terms):
203
+ amplitude = np.random.uniform(-1, 1) # Random amplitude for y-component
204
+ kx, ky = np.random.randint(min_freq, max_freq + 1, 2) # Frequencies for x and y components
205
+ phase_x = np.random.uniform(0, 2 * np.pi) # Phase shift for x-component
206
+ phase_y = np.random.uniform(0, 2 * np.pi) # Phase shift for y-component
207
+
208
+ # Determine the coefficient amplitude based on the func_gen_id
209
+ if func_gen_id == 0:
210
+ # No decay applied to amplitude
211
+ pass
212
+ elif func_gen_id == 1:
213
+ if np.random.uniform(0, 1) < 0.5:
214
+ amplitude = amplitude / kx
215
+ else:
216
+ amplitude = amplitude / ky
217
+ elif func_gen_id == 2:
218
+ amplitude = amplitude / (kx * ky)
219
+ elif func_gen_id == 3:
220
+ amplitude = amplitude / (kx * kx * ky * ky)
221
+ elif func_gen_id == 4:
222
+ # Gaussian decay with random covariance matrix
223
+ sxx = np.random.uniform(0.1, 1.0)
224
+ syy = np.random.uniform(0.1, 1.0)
225
+ sxy = np.random.uniform(0.1, 1.0)
226
+ amplitude = amplitude * np.exp(-(sxx * kx**2 + syy * ky**2 + sxy * kx * ky))
227
+ else:
228
+ raise ValueError(
229
+ f'Invalid func_gen_id. Should be an integer in the range [-1, 4], but received {func_gen_id}')
230
+
231
+ # Add the term to the nth function in the batch
232
+ F_batch[n] += amplitude * np.cos(kx * X - phase_x) * np.cos(ky * Y - phase_y)
233
+
234
+ # Adjust the function to ensure it's positive
235
+ minF = np.min(F_batch[n])
236
+ if minF < 0:
237
+ F_batch[n] -= minF
238
+
239
+ return F_batch
stnn/data/preprocessing.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import numpy as np
3
+
4
+ # If STRICT_WARNING = True, the program exits when negative values are detected in ibf, obf, or rho
5
+ # This is important to check because negative values are unphysical.
6
+ STRICT_WARNING = True
7
+
8
+
9
+ def verify_nonnegative(fname, ibf, obf, rho):
10
+ """
11
+ Check ibf, obf, and rho for negative values.
12
+ """
13
+ found_warning = False
14
+ if np.any(ibf < 0):
15
+ print(f'Warning: negative values detected in array "ibf" in {fname}; min val: {ibf.min()}')
16
+ found_warning = True
17
+ elif np.any(obf < 0):
18
+ print(f'Warning: negative values detected in array "obf" in {fname}')
19
+ found_warning = True
20
+ elif np.any(rho < 0):
21
+ print(f'Warning: negative values detected in array "rho" in {fname}')
22
+ found_warning = True
23
+
24
+ if found_warning and STRICT_WARNING:
25
+ print(f'Exiting program. To avoid exiting on this warning, set STRICT_WARNING to False in {__file__.name}')
26
+ exit()
27
+
28
+
29
+ def get_data_from_file(fname, nx2, nx3, Nrange = None):
30
+ """
31
+ Retrieves training X from the given HDF5 file. Assumes that the PDE parameters
32
+ are stored in datasets with their respective names, i.e., 'ell', 'a1', 'a2'. Likewise,
33
+ the density rho(x1,x2) and boundary X ibf(x2,x3) / obf(x2,x3) are stored in datasets
34
+ 'rho', 'ibf', and 'obf'.
35
+
36
+ Args:
37
+ nx2 (int): Second grid dimension
38
+ nx3 (int): Third grid dimension
39
+ fname (str): Path to the HDF5 file containing the X.
40
+ Nrange (tuple, optional): A tuple of two integers specifying the range of X to extract (start, end).
41
+ Defaults to None.
42
+
43
+ Returns:
44
+ tuple: Tuple of extracted X
45
+
46
+ Raises:
47
+ ValueError: If the file does not contain the required datasets.
48
+ """
49
+ if not isinstance(fname, str):
50
+ raise TypeError('Filename must be a string.')
51
+ type_check1 = not (Nrange is None or isinstance(Nrange, (tuple, list)))
52
+ type_check2 = False
53
+ if isinstance(Nrange, (tuple, list)):
54
+ type_check2 = len(Nrange) != 2
55
+ if not type_check2:
56
+ type_check2 = not all((isinstance(i, int) or i is None) for i in Nrange)
57
+ if type_check1 or type_check2:
58
+ raise TypeError('Nrange must be a length-2 tuple or list of integers.')
59
+
60
+ if Nrange is None:
61
+ N1, N2 = None, None
62
+ else:
63
+ N1, N2 = Nrange
64
+
65
+ # Check that all datasets are present
66
+ dset_names = ['ell', 'a1', 'a2', 'rho', 'ibf', 'obf']
67
+ with h5py.File(fname, 'r') as input_file:
68
+ missing_keys = [key for key in dset_names if key not in input_file.keys()]
69
+ if missing_keys:
70
+ raise ValueError(f"Missing / incorrectly labeled datasets in file {fname}.'"
71
+ f"Could not find datasets: {', '.join(missing_keys)}")
72
+
73
+ ell = input_file['ell'][N1:N2]
74
+ a2 = input_file['a2'][N1:N2] # minor axis of outer boundary
75
+ a1 = input_file['a1'][N1:N2] # minor axis of inner boundary
76
+ eccentricity = np.ones_like(a1) - a1 # eccentricity of inner boundary
77
+ rho = input_file['rho'][N1:N2]
78
+ ibf = input_file['ibf'][N1:N2] # boundary X on inner boundary
79
+ obf = input_file['obf'][N1:N2] # boundary X on outer boundary
80
+
81
+ verify_nonnegative(fname, ibf, obf, rho)
82
+
83
+ # Combine 'ibf' and 'obf' into single array
84
+ N = rho.shape[0]
85
+ bf = np.zeros((N, 2 * nx2, nx3 // 2), dtype = np.float32)
86
+ bf[:, :nx2, :] = ibf
87
+ bf[:, nx2:, :] = obf
88
+
89
+ return a2, ell, eccentricity, bf, rho
90
+
91
+
92
+ def reshape_and_stack(a2, ell, ecc):
93
+ a2 = a2.reshape((-1, 1))
94
+ ell = ell.reshape((-1, 1))
95
+ ecc = ecc.reshape((-1, 1))
96
+ return np.hstack([a2, ell, ecc])
97
+
98
+
99
+ def apply_normalization(bf, rho):
100
+ fac = np.average(np.abs(rho), axis = (1, 2))
101
+ fac = fac.reshape((-1, 1, 1))
102
+ bf /= fac
103
+ rho /= fac
104
+ return bf, rho
105
+
106
+
107
+ def load_data(files, nx2, nx3, ell_min, ell_max, a2_min, a2_max,
108
+ Nrange_list = None, params_slice = None, normalize_data = False):
109
+ """
110
+ Loads X from the specified files and processes it for use with the STNN.
111
+
112
+ Args:
113
+ nx2 (int): Second grid dimension
114
+ nx3 (int): Third grid dimension
115
+ ell_min / ell_max (float): Minimum / maximum value of 'ell' over parameter space
116
+ a2_min / a2_max (float): Minimum / maximum value of 'a2' over parameter space
117
+ files (str or list of str): List of file paths containing the X
118
+ Nrange_list (list of tuples, optional): Slice indices for the extracting X from the corresponding file. If
119
+ given, must have the same number of elements as 'file_list'. Defaults
120
+ to None.
121
+ params_slice (slice, optional): Boolean array for selecting X over a subset of parameter space (ell, a1, a2).
122
+ Defaults to None.
123
+ normalize_data (bool, optional): Flag to normalize 'bf' and 'rho'. Defaults to False.
124
+
125
+ Returns:
126
+ tuple: A tuple containing the values of ell, a1, a2, bf, and rho. The parameters
127
+ ell, a1, a2 are combined into a single array 'params'.
128
+ """
129
+ if isinstance(files, (list, tuple)) and len(files) == 0:
130
+ raise ValueError(f'List of files provided to "load_data" is empty.')
131
+ if not isinstance(files, (list, tuple)):
132
+ files = [files]
133
+ if Nrange_list is None or len(Nrange_list) == 0:
134
+ # Default
135
+ Nrange_list = [None for _ in range(len(files))]
136
+ else:
137
+ # User-specified; check shapes
138
+ if not isinstance(Nrange_list, (list, tuple)):
139
+ Nrange_list = [Nrange_list]
140
+ if len(files) != len(Nrange_list):
141
+ raise ValueError('List of input files must have same length as list of Nrange tuples')
142
+ a2_list = []
143
+ ell_list = []
144
+ ecc_list = []
145
+ bf_list = []
146
+ rho_list = []
147
+
148
+ # Get X from each file and add to the lists
149
+ for file, Nrange in zip(files, Nrange_list):
150
+ a2, ell, ecc, bf, rho = get_data_from_file(file, nx2, nx3, Nrange = Nrange)
151
+ a2_list.append(a2)
152
+ ell_list.append(ell)
153
+ ecc_list.append(ecc)
154
+ bf_list.append(bf)
155
+ rho_list.append(rho)
156
+
157
+ a2 = np.concatenate(a2_list)
158
+ ell = np.concatenate(ell_list)
159
+ ecc = np.concatenate(ecc_list)
160
+ bf = np.vstack(bf_list)
161
+ rho = np.vstack(rho_list)
162
+
163
+ # Map ell and a2 values onto [0, 1]
164
+ ell = (ell - ell_min) / (ell_max - ell_min)
165
+ a2 = (a2 - a2_min) / (a2_max - a2_min)
166
+
167
+ params = reshape_and_stack(a2, ell, ecc)
168
+
169
+ if not params_slice is None:
170
+ # Extract subset of X, if params_slice is given
171
+ params = params[params_slice, ...]
172
+ bf = bf[params_slice, ...]
173
+ rho = rho[params_slice, ...]
174
+
175
+ if normalize_data:
176
+ bf, rho = apply_normalization(bf, rho)
177
+
178
+ return params, bf, rho
179
+
180
+
181
+ def load_training_data(file_list, nx2, nx3, ell_min, ell_max, a2_min, a2_max, Nrange_list = None,
182
+ params_slice = None, test_size = 0.1, random_state = 23, normalize_data = True):
183
+ """
184
+ Loads training X from specified files and preprocesses it for use with training the STNN.
185
+
186
+ This function wraps the 'load_data' function, adding additional steps specific to preparing training X.
187
+
188
+ Args:
189
+ nx2 (int): Second grid dimension
190
+ nx3 (int): Third grid dimension
191
+ ell_min / ell_max (float): Minimum / maximum value of 'ell' over parameter space
192
+ a2_min / a2_max (float): Minimum / maximum value of 'a2' over parameter space
193
+ file_list (list of str): List of file paths containing the X
194
+ Nrange_list (list of tuples, optional): Slice indices for the extracting X from the corresponding file. If
195
+ given, must have the same number of elements as 'file_list'. Defaults
196
+ to None.
197
+ params_slice (slice, optional): Boolean array for selecting X over a subset of parameter space (ell, a1, a2).
198
+ Defaults to None.
199
+ test_size (float, optional): Size of the test/validation dataset as a fraction of the total dataset size.
200
+ Defaults to 0.1.
201
+ random_state (int, optional): Random seed used to select the train-test split. Defaults to 23.
202
+ normalize_data (bool, optional): Flag to normalize 'bf' and 'rho'. Defaults to False.
203
+
204
+ Returns:
205
+ tuple: A tuple containing the values of ell, a1, a2, bf, and rho. The parameters
206
+ ell, a1, a2 are combined into a single array 'params'.
207
+ """
208
+ params, bf, rho = load_data(file_list, nx2, nx3, ell_min, ell_max, a2_min, a2_max,
209
+ Nrange_list = Nrange_list, params_slice = params_slice, normalize_data = normalize_data)
210
+
211
+ (rho_train, rho_test,
212
+ Y_train, Y_test) = train_test_split(rho, [params, bf], test_size = test_size, random_state = random_state)
213
+
214
+ params_train = Y_train[0]
215
+ params_test = Y_test[0]
216
+ bf_train = Y_train[1]
217
+ bf_test = Y_test[1]
218
+
219
+ print('Finished loading training X:')
220
+ print(f' params_train.shape:\t{params_train.shape}')
221
+ print(f' bf_train.shape:\t{bf_train.shape}')
222
+ print(f' rho_train.shape:\t{rho_train.shape}')
223
+ print(f' params_test.shape:\t{params_test.shape}')
224
+ print(f' bf_test.shape:\t{bf_test.shape}')
225
+ print(f' rho_test.shape:\t{rho_test.shape}')
226
+
227
+ # Compute min/max extent of training X in parameter space.
228
+ # Note that 'params' is denormalized before computing the max/min.
229
+ min_a2 = np.min(a2_min + (a2_max - a2_min) * params[:, 0])
230
+ min_ell = np.min(ell_min + (ell_max - ell_min) * params[:, 1])
231
+ min_ecc = np.min(params[:, 2])
232
+
233
+ max_a2 = np.max(a2_min + (a2_max - a2_min) * params[:, 0])
234
+ max_ell = np.max(ell_min + (ell_max - ell_min) * params[:, 1])
235
+ max_ecc = np.max(params[:, 2])
236
+
237
+ print('')
238
+ print(f' Number of circle samples (train):\t{np.sum(params[:, 2] < 1e-7)}')
239
+ print(f' Number of ellipse samples (train):\t{np.sum(params[:, 2] > 0)}')
240
+ print(f' Min .. Max in training X:')
241
+ print(f' ell:\t{min_ell:.2f} .. {max_ell:.2f}')
242
+ print(f' a2:\t{min_a2:.2f} .. {max_a2:.2f}')
243
+ print(f' ecc:\t{min_ecc:.2f} .. {max_ecc:.2f}')
244
+ print('-------------------------------------------')
245
+
246
+ return params_train, bf_train, rho_train, params_test, bf_test, rho_test
247
+
248
+
249
+ def train_test_split(X, Y, test_size = 0.1, random_state = None):
250
+ """
251
+ Split (X, Y) pairs into random train and test subsets.
252
+
253
+ Args:
254
+ X (np.ndarray or list of arrays): Input dataset
255
+ Y (np.ndarray or list of arrays): Labels for the dataset
256
+ test_size (float): Proportion of the dataset to include in the test split
257
+ random_state (int): Controls the shuffling applied to the X and Y before applying the split
258
+
259
+ Returns:
260
+ X_train, X_test, Y_train, Y_test: Lists containing train-test split of the dataset. The format is
261
+ the same as the input X. For example, if 'X' is an array and 'Y' is a list of arrays, then X_train
262
+ and X_test will be arrays, and Y_train and Y_test will be lists of arrays.
263
+
264
+ Note: This function is included primarilyto reduce module dependency requirements, and it may not be memory-efficient
265
+ for large datasets. sklearn.model_selection.train_test_split has similar functionality and may be preferred
266
+ for performance-critical applications.
267
+ """
268
+ if len(X) == 0 or len(Y) == 0:
269
+ raise ValueError("Input arrays/lists X and Y cannot be empty.")
270
+
271
+ input_X_is_array = isinstance(X, np.ndarray)
272
+ input_Y_is_array = isinstance(Y, np.ndarray)
273
+
274
+ if input_X_is_array:
275
+ X = [X]
276
+ if input_Y_is_array:
277
+ Y = [Y]
278
+
279
+ total_samples = X[0].shape[0]
280
+
281
+ # Check for consistent number of samples across all datasets
282
+ if any(x.shape[0] != total_samples for x in X) or any(y.shape[0] != total_samples for y in Y):
283
+ raise ValueError('Inconsistent number of samples.')
284
+
285
+ Ntest = int(test_size * total_samples)
286
+ if Ntest < 1 or Ntest > total_samples:
287
+ raise ValueError('Size of test dataset cannot be less than 1 or greater than the total number of samples.')
288
+
289
+ if random_state is not None:
290
+ np.random.seed(random_state)
291
+
292
+ # Shuffle indices
293
+ indices = np.arange(total_samples)
294
+ np.random.shuffle(indices)
295
+
296
+ # Apply shuffled indices to all datasets
297
+ shuffled_X = [x[indices] for x in X]
298
+ shuffled_Y = [y[indices] for y in Y]
299
+
300
+ # Split X and Y
301
+ X_train = [x[:-Ntest] for x in shuffled_X]
302
+ X_test = [x[-Ntest:] for x in shuffled_X]
303
+ Y_train = [y[:-Ntest] for y in shuffled_Y]
304
+ Y_test = [y[-Ntest:] for y in shuffled_Y]
305
+
306
+ # Convert back to arrays if original input was array
307
+ if input_X_is_array:
308
+ X_train, X_test = X_train[0], X_test[0]
309
+ if input_Y_is_array:
310
+ Y_train, Y_test = Y_train[0], Y_test[0]
311
+
312
+ return X_train, X_test, Y_train, Y_test
stnn/data/test_functions.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ """
4
+ Interface to a subset of the test functions listed at
5
+ https://en.wikipedia.org/wiki/Test_functions_for_optimization
6
+ """
7
+
8
+
9
+ def rastrigin(x, y):
10
+ args = (x, y)
11
+ A = 10
12
+ return A * len(x) + sum([(xi**2 - A * np.cos(2 * np.pi * xi)) for xi in args])
13
+
14
+
15
+ def ackley(x, y):
16
+ return -20 * np.exp(-0.2 * np.sqrt(0.5 * (x**2 + y**2))) - \
17
+ np.exp(0.5 * (np.cos(2 * np.pi * x) + np.cos(2 * np.pi * y))) + np.e + 20
18
+
19
+
20
+ def sphere(x, y):
21
+ return x**2 + y**2
22
+
23
+
24
+ def rosenbrock(x, y):
25
+ return 100 * (y - x**2)**2 + (1 - x)**2
26
+
27
+
28
+ def beale(x, y):
29
+ return (1.5 - x + x * y)**2 + (2.25 - x + x * y**2)**2 + (2.625 - x + x * y**3)**2
30
+
31
+
32
+ def goldstein_price(x, y):
33
+ return (1 + (x + y + 1)**2 * (19 - 14 * x + 3 * x**2 - 14 * y + 6 * x * y + 3 * y**2)) * \
34
+ (30 + (2 * x - 3 * y)**2 * (18 - 32 * x + 12 * x**2 + 48 * y - 36 * x * y + 27 * y**2))
35
+
36
+
37
+ def booth(x, y):
38
+ return (x + 2 * y - 7)**2 + (2 * x + y - 5)**2
39
+
40
+
41
+ def bukin(x, y):
42
+ return 100 * np.sqrt(abs(y - 0.01 * x**2)) + 0.01 * abs(x + 10)
43
+
44
+
45
+ def matyas(x, y):
46
+ return 0.26 * (x**2 + y**2) - 0.48 * x * y
47
+
48
+
49
+ def levi(x, y):
50
+ return np.sin(3 * np.pi * x)**2 + (x - 1)**2 * (1 + np.sin(3 * np.pi * y)**2) + \
51
+ (y - 1)**2 * (1 + np.sin(2 * np.pi * y)**2)
52
+
53
+
54
+ def himmelblau(x, y):
55
+ return (x**2 + y - 11)**2 + (x + y**2 - 7)**2
56
+
57
+
58
+ def three_hump_camel(x, y):
59
+ return 2 * x**2 - 1.05 * x**4 + x**6 / 6 + x * y + y**2
60
+
61
+
62
+ def easom(x, y):
63
+ return -np.cos(x) * np.cos(y) * np.exp(-((x - np.pi)**2 + (y - np.pi)**2))
64
+
65
+
66
+ def cross_in_tray(x, y):
67
+ return -0.0001 * (abs(np.sin(x) * np.sin(y) * np.exp(abs(100 - np.sqrt(x**2 + y**2) / np.pi))) + 1)**0.1
68
+
69
+
70
+ def eggholder(x, y):
71
+ return -(y + 47) * np.sin(np.sqrt(abs(x / 2 + (y + 47)))) - x * np.sin(np.sqrt(abs(x - (y + 47))))
72
+
73
+
74
+ def holder_table(x, y):
75
+ return -abs(np.sin(x) * np.cos(y) * np.exp(abs(1 - np.sqrt(x**2 + y**2) / np.pi)))
76
+
77
+
78
+ def mccormick(x, y):
79
+ return np.sin(x + y) + (x - y)**2 - 1.5 * x + 2.5 * y + 1
80
+
81
+
82
+ def schaffer2(x, y):
83
+ return 0.5 + (np.sin(x**2 - y**2)**2 - 0.5) / (1 + 0.001 * (x**2 + y**2))**2
84
+
85
+
86
+ def schaffer4(x, y):
87
+ return 0.5 + (np.cos(np.sin(abs(x**2 - y**2)))**2 - 0.5) / (1 + 0.001 * (x**2 + y**2))**2
88
+
89
+
90
+ def styblinski_tang(x, y):
91
+ args = (x, y)
92
+ return sum([xi**4 - 16 * xi**2 + 5 * xi for xi in args]) / 2
93
+
94
+
95
+ functions = [
96
+ rastrigin,
97
+ ackley,
98
+ sphere,
99
+ rosenbrock,
100
+ beale,
101
+ goldstein_price,
102
+ booth,
103
+ bukin,
104
+ matyas,
105
+ levi,
106
+ himmelblau,
107
+ three_hump_camel,
108
+ easom,
109
+ cross_in_tray,
110
+ eggholder,
111
+ holder_table,
112
+ mccormick,
113
+ schaffer2,
114
+ schaffer4,
115
+ styblinski_tang
116
+ ]
117
+
118
+ function_names = [
119
+ 'rastrigin',
120
+ 'ackley',
121
+ 'sphere',
122
+ 'rosenbrock',
123
+ 'beale',
124
+ 'goldstein_price',
125
+ 'booth',
126
+ 'bukin',
127
+ 'matyas',
128
+ 'levi',
129
+ 'himmelblau',
130
+ 'three_hump_camel',
131
+ 'easom',
132
+ 'cross_in_tray',
133
+ 'eggholder',
134
+ 'holder_table',
135
+ 'mccormick',
136
+ 'schaffer2',
137
+ 'schaffer4',
138
+ 'styblinski_tang'
139
+ ]
140
+
141
+ domains = {
142
+ 'rastrigin': (-5.12, 5.12),
143
+ 'ackley': (-5, 5),
144
+ 'sphere': (-1, 1),
145
+ 'rosenbrock': {'x': (-2, 2), 'y': (-10, 10)},
146
+ 'beale': (-4.5, 4.5),
147
+ 'goldstein_price': (-2, 2),
148
+ 'booth': (-10, 10),
149
+ 'bukin': {'x': (-15, -5), 'y': (-3, 3)},
150
+ 'matyas': (-10, 10),
151
+ 'levi': (-10, 10),
152
+ 'himmelblau': (-5, 5),
153
+ 'three_hump_camel': (-5, 5),
154
+ 'easom': (-100, 100),
155
+ 'cross_in_tray': (-10, 10),
156
+ 'eggholder': (-512, 512),
157
+ 'holder_table': (-10, 10),
158
+ 'mccormick': {'x': (-1.5, 4), 'y': (-3, 4)},
159
+ 'schaffer2': (-100, 100),
160
+ 'schaffer4': (-100, 100),
161
+ 'styblinski_tang': (-5, 5)
162
+ }
163
+
164
+
165
+ def scale_input(x, domain):
166
+ min_d, max_d = domain
167
+ return min_d + (max_d - min_d) * x
168
+
169
+
170
+ def get_test_function(X, Y, fun_idx):
171
+ """
172
+ Evaluates a function on inputs (X, Y).
173
+
174
+ Args:
175
+ X (float or array-like): The X input values to be scaled and used in the function.
176
+ Y (float or array-like): The Y input values to be scaled and used in the function.
177
+ Ignored if the function takes only one argument.
178
+ fun_idx (int): The index of the function to be retrieved from a predefined list 'functions'.
179
+
180
+ Returns:
181
+ (float or array-like), values of the function on the grid
182
+ """
183
+ func = functions[fun_idx]
184
+
185
+ domain = domains[func.__name__]
186
+ if isinstance(domain, dict):
187
+ x_scaled = scale_input(X, domain['x'])
188
+ y_scaled = scale_input(Y, domain['y'])
189
+ else:
190
+ x_scaled = scale_input(X, domain)
191
+ y_scaled = scale_input(Y, domain)
192
+
193
+ try:
194
+ output = func(X, Y)
195
+ except TypeError:
196
+ output = func(X)
197
+ return output
stnn/linalg_backend.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module imports and wrapper functions of the linear algebra backend.
3
+
4
+ If __usecupy__ is "True" and cupy is successfully imported, then
5
+
6
+ xp --> cupy
7
+ spx --> cupyx.scipy.sparse.linalg
8
+
9
+ Otherwise,
10
+
11
+ xp --> numpy
12
+ spx --> scipy.sparse
13
+
14
+ For example, if __usecupy__ is False, then
15
+
16
+ import numpy as np
17
+ import scipy.sparse.linalg as sp
18
+
19
+ is equivalent to
20
+
21
+ from stnn.linalg_backend import xp, spx
22
+ """
23
+ __usecupy__ = True
24
+
25
+ try:
26
+ # If CuPy is not preferred or available, fall back to NumPy
27
+ if not __usecupy__:
28
+ raise ImportError
29
+ import cupy as cp
30
+ import cupyx.scipy.sparse.linalg
31
+ import cupyx.scipy.sparse as cupy_sparse
32
+
33
+ xp = cp
34
+ spx = cupy_sparse
35
+ using_cupy = True
36
+ except ImportError:
37
+ import numpy as np
38
+ import scipy.sparse.linalg
39
+ import scipy.sparse as scipy_sparse
40
+
41
+ xp = np
42
+ spx = scipy_sparse
43
+ using_cupy = False
44
+
45
+
46
+ def csr_matrix(L):
47
+ """
48
+ Create a CSR (Compressed Sparse Row) matrix.
49
+
50
+ If CuPy is available and enabled, this function will create a CuPy CSR matrix.
51
+ Otherwise, it converts the given data to a SciPy CSR matrix.
52
+
53
+ Parameters:
54
+ L (array_like or sparse matrix): 2-D array or sparse matrix to convert.
55
+
56
+ Returns:
57
+ CSR matrix: The converted CSR matrix, using either CuPy or SciPy.
58
+ """
59
+ if using_cupy:
60
+ return spx.csr_matrix(L, dtype=xp.float64)
61
+ return L.tocsr()
62
+
63
+
64
+ def asnumpy(arr):
65
+ """
66
+ Convert an array from the backend library (CuPy or NumPy) to NumPy.
67
+ If NumPy is enabled, the input array is returned unchanged.
68
+ """
69
+ if using_cupy:
70
+ return cp.asnumpy(arr)
71
+ return arr
72
+
73
+
74
+ def asarray(arr):
75
+ """
76
+ Convert the input to an array of the backend library (CuPy or NumPy).
77
+ If NumPy is enabled, the input array is returned unchanged.
78
+ """
79
+ if using_cupy:
80
+ return cp.asarray(arr, dtype=cp.float64)
81
+ return arr
stnn/nn/__init__.py ADDED
File without changes
stnn/nn/stnn.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Input, Multiply, Add
3
+ from tensorflow.keras.models import Model
4
+
5
+ from .stnn_layers import TTL, SoftmaxEmbeddingLayer
6
+
7
+
8
+ def build_stnn(config):
9
+ """
10
+ Constructs a Stacked Tensorial Neural Network (STNN) as a TensorFlow model based on
11
+ the provided configuration dictionary.
12
+
13
+ Args:
14
+ config (dict): Configuration dictionary for the STNN model. Must contain the following entries:
15
+ - 'K' (int): Number of tensor networks to be stacked
16
+ - 'd' (int): Number of dense layers in the model's SoftmaxEmbeddingLayer.
17
+ - 'nx1', 'nx2', 'nx3' (int): Dimensions of the finite-difference grid
18
+ - All other required entries for the TTL class, not already listed above.
19
+
20
+ Returns:
21
+ tf.keras.Model: The constructed STNN model.
22
+
23
+ Raises:
24
+ ValueError: If the config dictionary does not contain positive integers 'K', 'd', 'nx2', 'nx3';
25
+ also if config['nx3'] is not divisible by 2.
26
+ """
27
+ required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
28
+ missing_keys = [key for key in required_keys if key not in config]
29
+ if missing_keys:
30
+ raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
31
+
32
+ for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
33
+ if not isinstance(config[key], int):
34
+ raise TypeError(f"{key} must be an integer.")
35
+
36
+ for key in ['nx1', 'nx2', 'nx3', 'K', 'd']:
37
+ if config[key] <= 0:
38
+ raise ValueError(f"{key} must be positive.")
39
+
40
+ if config['nx3'] % 2 == 1:
41
+ raise ValueError('Config error: nx3 must be divisible by 2.')
42
+
43
+ K = config['K'] # Number of tensor networks
44
+ d = config['d'] # Number of dense layers in SoftmaxEmbeddingLayer
45
+ input_shape = (2 * config['nx2'], config['nx3'] // 2, 1)
46
+ input_tensor = Input(shape = input_shape)
47
+
48
+ # Process parameter array (ell, a1, a2) and output weights for stacking the tensor networks
49
+ preprocess_layer = SoftmaxEmbeddingLayer(K, d)
50
+ params_input = Input(shape = (3,))
51
+ stack_weights = preprocess_layer(params_input)[:, tf.newaxis, tf.newaxis, :]
52
+
53
+ # Build the tensor networks using the custom keras layer class TLL
54
+ models = [TTL(config) for _ in range(K)]
55
+
56
+ # Combine the tensor networks based on the weights outputted by 'preprocess_layer'
57
+ weighted_outputs = []
58
+ for i, model in enumerate(models):
59
+ processed_output = model(input_tensor)
60
+ weighted_output = Multiply()([processed_output, stack_weights[..., i]])
61
+ weighted_outputs.append(weighted_output)
62
+ final_output = Add()(weighted_outputs)
63
+
64
+ model = Model(inputs = [params_input, input_tensor], outputs = final_output)
65
+
66
+ return model
stnn/nn/stnn_layers.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras.models import Sequential
4
+ from tensorflow.keras.layers import Reshape, Flatten
5
+ import t3f
6
+
7
+ import os
8
+ import logging
9
+
10
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
11
+ logging.getLogger('tensorflow').setLevel(logging.FATAL)
12
+
13
+
14
+ class SoftmaxEmbeddingLayer(tf.keras.layers.Layer):
15
+ """
16
+ Parameter embedding layer that generates the weights used for stacking the tensor networks. It
17
+ takes the parameter array, lambda = (ell, a1, a2), as input and outputs K numbers that sum to 1.
18
+
19
+ Attributes:
20
+ output_dim (int): The dimension of the output
21
+ expansion_dim (int): The dimension used for expanding the input in intermediate layers.
22
+ """
23
+
24
+ def __init__(self, output_dim, d, expansion_dim = 30, **kwargs):
25
+ super(SoftmaxEmbeddingLayer, self).__init__(**kwargs)
26
+ self.reduction_layer = None
27
+ self.expansion_layers = None
28
+ self.output_dim = output_dim
29
+ self.expansion_dim = expansion_dim
30
+ self.d = d # Number of dense layers
31
+
32
+ def build(self, input_shape):
33
+ # Expansion layers to increase dimensionality
34
+ self.expansion_layers = [tf.keras.layers.Dense(self.expansion_dim, activation = 'relu') for _ in range(self.d)]
35
+ # Reduction layer to bring dimensionality back to the desired output dimension
36
+ self.reduction_layer = tf.keras.layers.Dense(self.output_dim)
37
+
38
+ def call(self, inputs):
39
+ expanded = inputs
40
+ for layer in self.expansion_layers:
41
+ expanded = layer(expanded)
42
+ return tf.nn.softmax(self.reduction_layer(expanded))
43
+
44
+ def get_config(self):
45
+ return {'output_dim': self.output_dim, 'expansion_dim': self.expansion_dim}
46
+
47
+
48
+ class EinsumTTLRegularizer(tf.keras.regularizers.Regularizer):
49
+ """
50
+ Regularizer for the Einsum layer of the TTL layer class, penalizing high-frequency components of the
51
+ weights vector.
52
+
53
+ Attributes:
54
+ strength (float): The regularization strength.
55
+ midpoint (int): Index demarcating the inner and outer boundaries, i.e. x[:midpoint] contains
56
+ data for the inner boundary, and x[midpoint:] contains data for the outer boundary.
57
+ The regularization is designed so it does not penalize variations across this index.
58
+ """
59
+
60
+ def __init__(self, strength, midpoint):
61
+ self.strength = strength
62
+ self.midpoint = midpoint
63
+
64
+ def __call__(self, x):
65
+ diff = tf.abs(x[1:self.midpoint - 1] - x[0:self.midpoint - 2]) \
66
+ + tf.abs(x[self.midpoint + 1:2 * self.midpoint - 1] - x[self.midpoint:2 * self.midpoint - 2])
67
+ return self.strength * tf.reduce_sum(diff)
68
+
69
+ def get_config(self):
70
+ return {'strength': self.strength, 'midpoint': self.midpoint}
71
+
72
+
73
+ def cosine_initializer(kx = 1.0):
74
+ """
75
+ Initializer for the Einsum layer of the TTL layer class. Sets the weights to a linear combination
76
+ of cos(kx * x) and cos(2 * kx * x), where x is the weight vector.
77
+
78
+ Args:
79
+ kx (float, optional): Frequency of the cosine terms. Defaults to 1.0.
80
+
81
+ Returns:
82
+ _initializer: Weight initializer function
83
+ """
84
+
85
+ def _initializer(shape, dtype = None):
86
+ x_values = np.linspace(-np.pi, np.pi, shape[0])
87
+ cos_values = np.random.uniform(-0.1, 0.3) * np.abs(np.cos(kx * x_values)) \
88
+ + np.random.uniform(-0.05, 0.05) * np.abs(np.cos(2.0 * kx * x_values))
89
+ return tf.convert_to_tensor(-cos_values, dtype = dtype)
90
+
91
+ return _initializer
92
+
93
+
94
+ class EinsumTTL(tf.keras.layers.Layer):
95
+ """
96
+ Layer that contracts the input tensor over the second dimension before passing it to the TTL.
97
+ If regularization is enabled, it applies an `EinsumTTLRegularizer` to the kernels.
98
+
99
+ Attributes:
100
+ (nx2, nx3) (integers): Shape parameters characterizing input tensor dimensions. T
101
+ The shape of the input tensor is (2*nx2, nx3//2).
102
+ W (int): Number of einsum contractions
103
+ kernels (list): List of weight matrices for each einsum contraction
104
+ regularization_strength (float): The strength of the regularization if used.
105
+ use_regularization (bool): Flag to indicate whether regularization is used.
106
+ """
107
+
108
+ def __init__(self, nx2, nx3, W, use_regularization, regularization_strength = 0.005, **kwargs):
109
+ super(EinsumTTL, self).__init__(**kwargs)
110
+ self.nx2 = nx2
111
+ self.nx3 = nx3
112
+ self.W = W
113
+ self.kernels = []
114
+
115
+ self.regularization_strength = regularization_strength
116
+ self.use_regularization = use_regularization
117
+ if self.use_regularization:
118
+ regularizer = EinsumTTLRegularizer(self.regularization_strength, self.nx3 // 4)
119
+ else:
120
+ regularizer = None
121
+
122
+ initializer_values_ = [1.0, 0.5, 2.0, 3.0] * W
123
+ initializer_values = initializer_values_[:W]
124
+ for i in range(W):
125
+ self.kernels.append(self.add_weight(
126
+ name = f'w{i + 1}',
127
+ shape = (nx3 // 2,),
128
+ regularizer = regularizer,
129
+ initializer = cosine_initializer(initializer_values[i])
130
+ ))
131
+
132
+ def call(self, inputs):
133
+ parts = []
134
+ for w in self.kernels:
135
+ part_a = tf.einsum('abc,c->ab', inputs[:, :self.nx2, :self.nx3 // 4], w[:self.nx3 // 4]) + \
136
+ tf.einsum('abc,c->ab', inputs[:, :self.nx2, self.nx3 // 4:self.nx3 // 2],
137
+ tf.reverse(w[:self.nx3 // 4], axis = [0]))
138
+ part_b = tf.einsum('abc,c->ab', inputs[:, self.nx2:, :self.nx3 // 4], w[self.nx3 // 4:self.nx3 // 2]) + \
139
+ tf.einsum('abc,c->ab', inputs[:, self.nx2:, self.nx3 // 4:self.nx3 // 2],
140
+ tf.reverse(w[self.nx3 // 4:self.nx3 // 2], axis = [0]))
141
+ parts.extend([part_a, part_b])
142
+
143
+ return tf.concat(parts, axis = 1)
144
+
145
+ def get_config(self):
146
+ return {'use_regularization': self.use_regularization,
147
+ 'regularization_strength': self.regularization_strength}
148
+
149
+
150
+ class TTL(tf.keras.layers.Layer):
151
+ """
152
+ TTL (Tensor Train Layer) is a custom TensorFlow Keras layer that builds a model
153
+ based on the given configuration. This layer is designed to work with
154
+ tensor train decomposition in neural networks.
155
+
156
+ Attributes:
157
+ config (dict): Configuration dictionary containing parameters for the model.
158
+
159
+ 'nx1', 'nx2', 'nx3': Integers, dimensions of the finite-difference grid
160
+
161
+ 'shape1': List of integers, defines the shape of the output tensor in the tensor train format.
162
+ The length of shape1 must match the length of shape2.
163
+ 'shape2': List of integers, specifies the shape of the input tensor in the tensor train format.
164
+ The length of shape2 must match the length of shape1.
165
+ 'ranks': List of integers, represents the ranks in the tensor train decomposition.
166
+ The length of this list determines the complexity and the number of parameters in the tensor train layer.
167
+ 'W' (int): Number of weight vectors to use in the initial EinsumTTL layer. Setting W = 0 means that no EinsumTLL
168
+ used.
169
+
170
+ 'use_regularization' (boolean, optional, default: False): Indicates whether regularization is used in the EinsumTTL.
171
+ 'regularization_strength' (float, optional, default: 0): Strength of the regularization
172
+
173
+ model (tf.keras.Sequential): The Sequential model built based on the provided configuration.
174
+
175
+ Methods:
176
+ load_config(self, config): Loads configuration
177
+ build_model(self): Builds the layer
178
+ call(inputs): Method for the forward pass of the layer.
179
+ """
180
+
181
+ def __init__(self, config, **kwargs):
182
+ super(TTL, self).__init__(**kwargs)
183
+ self.model = Sequential()
184
+ self.nx1 = None
185
+ self.nx2 = None
186
+ self.nx3 = None
187
+ self.shape1 = None
188
+ self.shape2 = None
189
+ self.ranks = None
190
+ self.W = None
191
+ self.use_regularization = None
192
+ self.regularization_strength = None
193
+ self._required_keys = ['nx1', 'nx2', 'nx3', 'shape1', 'shape2', 'ranks', 'W']
194
+
195
+ config.setdefault('use_regularization', False)
196
+ config.setdefault('regularization_strength', 0.0)
197
+ self.load_config(config)
198
+ self.config = config
199
+ self.build_model()
200
+
201
+ def load_config(self, config):
202
+ missing_keys = [key for key in self._required_keys if key not in config]
203
+ if missing_keys:
204
+ raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
205
+
206
+ if not isinstance(config['use_regularization'], bool):
207
+ raise TypeError('use_regularization must be a boolean.')
208
+ else:
209
+ self.use_regularization = config['use_regularization']
210
+
211
+ self.regularization_strength = 0.0
212
+
213
+ for key in ['nx1', 'nx2', 'nx3', 'W']:
214
+ if not isinstance(config[key], int):
215
+ raise TypeError(f"{key} must be an integer.")
216
+
217
+ for key in ['nx1', 'nx2', 'nx3']:
218
+ if config[key] <= 0:
219
+ raise ValueError(f"{key} must be positive.")
220
+
221
+ if config['W'] < 0:
222
+ raise ValueError("W must be non-negative.")
223
+
224
+ nx1, nx2, nx3 = config['nx1'], config['nx2'], config['nx3']
225
+ self.nx1 = nx1
226
+ self.nx2 = nx2
227
+ self.nx3 = nx3
228
+
229
+ W = config['W']
230
+ self.W = W
231
+
232
+ input_dim = 2 * nx2 * W
233
+ if W == 0:
234
+ input_dim = nx2 * nx3
235
+
236
+ shape1, shape2 = config['shape1'], config['shape2']
237
+ if len(shape1) != len(shape2):
238
+ raise ValueError(
239
+ f'shape1 and shape2 must have the same length. '
240
+ f'Received: shape1 = {shape1}, shape2 = {shape2}.'
241
+ )
242
+ elif np.prod(np.array(shape1)) != nx1 * nx2:
243
+ raise ValueError(
244
+ f'prod(shape1) must be equal to the output dimension of the TTL '
245
+ f'(nx1 * nx2,). Received: prod(shape1) = {np.prod(np.array(shape1))}, '
246
+ f'nx1 * nx2 = {nx1 * nx2}.'
247
+ )
248
+ elif np.prod(np.array(shape2)) != input_dim:
249
+ raise ValueError(
250
+ f'prod(shape2) must be equal to the input dimension of the TTL '
251
+ f'(2 * nx2 * W or nx2 * nx3 if W = 0). '
252
+ f'Received: prod(shape2) = {np.prod(np.array(shape2))}, required input dimension = {input_dim}.'
253
+ )
254
+ else:
255
+ self.shape1 = shape1
256
+ self.shape2 = shape2
257
+
258
+ self.ranks = config['ranks']
259
+
260
+ def build_model(self):
261
+ if self.W == 0:
262
+ self.model.add(Flatten(input_shape = (2 * self.nx2, self.nx3 // 2)))
263
+ else:
264
+ self.model.add(EinsumTTL(self.nx2, self.nx3, self.W, self.use_regularization,
265
+ regularization_strength = self.regularization_strength,
266
+ input_shape = (2 * self.nx2, self.nx3 // 2)))
267
+ self.model.add(Flatten())
268
+ tt_layer = t3f.nn.KerasDense(input_dims = self.shape2, output_dims = self.shape1,
269
+ tt_rank = self.ranks, use_bias = False, activation = 'linear')
270
+ self.model.add(tt_layer)
271
+ self.model.add(Reshape((self.nx1, self.nx2)))
272
+
273
+ def call(self, inputs):
274
+ return self.model(inputs)
stnn/pde/__init__.py ADDED
File without changes
stnn/pde/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (168 Bytes). View file
 
stnn/pde/circle.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import *
2
+
3
+
4
+ def u_dot_thetavec(r, theta, w):
5
+ """
6
+ Dot product of u = (cos(w), sin(w)) with the coordinate vector for theta.
7
+
8
+ Args:
9
+ r (float or array-like): the radial coordinate(s)
10
+ theta (float or array-like): The angular coordinate(s).
11
+ w (float or array-like): The w coordinate(s).
12
+
13
+ Returns:
14
+ numpy.ndarray: The calculated dot product for each point.
15
+ """
16
+ return r * np.sin(w - theta)
17
+
18
+
19
+ def u_dot_thetahat(theta, w):
20
+ """
21
+ Dot product of u = (cos(w), sin(w)) with the unit vector for theta.
22
+
23
+ Args:
24
+ theta (float or array-like): The angular (theta) coordinate(s).
25
+ w (float or array-like): The w coordinate(s).
26
+
27
+ Returns:
28
+ numpy.ndarray: The calculated dot product for each point.
29
+ """
30
+ return np.sin(w - theta)
31
+
32
+
33
+ def u_dot_rvec(theta, w):
34
+ """
35
+ Dot product of u = (cos(w), sin(w)) with the coordinate vector for r.
36
+
37
+ Args:
38
+ theta (float or array-like): The radial (r) coordinate(s).
39
+ w (float or array-like): The w coordinate(s).
40
+
41
+ Returns:
42
+ numpy.ndarray: The calculated dot product for each point.
43
+ """
44
+ return np.cos(w - theta)
45
+
46
+
47
+ def get_system_circle(config):
48
+ """
49
+ For a circular geometry, constructs the matrices, grids, and other quantities corresponding to the PDE system
50
+ specified by "config"".
51
+
52
+ Args:
53
+ config (dict): Configuration dictionary containing the system parameters.
54
+
55
+ Returns:
56
+ tuple: A tuple containing matrices, grids, etc. for the PDE system
57
+ """
58
+ required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2']
59
+ optional_keys = []
60
+
61
+ missing_keys = [key for key in required_keys if key not in config]
62
+ if missing_keys:
63
+ raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
64
+
65
+ unused_keys = [key for key in config if key not in required_keys + optional_keys]
66
+ if unused_keys:
67
+ warnings.warn(f"Unused keys in config: {', '.join(unused_keys)}")
68
+
69
+ for key in ['nx1', 'nx2', 'nx3']:
70
+ if not isinstance(config[key], int):
71
+ raise TypeError(f"{key} must be an integer.")
72
+
73
+ for key in ['nx1', 'nx2', 'nx3', 'ell', 'a2']:
74
+ if config[key] <= 0:
75
+ raise ValueError(f"{key} must be positive.")
76
+
77
+ if config['a2'] < 1.0:
78
+ raise ValueError('a2 must be greater than 1.')
79
+
80
+ nr, ntheta, nw = config['nx1'], config['nx2'], config['nx3']
81
+ R1 = 1.0
82
+ R2 = config['a2']
83
+ ell = config['ell']
84
+
85
+ # 1D grids
86
+ theta, w = get_angular_grids(ntheta, nw)
87
+ # r grid: non-uniform spacing and Dirichlet boundary conditions
88
+ y = np.linspace(-np.pi / 2, np.pi / 2, nr + 2)
89
+ r_ = (R2 - R1) * (np.sin(y) / 2 + 0.5) + R1
90
+ dr1 = r_[1] - r_[0]
91
+ dr2 = r_[-1] - r_[-2]
92
+ r = r_[1:-1]
93
+
94
+ # 1D finite-difference operators
95
+ Dtheta_minus, Dtheta_plus = d_dx_upwind(theta, ntheta)
96
+ D2w = d2_dx2_fourth_order(w, nw)
97
+ Dr_minus, Dr_plus = d_dx_upwind_nonuniform(r_, nr)
98
+
99
+ # 3D quantities. Kronecker products are used to build the 3D difference operators
100
+ r_3D, theta_3D, w_3D = np.meshgrid(r, theta, w, indexing = 'ij')
101
+ I_r = sp.eye(nr)
102
+ I_theta = sp.eye(ntheta)
103
+ I_w = sp.eye(nw)
104
+ Dtheta_3D_minus = sp.kron(sp.kron(I_r, Dtheta_minus), I_w)
105
+ Dtheta_3D_plus = sp.kron(sp.kron(I_r, Dtheta_plus), I_w)
106
+ D2w_3D = sp.kron(sp.kron(I_r, I_theta), D2w)
107
+ Dr_3D_minus = sp.kron(sp.kron(Dr_minus, I_theta), I_w)
108
+ Dr_3D_plus = sp.kron(sp.kron(Dr_plus, I_theta), I_w)
109
+
110
+ # Metric tensor. Note that g_12 = g_21 = 0.
111
+ g_11 = np.ones_like(r_3D)
112
+ g_22_over_r = r_3D # divide out factor of r
113
+
114
+ # Dot products
115
+ dp_r = u_dot_rvec(theta_3D, w_3D)
116
+ dp_thetahat = u_dot_thetahat(theta_3D, w_3D)
117
+
118
+ # Coefficient of d / dr
119
+ Dr_3D_coeff_meshgrid = dp_r / g_11
120
+ test_ill_conditioned(Dr_3D_coeff_meshgrid)
121
+ Dr_3D_coeff = sp.diags(Dr_3D_coeff_meshgrid.ravel())
122
+
123
+ # Coefficient of d / dtheta
124
+ Dtheta_3D_coeff_meshgrid = dp_thetahat / g_22_over_r
125
+ Dtheta_3D_coeff = sp.diags(Dtheta_3D_coeff_meshgrid.ravel())
126
+
127
+ # Upwind differencing
128
+ Dr_3D_upwind = upwind_operator(Dr_3D_minus, Dr_3D_plus, Dr_3D_coeff_meshgrid)
129
+ Dtheta_3D_upwind = upwind_operator(Dtheta_3D_minus, Dtheta_3D_plus, Dtheta_3D_coeff_meshgrid)
130
+
131
+ # Full operator
132
+ L = Dr_3D_coeff @ Dr_3D_upwind + Dtheta_3D_coeff @ Dtheta_3D_upwind - (1 / ell) * D2w_3D
133
+
134
+ return L, r_3D, theta_3D, w_3D, dr1, dr2, Dr_3D_coeff_meshgrid
135
+
136
+
137
+ def get_boundary_quantities_circle(theta_3D, w_3D):
138
+ """
139
+ Gets grid coordinates on the boundaries, as well as slice arrays
140
+ for positive/negative angles with respect to the boundary angle.
141
+
142
+ Args:
143
+ theta_3D (numpy.ndarray): 3D array of theta values on the grid.
144
+ w_3D (numpy.ndarray): 3D array of w values on the grid.
145
+
146
+ Returns:
147
+ tuple: Tuple of the grid coordinates and slice arrays
148
+ """
149
+ th1 = theta_3D[0, :, :]
150
+ wb1 = w_3D[0, :, :]
151
+ th2 = theta_3D[-1, :, :]
152
+ wb2 = w_3D[-1, :, :]
153
+ ib_slice = np.cos(th1 - wb1) > 0
154
+ ob_slice = np.cos(th2 - wb2) < 0
155
+
156
+ return th1, th2, wb1, wb2, ib_slice, ob_slice
stnn/pde/common.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ import scipy.sparse as sp
4
+
5
+
6
+ def d_dx_upwind(x, nx):
7
+ """
8
+ Sparse matrix representation of d/dx using first-order left/right differences with Dirichlet boundary conditions
9
+ """
10
+ dx = x[1] - x[0]
11
+ Dx_minus = sp.diags([-1, 1], [0, 1], shape = (nx, nx)).tolil() / dx
12
+ Dx_plus = sp.diags([-1, 1], [-1, 0], shape = (nx, nx)).tolil() / dx
13
+ Dx_minus[-1, 0] = 1 / dx
14
+ Dx_plus[0, -1] = -1 / dx
15
+ Dx_minus = Dx_minus.tocsr()
16
+ Dx_plus = Dx_plus.tocsr()
17
+ return Dx_minus, Dx_plus
18
+
19
+
20
+ def d2_dx2_fourth_order(x, nx):
21
+ """
22
+ Sparse matrix representation of d^2/dx^2 using fourth order central differences and periodic boundary conditions
23
+ """
24
+ dx = x[1] - x[0]
25
+ D2x = sp.diags([-1, 16, -30, 16, -1], [-2, -1, 0, 1, 2],
26
+ shape = (nx, nx)).tolil() / (12 * dx**2)
27
+ D2x[0, -1] = 16 / (12 * dx**2)
28
+ D2x[0, -2] = -1 / (12 * dx**2)
29
+ D2x[1, -1] = -1 / (12 * dx**2)
30
+ D2x[-1, 0] = 16 / (12 * dx**2)
31
+ D2x[-1, 1] = -1 / (12 * dx**2)
32
+ D2x[-2, 0] = -1 / (12 * dx**2)
33
+ D2x = D2x.tocsr()
34
+ return D2x
35
+
36
+
37
+ def d_dx_upwind_nonuniform(x, nx):
38
+ """
39
+ Sparse matrix representation of d/dx on a nonuniform grid, using first-order left/right differences
40
+ with Dirichlet boundary conditions.
41
+ """
42
+ Dx_ = np.diff(x)
43
+ Dx_minus = np.diff(x[1:])
44
+ Dx_minus_inv = 1 / Dx_minus
45
+ Dx_plus_inv = 1 / Dx_
46
+ Dx_minus = sp.diags([-Dx_minus_inv, Dx_minus_inv], [0, 1], shape = (nx, nx)).tolil()
47
+ Dx_plus = sp.diags([-Dx_plus_inv[1:], Dx_plus_inv[:-1]], [-1, 0], shape = (nx, nx)).tolil()
48
+ Dx_minus = Dx_minus.tocsr()
49
+ Dx_plus = Dx_plus.tocsr()
50
+ return Dx_minus, Dx_plus
51
+
52
+
53
+ def get_angular_grids(nx2, nx3):
54
+ """
55
+ x2 / x3 grids: uniform spacing and periodic boundary conditions
56
+ The x3 grid has an offset to ensure cos(x2 - x3) != 0.
57
+ """
58
+ x2 = np.linspace(-np.pi, np.pi, nx2, endpoint = False)
59
+ x3_min, x3_max = 0 + 0.125 * (2 * np.pi / nx3), 2 * np.pi + 0.125 * (2 * np.pi / nx3)
60
+ x3 = np.linspace(x3_min, x3_max, nx3, endpoint = False)
61
+ return x2, x3
62
+
63
+
64
+ def upwind_operator(Dx_minus, Dx_plus, Dx_coeff):
65
+ """
66
+ Upwind finite difference operator.
67
+
68
+ Args:
69
+ Dx_minus (scipy.sparse matrix): backward (minus) finite difference operator.
70
+ Dx_plus (scipy.sparse matrix): forward (plus) finite difference operator.
71
+ Dx_coeff (numpy.ndarray): coefficient array
72
+
73
+ Returns:
74
+ scipy.sparse matrix: The upwind operator
75
+ """
76
+ mask_x = Dx_coeff <= 0
77
+ Dx_masked_minus = sp.diags(mask_x.ravel().astype(int)) @ Dx_minus
78
+ Dx_masked_plus = sp.diags((~mask_x).ravel().astype(int)) @ Dx_plus
79
+ Dx_upwind = Dx_masked_minus + Dx_masked_plus
80
+ return Dx_upwind
81
+
82
+
83
+ def test_ill_conditioned(Dx_coeff):
84
+ """
85
+ Test for ill-conditioning. The thresholds are heuristic only.
86
+ """
87
+ ill_conditioning_test = np.min(np.abs(Dx_coeff.ravel()))
88
+ if ill_conditioning_test < 1e-10:
89
+ raise ValueError(f'System is ill-conditioned; min |Dx1_coeff| = {ill_conditioning_test}')
90
+ elif ill_conditioning_test < 1e-6:
91
+ warnings.warn(f'System may be ill-conditioned; min |Dx1_coeff| = {ill_conditioning_test}')
stnn/pde/ellipse.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import *
2
+
3
+
4
+ def u_dot_muvec(mu, eta, w):
5
+ """
6
+ Dot product of u = (cos(w), sin(w)) with the coordinate vector for mu.
7
+
8
+ Args:
9
+ mu (float or array-like): The mu coordinate(s).
10
+ eta (float or array-like): The eta coordinate(s).
11
+ w (float or array-like): The w coordinate(s).
12
+
13
+ Returns:
14
+ numpy.ndarray: The calculated dot product for each point.
15
+ """
16
+ return (0.5 * np.cosh(mu) * np.cos(eta - w) + 0.5 * np.cosh(mu) * np.cos(eta + w)
17
+ + 0.5 * np.sinh(mu) * np.cos(eta - w) - 0.5 * np.sinh(mu) * np.cos(eta + w))
18
+
19
+
20
+ def u_dot_etavec(mu, eta, w):
21
+ """
22
+ Dot product of u = (cos(w), sin(w)) with the coordinate vector for eta.
23
+
24
+ Args:
25
+ mu (float or array-like): The mu coordinate(s).
26
+ eta (float or array-like): The eta coordinate(s).
27
+ w (float or array-like): The w coordinate(s).
28
+
29
+ Returns:
30
+ numpy.ndarray: The calculated dot product for each point.
31
+ """
32
+ return (-0.5 * np.sinh(mu) * np.sin(eta + w) - 0.5 * np.sinh(mu) * np.sin(eta - w)
33
+ + 0.5 * np.cosh(mu) * np.sin(eta + w) - 0.5 * np.cosh(mu) * np.sin(eta - w))
34
+
35
+
36
+ def get_system_ellipse(config):
37
+ """
38
+ For an elliptical geometry, constructs the matrices, grids, and other quantities corresponding to the PDE system
39
+ specified by "config"".
40
+
41
+ Args:
42
+ config (dict): Configuration dictionary containing the system parameters.
43
+
44
+ Returns:
45
+ tuple: A tuple containing matrices, grids, etc. for the PDE system
46
+ """
47
+ required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2', 'eccentricity']
48
+ optional_keys = []
49
+
50
+ missing_keys = [key for key in required_keys if key not in config]
51
+ if missing_keys:
52
+ raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
53
+
54
+ unused_keys = [key for key in config if key not in required_keys + optional_keys]
55
+ if unused_keys:
56
+ warnings.warn(f"Unused keys in config: {', '.join(unused_keys)}")
57
+
58
+ for key in ['nx1', 'nx2', 'nx3']:
59
+ if not isinstance(config[key], int):
60
+ raise TypeError(f"{key} must be an integer.")
61
+
62
+ for key in ['nx1', 'nx2', 'nx3', 'ell']:
63
+ if config[key] <= 0:
64
+ raise ValueError(f"{key} must be positive.")
65
+
66
+ if not (0 <= config['eccentricity'] < 1.0):
67
+ raise ValueError('eccentricity must be >= 0 and < 1.')
68
+
69
+ if config['a2'] <= config['eccentricity']:
70
+ raise ValueError(f'a2 must be greater than the eccentricity.')
71
+
72
+ nmu, neta, nw = config['nx1'], config['nx2'], config['nx3']
73
+ minor_axis_outer = config['a2']
74
+ ell = config['ell']
75
+ minor_axis = 1.0 - config['eccentricity']
76
+ major_axis = 1.0
77
+ focal_distance = np.sqrt(major_axis**2 - minor_axis**2)
78
+ mu1 = np.arccosh(major_axis / focal_distance)
79
+ major_axis_outer = np.sqrt(focal_distance**2 + minor_axis_outer**2)
80
+ mu2 = np.arccosh(major_axis_outer / focal_distance)
81
+
82
+ # 1D grids
83
+ eta, w = get_angular_grids(neta, nw)
84
+ # mu grid: non-uniform spacing and Dirichlet boundary conditions
85
+ y = np.linspace(-np.pi / 2, np.pi / 2, nmu + 2, dtype = np.float64)
86
+ mu_ = np.log((np.exp(mu2) - np.exp(mu1)) * (np.sin(y) / 2 + 0.5) + np.exp(mu1))
87
+ dmu1 = mu_[1] - mu_[0]
88
+ dmu2 = mu_[-1] - mu_[-2]
89
+ mu = mu_[1:-1]
90
+
91
+ # 1D finite-difference operators
92
+ Deta_minus, Deta_plus = d_dx_upwind(eta, neta)
93
+ D2w = d2_dx2_fourth_order(w, nw)
94
+ Dmu_minus, Dmu_plus = d_dx_upwind_nonuniform(mu_, nmu)
95
+
96
+ # 3D quantities. Kronecker products are used to build the 3D difference operators
97
+ mu_3D, eta_3D, w_3D = np.meshgrid(mu, eta, w, indexing = 'ij')
98
+ I_mu = sp.eye(nmu)
99
+ I_eta = sp.eye(neta)
100
+ I_w = sp.eye(nw)
101
+ Deta_3D_minus = sp.kron(sp.kron(I_mu, Deta_minus), I_w)
102
+ Deta_3D_plus = sp.kron(sp.kron(I_mu, Deta_plus), I_w)
103
+ D2w_3D = sp.kron(sp.kron(I_mu, I_eta), D2w)
104
+ Dmu_3D_minus = sp.kron(sp.kron(Dmu_minus, I_eta), I_w)
105
+ Dmu_3D_plus = sp.kron(sp.kron(Dmu_plus, I_eta), I_w)
106
+
107
+ # Metric tensor. Note that g_12 = g_21 = 0 and g_11 = g_22.
108
+ g_11 = focal_distance * (np.cosh(mu_3D) * np.cosh(mu_3D) * np.cos(eta_3D) * np.cos(eta_3D)
109
+ + np.sinh(mu_3D) * np.sinh(mu_3D) * np.sin(eta_3D) * np.sin(eta_3D))
110
+
111
+ # Dot products
112
+ dp_mu = u_dot_muvec(mu_3D, eta_3D, w_3D)
113
+ dp_eta = u_dot_etavec(mu_3D, eta_3D, w_3D)
114
+
115
+ # Coefficient of d / dmu
116
+ Dmu_3D_coeff_meshgrid = dp_mu / g_11
117
+ test_ill_conditioned(Dmu_3D_coeff_meshgrid)
118
+
119
+ Dmu_3D_coeff = sp.diags(Dmu_3D_coeff_meshgrid.ravel())
120
+
121
+ # Coefficient of d / deta
122
+ Deta_3D_coeff_meshgrid = dp_eta / g_11
123
+ Deta_3D_coeff = sp.diags(Deta_3D_coeff_meshgrid.ravel())
124
+
125
+ # Upwind differencing
126
+ Dmu_3D_upwind = upwind_operator(Dmu_3D_minus, Dmu_3D_plus, Dmu_3D_coeff_meshgrid)
127
+ Deta_3D_upwind = upwind_operator(Deta_3D_minus, Deta_3D_plus, Deta_3D_coeff_meshgrid)
128
+
129
+ # Full operator
130
+ L = Dmu_3D_coeff @ Dmu_3D_upwind + Deta_3D_coeff @ Deta_3D_upwind - (1 / ell) * D2w_3D
131
+
132
+ return L, mu_3D, eta_3D, w_3D, dmu1, dmu2, Dmu_3D_coeff_meshgrid, major_axis_outer
133
+
134
+
135
+ def get_boundary_quantities_ellipse(mu_3D, eta_3D, w_3D):
136
+ """
137
+ Gets grid coordinates on the boundaries, as well as slice arrays
138
+ for positive/negative angles with respect to the boundary angle.
139
+
140
+ Args:
141
+ mu_3D: 3D array of mu values on the grid.
142
+ eta_3D: 3D array of eta values on the grid.
143
+ w_3D (numpy.ndarray): 3D array of w values on the grid.
144
+
145
+ Returns:
146
+ tuple: Tuple of the grid coordinates and slice arrays
147
+ """
148
+ eta_2D_ib = eta_3D[0, ...]
149
+ eta_2D_ob = eta_3D[-1, ...]
150
+ w_2D_ib = w_3D[0, ...]
151
+ w_2D_ob = w_3D[-1, ...]
152
+ ib_slice = u_dot_muvec(mu_3D, eta_3D, w_3D)[0, ...] > 0
153
+ ob_slice = u_dot_muvec(mu_3D, eta_3D, w_3D)[-1, ...] < 0
154
+ return eta_2D_ib, eta_2D_ob, w_2D_ib, w_2D_ob, ib_slice, ob_slice
stnn/pde/pde_system.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from stnn.data.function_generators import generate_random_functions
4
+ from .circle import get_system_circle, get_boundary_quantities_circle, u_dot_thetahat
5
+ from .ellipse import (get_system_ellipse, get_boundary_quantities_ellipse, u_dot_etavec)
6
+
7
+
8
+ class PDESystem:
9
+ """
10
+ Constructs the PDE system given input parameters. The finite-difference matrices, grids, and other relevant
11
+ quantities are available as attributes.
12
+
13
+ Constructor Args:
14
+ params (dict): Configuration dictionary containing the parameters that define the PDE system.
15
+
16
+ Attributes:
17
+
18
+ ib_slice (numpy.ndarray): Boolean array defining nodes adjacent to the inner boundary
19
+ ob_slice (numpy.ndarray): Boolean array defining nodes adjacent to the outer boundary
20
+
21
+ x2_ib (numpy.ndarray): The x2 coordinate values at the inner boundary.
22
+ x2_ob (numpy.ndarray): The x2 coordinate values at the outer boundary.
23
+ x3_ib (numpy.ndarray): The x3 coordinate values at the inner boundary.
24
+ x3_ob (numpy.ndarray): The x3 coordinate values at the outer boundary.
25
+
26
+ Dx1_coeff (numpy.ndarray): Coefficients for the advection operator in the radial direction. Used for converting
27
+ boundary conditions to the r.h.s. of the linear system defining a
28
+ boundary-value problem.
29
+
30
+ dx1a (numpy.ndarray): Grid spacing adjacent to the inner boundary
31
+ dx1b (numpy.ndarray): Grid spacing adjacent to the outer boundary
32
+
33
+ L (numpy.ndarray): Finite-difference representation of the linear operator defining the PDE.
34
+
35
+ x1 (numpy.ndarray): The grid values in radial coordinate (r or mu)
36
+ x2 (numpy.ndarray): The grid values in the angular coordinate (theta or eta).
37
+ x3 (numpy.ndarray): The grid values in the w coordinate
38
+
39
+ a1 (float): Minor axis of the inner boundary
40
+ a2 (float): Minor axis of the outer boundary
41
+ b1 (float): Major axis of the inner boundary
42
+ b2 (float): Major axis of the outer boundary
43
+
44
+ _coords (str): The type of coordinate system used ('ellipse' or 'circle'). This affects how the grids and
45
+ other geometric properties are calculated.
46
+
47
+ params (dict): The configuration dictionary containing the PDE parameters.
48
+ """
49
+ def __init__(self, params):
50
+ self._required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2', 'eccentricity']
51
+ self._optional_keys = []
52
+ self.ib_slice = None
53
+ self.ob_slice = None
54
+ self.x2_ib = None
55
+ self.x2_ob = None
56
+ self.x3_ib = None
57
+ self.x3_ob = None
58
+ self.Dx1_coeff = None
59
+ self.dx1b = None
60
+ self.dx1a = None
61
+ self.L = None
62
+ self.x1 = None
63
+ self.x2 = None
64
+ self.x3 = None
65
+ self.a1 = None
66
+ self.a2 = None
67
+ self.b1 = None
68
+ self.b2 = None
69
+ self._coords = None
70
+ self.params = params
71
+ self.initialize()
72
+
73
+ def initialize(self):
74
+ """
75
+ Constructs the system matrices and vectors based on the stored configuration.
76
+
77
+ Depending on the 'eccentricity' parameter in `self.params`, the coordinate system is set
78
+ to either 'circle' or 'ellipse'; the domain parametrization and finite-difference grid are defined
79
+ accordingly.
80
+ """
81
+ params = self.params
82
+ missing_keys = [key for key in self._required_keys if key not in params]
83
+ if missing_keys:
84
+ raise KeyError(f"Missing keys in config: {', '.join(missing_keys)}")
85
+
86
+ # The functions 'get_system_circle' and 'get_system_ellipse' have a fair amount
87
+ # of overlap and probably should be combined, but for now they are kept separate
88
+ # for simplicity and readability.
89
+ if params['eccentricity'] < 1e-7:
90
+ self._coords = 'circle'
91
+ L, x1, x2, x3, dx1a, dx1b, Dx1_coeff = get_system_circle(params)
92
+ x2_ib, x2_ob, x3_ib, x3_ob, ib_slice, ob_slice = get_boundary_quantities_circle(x2, x3)
93
+ self.b2 = params['a2']
94
+ else:
95
+ self._coords = 'ellipse'
96
+ (L, x1, x2, x3, dx1a, dx1b, Dx1_coeff, major_axis_outer) = get_system_ellipse(params)
97
+ x2_ib, x2_ob, x3_ib, x3_ob, ib_slice, ob_slice = get_boundary_quantities_ellipse(x1, x2, x3)
98
+ self.b2 = major_axis_outer
99
+
100
+ self.a1 = 1.0 - params['eccentricity']
101
+ self.a2 = params['a2']
102
+ self.b1 = 1.0
103
+
104
+ self.x1 = x1
105
+ self.x2 = x2
106
+ self.x3 = x3
107
+
108
+ self.L = L
109
+
110
+ self.dx1a = dx1a
111
+ self.dx1b = dx1b
112
+ self.Dx1_coeff = Dx1_coeff
113
+
114
+ self.x2_ib = x2_ib
115
+ self.x2_ob = x2_ob
116
+ self.x3_ib = x3_ib
117
+ self.x3_ob = x3_ob
118
+ self.ib_slice = ib_slice
119
+ self.ob_slice = ob_slice
120
+
121
+ def generate_random_bc(self, func_gen_id):
122
+ """
123
+ Generates random boundary conditions for the PDE system.
124
+
125
+ Args:
126
+ func_gen_id (int): Integer representing the type of 'function generator' used to construct the
127
+ boundary conditions.
128
+
129
+ Returns:
130
+ tuple: A tuple containing:
131
+ - ibf_data: Inner boundary data
132
+ - obf_data: Outer boundary data
133
+ - b: 3D array for passing to the GMRES solver. 'b' contains the boundary data but is defined
134
+ on the full 3D grid.
135
+ - bf: Flattened boundary data before it is permuted
136
+
137
+ The boundary conditions are defined on the inner and outer boundaries of the domain and are denoted
138
+ by 'ibf_data' and 'obf_data'. The function passes 'ibf_data' and 'obf_data' through 'convert_boundary_data',
139
+ which converts them into formats suitable for passing into the GMRES solver and STNN model (e.g., by
140
+ reshaping and permutation operations).
141
+ """
142
+
143
+ # Note the change of variable (x2, x3) -> (x2, x2 - x3).
144
+ ibf_data = generate_random_functions(1, self.x2_ib, self.x2_ib - self.x3_ib,
145
+ max_freq=self.params['nx3'], func_gen_id = func_gen_id)[0, self.ib_slice]
146
+ obf_data = generate_random_functions(1, self.x2_ob, self.x2_ob - self.x3_ob,
147
+ max_freq=self.params['nx3'], func_gen_id = func_gen_id)[0, self.ob_slice]
148
+
149
+ # Combine boundary data in single vector
150
+ bf = np.concatenate([ibf_data, obf_data], axis = -1).flatten()
151
+
152
+ # Permutes 'ibf_data' and 'obf_data' and construct 'b'
153
+ ibf_data, obf_data, b = self.convert_boundary_data(ibf_data, obf_data)
154
+
155
+ return ibf_data, obf_data, b, bf
156
+
157
+ def convert_boundary_data(self, ibf_data, obf_data):
158
+ """
159
+ Converts boundary data into formats suitable for passing into the GMRES solver and STNN model.
160
+
161
+ Args:
162
+ ibf_data: Inner boundary data
163
+ obf_data: Outer boundary data
164
+
165
+ Returns:
166
+ tuple: A tuple containing:
167
+ - ibf_data: Inner boundary data, permuted to match the input structure of the EinsumTTL layer.
168
+ - obf_data: Outer boundary data, permuted to match the input structure of the EinsumTTL layer
169
+ - b: 3D array for passing to the GMRES solver. 'b' contains the boundary data but is defined
170
+ on the full 3D grid.
171
+ """
172
+ nx1, nx2, nx3 = self.params['nx1'], self.params['nx2'], self.params['nx3']
173
+ b = np.zeros((nx1, nx2, nx3), dtype=np.float64)
174
+ b[0, self.ib_slice] = self.Dx1_coeff[0, self.ib_slice] * (ibf_data / self.dx1a)
175
+ b[-1, self.ob_slice] = -self.Dx1_coeff[-1, self.ob_slice] * (obf_data / self.dx1b)
176
+
177
+ if self._coords == 'ellipse':
178
+ sin_angle = u_dot_etavec(self.x1, self.x2, self.x3)
179
+ elif self._coords == 'circle':
180
+ sin_angle = u_dot_thetahat(self.x2, self.x3)
181
+ else:
182
+ raise ValueError(f'"_coords" attribute should be either "ellipse" or "circle"; instead received {self._coords}')
183
+
184
+ # reshape and permute 'ibf_data' and 'obf_data'
185
+ sin_angle_i = sin_angle[0, self.ib_slice].reshape(nx2, nx3 // 2)
186
+ sin_angle_o = sin_angle[-1, self.ob_slice].reshape(nx2, nx3 // 2)
187
+ W_I = np.argsort(sin_angle_i, axis=1)
188
+ W_O = np.argsort(sin_angle_o, axis=1)
189
+ ibf_data = ibf_data.reshape((nx2, nx3 // 2))
190
+ obf_data = obf_data.reshape((nx2, nx3 // 2))
191
+ for n in range(nx2):
192
+ ibf_data[n, :] = ibf_data[n, W_I[n, :]]
193
+ obf_data[n, :] = obf_data[n, W_O[n, :]]
194
+
195
+ return ibf_data, obf_data, b
196
+
197
+ def get_xy_grids(self):
198
+ """
199
+ Converts the native grids of the PDE system to xy coordinates (no interpolation).
200
+
201
+ The function also applies a wrap-around in the x2 domain for plotting purposes, ensuring
202
+ continuity across the (periodic) domain.
203
+
204
+ Returns:
205
+ tuple of numpy.ndarray: A tuple containing two 2D numpy arrays:
206
+ - x_grid: The x-coordinates grid.
207
+ - y_grid: The y-coordinates grid.
208
+ """
209
+ x1_1D = self.x1[:, 0, 0]
210
+ x2_1D = self.x2[0, :, 0]
211
+ x2_1D = np.append(x2_1D, np.array([np.pi - 1e-3])) # wrap around for plotting
212
+ x1_2D, x2_2D = np.meshgrid(x1_1D, x2_1D, indexing='ij')
213
+ if self._coords == 'ellipse':
214
+ focal_distance = np.sqrt(self.b1**2 - self.a1**2)
215
+ x_grid = focal_distance * np.sinh(x1_2D) * np.cos(x2_2D)
216
+ y_grid = focal_distance * np.cosh(x1_2D) * np.sin(x2_2D)
217
+ elif self._coords == 'circle':
218
+ x_grid = x1_2D * np.cos(x2_2D)
219
+ y_grid = x1_2D * np.sin(x2_2D)
220
+ else:
221
+ raise ValueError(f'"_coords" should be either "ellipse" or "circle"; instead received {self._coords}')
222
+
223
+ return x_grid, y_grid
stnn/tests/test_circle.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import copy
3
+ import numpy as np
4
+ import scipy.sparse as sp
5
+ from stnn.pde.circle import get_system_circle
6
+
7
+
8
+ class TestGetSystemCircle(unittest.TestCase):
9
+
10
+ def setUp(self):
11
+ self.config = {
12
+ 'nx1': 10,
13
+ 'nx2': 20,
14
+ 'nx3': 30,
15
+ 'a2': 2.0,
16
+ 'ell': 1.5,
17
+ }
18
+ self.saved_config = copy.deepcopy(self.config)
19
+ self._required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2']
20
+ self._optional_keys = []
21
+
22
+ def test_valid_output(self):
23
+ L, r_3D, theta_3D, w_3D, dr1, dr2, Dr_3D_coeff_meshgrid = get_system_circle(self.config)
24
+
25
+ # Test types
26
+ self.assertIsInstance(L, sp.csr_matrix)
27
+ self.assertIsInstance(r_3D, np.ndarray)
28
+ self.assertIsInstance(theta_3D, np.ndarray)
29
+ self.assertIsInstance(w_3D, np.ndarray)
30
+ self.assertIsInstance(Dr_3D_coeff_meshgrid, np.ndarray)
31
+
32
+ # Test shapes
33
+ self.assertEqual(r_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
34
+ self.assertEqual(theta_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
35
+ self.assertEqual(w_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
36
+
37
+ # Test values
38
+ self.assertTrue(dr1 > 0)
39
+ self.assertTrue(dr2 > 0)
40
+
41
+ def test_missing_keys(self):
42
+ for key in self._required_keys:
43
+ del self.config[key]
44
+ with self.assertRaises(KeyError):
45
+ get_system_circle(self.config)
46
+ self.config[key] = self.saved_config[key]
47
+
48
+ def test_invalid_parameters(self):
49
+ for key in self._required_keys:
50
+ self.config[key] = 0 # None of the required keys should be zero.
51
+ with self.assertRaises(ValueError):
52
+ get_system_circle(self.config)
53
+ self.config[key] = self.saved_config[key]
54
+
55
+ def test_unused_params_warning(self):
56
+ copied_params = copy.deepcopy(self.config)
57
+ copied_params['unusedkey'] = 0
58
+ with self.assertWarns(UserWarning) as _:
59
+ get_system_circle(copied_params)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest.main()
stnn/tests/test_dependencies.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import importlib
3
+
4
+
5
+ class TestDependencies(unittest.TestCase):
6
+ def test_required_modules_installed(self):
7
+ required_modules = [
8
+ 'numpy',
9
+ 'tensorflow',
10
+ 'numpy',
11
+ 't3f',
12
+ 'scipy',
13
+ 'h5py',
14
+ 'matplotlib',
15
+ 'pydot',
16
+ 'openvino',
17
+ ]
18
+
19
+ for module in required_modules:
20
+ with self.subTest(module = module):
21
+ importlib.import_module(module)
22
+
23
+
24
+ if __name__ == '__main__':
25
+ unittest.main()
stnn/tests/test_differential_ops.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ from stnn.pde.common import d_dx_upwind, d2_dx2_fourth_order, d_dx_upwind_nonuniform
4
+
5
+
6
+ def assert_derivative(operator, function, expected_derivative, boundary = None, rtol = None, atol = None):
7
+ """
8
+ Assert the derivative of a function using a given finite-difference operator
9
+
10
+ Args:
11
+ operator (np.ndarray or sparse matrix): Differential operator matrix.
12
+ function (np.ndarray): Values of the function to differentiate
13
+ expected_derivative (np.ndarray):: Expected result of the derivative.
14
+ boundary (int): Specifies if boundary elements should be excluded, and how many.
15
+ rtol: Relative tolerance.
16
+ atol: Absolute tolerance.
17
+ """
18
+ observed_derivative = operator @ function
19
+ if boundary is not None:
20
+ observed_derivative = observed_derivative[boundary:-boundary]
21
+ expected_derivative = expected_derivative[boundary:-boundary]
22
+
23
+ np.testing.assert_allclose(observed_derivative, expected_derivative, rtol = rtol, atol = atol)
24
+
25
+
26
+ class TestDifferentialOperators(unittest.TestCase):
27
+
28
+ def setUp(self):
29
+ # ----- Set up grids
30
+
31
+ # grid for radial coordinate (non-periodic, non-uniform spacing)
32
+ self.nx = 100000
33
+ zs = np.linspace(-np.pi / 2, np.pi / 2, self.nx + 2)
34
+ R1, R2 = 0.5, 1.4
35
+ self.z_ = (R2 - R1) * (np.sin(zs) / 2 + 0.5) + R1
36
+ self.z = self.z_[1:-1]
37
+
38
+ # grid for angular coordinates (periodic, uniform spacing)
39
+ self.y = np.linspace(0, 2 * np.pi, self.nx, endpoint = False)
40
+ dy = (2 * np.pi) / self.nx
41
+
42
+ # ----- Set tolerances
43
+
44
+ # Tolerances for "exact" tests, i.e., where the finite differences do not have truncation error
45
+ self.atol = 1e-6
46
+ self.rtol = 1e-6
47
+
48
+ # Tolerances for some inexact tests
49
+ self.atol_inexact = 1e-3
50
+ self.rtol_z = 3 * np.max(np.diff(self.z_)) # relative tolerance of 3*dx for first-order one-sided differences
51
+ self.rtol_y = 3 * dy # relative tolerance of 3*dx for first-order one-sided differences
52
+ self.rtol_y2 = 3 * dy**4 # relative tolerance of 3*dy**4 for fourth-order central differences
53
+
54
+ # ----- Test inputs
55
+
56
+ # "Exact" test inputs
57
+ self.f1 = -2.3 * np.ones(self.nx)
58
+ self.f2 = 0.8 * self.z
59
+ self.f3 = 0.8 * self.y
60
+ self.f4 = -0.1 * self.y * self.y
61
+
62
+ # Inexact test inputs
63
+ # noinspection PyRedundantParentheses
64
+ self.g1 = (self.z)**2 - (self.z)**3
65
+ self.g2 = (np.cos(self.y))**2 - (np.sin(self.y))**3
66
+ self.g3 = (np.sin(2 * self.y))**2 - (np.cos(self.y))**3
67
+
68
+ def test_d_dx_upwind(self):
69
+ dx_m, dx_p = d_dx_upwind(self.y, self.nx)
70
+
71
+ assert_derivative(dx_m, self.f1, np.zeros(self.nx), rtol = self.rtol, atol = self.atol)
72
+ assert_derivative(dx_p, self.f1, np.zeros(self.nx), rtol = self.rtol, atol = self.atol)
73
+
74
+ assert_derivative(dx_m, self.f3, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
75
+ assert_derivative(dx_p, self.f3, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
76
+
77
+ expected_dg2 = -2 * np.cos(self.y) * np.sin(self.y) - 3 * np.sin(self.y)**2 * np.cos(self.y)
78
+ assert_derivative(dx_m, self.g2, expected_dg2, rtol = self.rtol_y, atol = self.atol_inexact)
79
+ assert_derivative(dx_p, self.g2, expected_dg2, rtol = self.rtol_y, atol = self.atol_inexact)
80
+
81
+ expected_dg3 = 4 * np.sin(2 * self.y) * np.cos(2 * self.y) + 3 * np.cos(self.y)**2 * np.sin(self.y)
82
+ assert_derivative(dx_m, self.g3, expected_dg3, rtol = self.rtol_y, atol = self.atol_inexact)
83
+ assert_derivative(dx_p, self.g3, expected_dg3, rtol = self.rtol_y, atol = self.atol_inexact)
84
+
85
+ def test_d2_dx2_fourth_order(self):
86
+ d2x = d2_dx2_fourth_order(self.y, self.nx)
87
+
88
+ assert_derivative(d2x, self.f1, np.zeros(self.nx), rtol = self.rtol, atol = self.atol)
89
+ assert_derivative(d2x, self.f3, np.zeros(self.nx), boundary = 4, rtol = self.rtol, atol = self.atol)
90
+ assert_derivative(d2x, self.f4, -0.2 * np.ones(self.nx), boundary = 4, rtol = self.rtol, atol = self.atol)
91
+
92
+ expected_dg2 = 2 * np.sin(self.y)**2 - 2 * np.cos(self.y)**2 - \
93
+ 6 * np.sin(self.y) * np.cos(self.y)**2 + 3 * np.sin(self.y)**3
94
+ assert_derivative(d2x, self.g2, expected_dg2, rtol = self.rtol_y2, atol = self.atol_inexact)
95
+
96
+ expected_dg3 = 8 * np.cos(2 * self.y)**2 - 8 * np.sin(2 * self.y)**2 - \
97
+ 6 * np.cos(self.y) * np.sin(self.y)**2 + 3 * np.cos(self.y)**3
98
+ assert_derivative(d2x, self.g3, expected_dg3, rtol = self.rtol_y2, atol = self.atol_inexact)
99
+
100
+ def test_d_dx_upwind_nonuniform(self):
101
+ dx_minus, dx_plus = d_dx_upwind_nonuniform(self.z_, self.nx)
102
+
103
+ assert_derivative(dx_minus, self.f1, np.zeros(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
104
+ assert_derivative(dx_plus, self.f1, np.zeros(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
105
+
106
+ assert_derivative(dx_minus, self.f2, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
107
+ assert_derivative(dx_plus, self.f2, 0.8 * np.ones(self.nx), boundary = 1, rtol = self.rtol, atol = self.atol)
108
+
109
+ expected_dg1 = (2 * self.z - 3 * self.z**2)
110
+ assert_derivative(dx_minus, self.g1, expected_dg1, boundary = 1, rtol = self.rtol_z, atol = self.atol_inexact)
111
+ assert_derivative(dx_plus, self.g1, expected_dg1, boundary = 1, rtol = self.rtol_z, atol = self.atol_inexact)
112
+
113
+
114
+ if __name__ == '__main__':
115
+ unittest.main()
stnn/tests/test_ellipse.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import copy
3
+ import numpy as np
4
+ import scipy.sparse as sp
5
+ from stnn.pde.ellipse import get_system_ellipse
6
+
7
+
8
+ class TestGetSystemCircle(unittest.TestCase):
9
+
10
+ def setUp(self):
11
+ self.config = {
12
+ 'nx1': 10,
13
+ 'nx2': 20,
14
+ 'nx3': 30,
15
+ 'a2': 2.0,
16
+ 'ell': 1.5,
17
+ 'eccentricity': 0.5
18
+ }
19
+ self.saved_config = copy.deepcopy(self.config)
20
+ self._required_keys = ['nx1', 'nx2', 'nx3', 'ell', 'a2', 'eccentricity']
21
+ self._optional_keys = []
22
+
23
+ def test_valid_output(self):
24
+ L, mu_3D, eta_3D, w_3D, dmu1, dmu2, Dmu_3D_coeff_meshgrid, major_axis_outer = get_system_ellipse(self.config)
25
+
26
+ # Test types
27
+ self.assertIsInstance(L, sp.csr_matrix)
28
+ self.assertIsInstance(mu_3D, np.ndarray)
29
+ self.assertIsInstance(eta_3D, np.ndarray)
30
+ self.assertIsInstance(w_3D, np.ndarray)
31
+ self.assertIsInstance(Dmu_3D_coeff_meshgrid, np.ndarray)
32
+
33
+ # Test shapes
34
+ self.assertEqual(mu_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
35
+ self.assertEqual(eta_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
36
+ self.assertEqual(w_3D.shape, (self.config['nx1'], self.config['nx2'], self.config['nx3']))
37
+ N = self.config['nx1'] * self.config['nx2'] * self.config['nx3']
38
+ self.assertEqual(L.shape, (N, N))
39
+
40
+ # Test values
41
+ self.assertTrue(dmu1 > 0)
42
+ self.assertTrue(dmu2 > 0)
43
+ self.assertTrue(major_axis_outer > self.config['a2'])
44
+
45
+ def test_missing_keys(self):
46
+ for key in self._required_keys:
47
+ del self.config[key]
48
+ with self.assertRaises(KeyError):
49
+ get_system_ellipse(self.config)
50
+ self.config[key] = self.saved_config[key]
51
+
52
+ def test_invalid_parameters(self):
53
+ for key in self._required_keys:
54
+ self.config[key] = -1 # None of the required keys should be negative.
55
+ with self.assertRaises(ValueError):
56
+ get_system_ellipse(self.config)
57
+ self.config[key] = self.saved_config[key]
58
+
59
+ def test_unused_params_warning(self):
60
+ copied_params = copy.deepcopy(self.config)
61
+ copied_params['unusedkey'] = 0
62
+ with self.assertWarns(UserWarning) as _:
63
+ get_system_ellipse(copied_params)
64
+
65
+
66
+ if __name__ == '__main__':
67
+ unittest.main()
stnn/tests/test_file.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ import h5py
4
+ import tempfile
5
+ from stnn.data.preprocessing import get_data_from_file, load_data, load_training_data
6
+
7
+
8
+ class TestGetDataFromFile(unittest.TestCase):
9
+
10
+ def setUp(self):
11
+ self.temp_file = tempfile.NamedTemporaryFile(delete = False)
12
+ self.nx1, self.nx2, self.nx3 = 30, 20, 16
13
+ self.Nsamples = 10
14
+ with h5py.File(self.temp_file.name, 'w') as f:
15
+ f.create_dataset('ell', data = np.random.rand(self.Nsamples))
16
+ f.create_dataset('a1', data = np.random.rand(self.Nsamples))
17
+ f.create_dataset('a2', data = np.random.rand(self.Nsamples))
18
+ f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
19
+ f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
20
+ f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
21
+
22
+ self.temp_file1 = tempfile.NamedTemporaryFile(delete = False)
23
+ with h5py.File(self.temp_file1.name, 'w') as f:
24
+ f.create_dataset('ell', data = np.random.rand(self.Nsamples))
25
+ f.create_dataset('a1', data = np.random.rand(self.Nsamples))
26
+ f.create_dataset('a2', data = np.random.rand(self.Nsamples))
27
+ f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
28
+ f.create_dataset('ibf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
29
+ f.create_dataset('obf', data = np.random.rand(self.Nsamples, self.nx2, self.nx3 // 2))
30
+
31
+ self.bad_file = tempfile.NamedTemporaryFile(delete = False)
32
+ with h5py.File(self.bad_file.name, 'w') as f:
33
+ f.create_dataset('ell', data = np.random.rand(self.Nsamples))
34
+ f.create_dataset('a1', data = np.random.rand(self.Nsamples))
35
+ f.create_dataset('a2', data = np.random.rand(self.Nsamples))
36
+ f.create_dataset('rho', data = np.random.rand(self.Nsamples, self.nx1, self.nx2))
37
+
38
+ def tearDown(self):
39
+ self.temp_file.close()
40
+ self.temp_file1.close()
41
+ self.bad_file.close()
42
+
43
+ def test_missing_datasets(self):
44
+ with self.assertRaises(ValueError):
45
+ get_data_from_file(self.bad_file.name, self.nx2, self.nx2)
46
+
47
+ def test_data_extraction_shapes(self):
48
+ result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3)
49
+ self.assertEqual(result[0].shape, (self.Nsamples,))
50
+ self.assertEqual(result[1].shape, (self.Nsamples,))
51
+ self.assertEqual(result[2].shape, (self.Nsamples,))
52
+ self.assertEqual(result[3].shape, (self.Nsamples, 2 * self.nx2, self.nx3 // 2))
53
+ self.assertEqual(result[4].shape, (self.Nsamples, self.nx1, self.nx2))
54
+
55
+ def test_nrange_parameter(self):
56
+ Nrange = (2, 5)
57
+ result = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
58
+ expected_size = Nrange[1] - Nrange[0]
59
+ self.assertEqual(result[0].shape, (expected_size,))
60
+ self.assertEqual(result[0].shape, (expected_size,))
61
+
62
+ def test_list_input(self):
63
+ file_list = [self.temp_file.name, self.temp_file1.name]
64
+ Nrange_list = [(0, -1), (0, -1)]
65
+ with self.assertRaises(TypeError):
66
+ # noinspection PyTypeChecker
67
+ _ = get_data_from_file(file_list, self.nx2, self.nx3, Nrange = Nrange_list)
68
+
69
+ def test_invalid_Nrange(self):
70
+ Nrange_list = [(0, -1), (0, -1)]
71
+ with self.assertRaises(TypeError):
72
+ _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange_list)
73
+
74
+ for Nrange in [(0, 1, 1), 1, (1), (1.5, 3), (3, 1.5), (1.5, 1.5), 'x']:
75
+ with self.assertRaises(TypeError):
76
+ _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = Nrange)
77
+ with self.assertRaises(TypeError):
78
+ _ = get_data_from_file(self.temp_file.name, self.nx2, self.nx3, Nrange = list(Nrange))
79
+
80
+ def test_good_data_load(self):
81
+ files = [self.temp_file.name, self.temp_file1.name]
82
+ Nrange_list = [(0, None), (0, self.Nsamples)]
83
+ ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
84
+
85
+ params, bf, rho = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
86
+ self.assertEqual(params.shape, (2 * self.Nsamples, 3))
87
+ self.assertEqual(bf.shape, (2 * self.Nsamples, 2 * self.nx2, self.nx3 // 2))
88
+ self.assertEqual(rho.shape, (2 * self.Nsamples, self.nx1, self.nx2))
89
+
90
+ test_size = 0.3
91
+ (params_train, bf_train, rho_train,
92
+ params_test, bf_test, rho_test) = load_training_data(files, self.nx2, self.nx3,
93
+ ell1, ell2, a1, a2, test_size = test_size,
94
+ Nrange_list = Nrange_list)
95
+ Ntest = int(test_size * 2 * self.Nsamples)
96
+ Ntrain = 2 * self.Nsamples - Ntest
97
+ self.assertEqual(params_train.shape, (Ntrain, 3))
98
+ self.assertEqual(bf_train.shape, (Ntrain, 2 * self.nx2, self.nx3 // 2))
99
+ self.assertEqual(rho_train.shape, (Ntrain, self.nx1, self.nx2))
100
+ self.assertEqual(params_test.shape, (Ntest, 3))
101
+ self.assertEqual(bf_test.shape, (Ntest, 2 * self.nx2, self.nx3 // 2))
102
+ self.assertEqual(rho_test.shape, (Ntest, self.nx1, self.nx2))
103
+
104
+ def test_bad_data_load(self):
105
+ files = [self.temp_file.name, self.temp_file1.name]
106
+ Nrange_list = (0, -1)
107
+ ell1, ell2, a1, a2 = 0.1, 2.0, 1.0, 5.0
108
+ with self.assertRaises(TypeError):
109
+ _ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = Nrange_list)
110
+ with self.assertRaises(TypeError):
111
+ _ = load_data(files, self.nx2, self.nx3, ell1, ell2, a1, a2, Nrange_list = list(Nrange_list))
112
+
113
+ Nrange_list = [(0, -1), (0, -1)]
114
+ for test_size in [-1, 0.0, 1.5]:
115
+ with self.assertRaises(ValueError):
116
+ _ = load_training_data(files, self.nx2, self.nx3,
117
+ ell1, ell2, a1, a2, test_size = test_size, Nrange_list = Nrange_list)
118
+
119
+
120
+ if __name__ == '__main__':
121
+ unittest.main()
stnn/tests/test_pde_system.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import unittest
3
+
4
+ import numpy as np
5
+
6
+ from stnn.pde.pde_system import PDESystem
7
+
8
+
9
+ class TestPDESystem(unittest.TestCase):
10
+
11
+ def setUp(self):
12
+ self.params = {
13
+ 'nx1': 50,
14
+ 'nx2': 100,
15
+ 'nx3': 75,
16
+ 'a1': 1.0,
17
+ 'a2': 2.0,
18
+ 'ell': 0.1,
19
+ 'eccentricity': 0.5
20
+ }
21
+ self.system = PDESystem(self.params)
22
+
23
+ def test_initialization(self):
24
+ self.assertEqual(self.system.params, self.params)
25
+
26
+ def test_attribute_types(self):
27
+ self.assertIsInstance(self.system.ib_slice, np.ndarray)
28
+ self.assertIsInstance(self.system.ob_slice, np.ndarray)
29
+
30
+ def test_attribute_values(self):
31
+ self.assertEqual(self.system.a1, 1.0 - self.params['eccentricity'])
32
+
33
+ def test_coordinate_system(self):
34
+ expected_coords = 'ellipse' if self.params['eccentricity'] != 0 else 'circle'
35
+ self.assertEqual(self.system._coords, expected_coords)
36
+
37
+ def test_unused_params_warning(self):
38
+ copied_params = copy.deepcopy(self.params)
39
+ copied_params['unusedkey'] = 0
40
+ with self.assertWarns(UserWarning) as _:
41
+ _ = PDESystem(copied_params)
42
+
43
+ def test_L(self):
44
+ # If f(x1, x2, x3) is constant, then L * f should be 0 except adjacent to the boundary.
45
+ L = self.system.L
46
+ nx1, nx2, nx3 = self.params['nx1'], self.params['nx2'], self.params['nx3']
47
+ f0 = 2.3 * np.ones((nx1, nx2, nx3))
48
+ result = L @ f0.ravel()
49
+ result = result.reshape((nx1, nx2, nx3))
50
+ np.testing.assert_allclose(result[1:-1, ...], 0, atol = 1e-7)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ unittest.main()
stnn/tests/test_preprocessing.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ from stnn.data.preprocessing import train_test_split
4
+
5
+
6
+ class TestTrainTestSplit(unittest.TestCase):
7
+
8
+ def setUp(self):
9
+ self.X_array = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
10
+ self.Y_array = np.array([1, 2, 3, 4])
11
+ self.X_list = [self.X_array, self.X_array]
12
+ self.Y_list = [self.Y_array, self.Y_array]
13
+ self.X_list_bad = [self.Y_array, self.X_array]
14
+ self.Y_list_bad = [self.Y_array, self.X_array]
15
+
16
+ def test_basic_functionality_array(self):
17
+ X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
18
+ self.assertEqual(len(X_train), 3)
19
+ self.assertEqual(len(X_test), 1)
20
+ self.assertEqual(len(Y_train), 3)
21
+ self.assertEqual(len(Y_test), 1)
22
+
23
+ def test_basic_functionality_list(self):
24
+ X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
25
+ self.assertEqual(len(X_train[0]), 3)
26
+ self.assertEqual(len(X_test[0]), 1)
27
+ self.assertEqual(len(Y_train[0]), 3)
28
+ self.assertEqual(len(Y_test[0]), 1)
29
+
30
+ def test_return_type_consistency_array(self):
31
+ X_train, X_test, Y_train, Y_test = train_test_split(self.X_array, self.Y_array, test_size = 0.25)
32
+ self.assertIsInstance(X_train, np.ndarray)
33
+ self.assertIsInstance(X_test, np.ndarray)
34
+ self.assertIsInstance(Y_train, np.ndarray)
35
+ self.assertIsInstance(Y_test, np.ndarray)
36
+
37
+ X_train, X_test, Y_train, Y_test = train_test_split([self.X_array], [self.Y_array], test_size = 0.25)
38
+ self.assertIsInstance(X_train, list)
39
+ self.assertIsInstance(X_test, list)
40
+ self.assertIsInstance(Y_train, list)
41
+ self.assertIsInstance(Y_test, list)
42
+
43
+ def test_return_type_consistency_list(self):
44
+ X_train, X_test, Y_train, Y_test = train_test_split(self.X_list, self.Y_list, test_size = 0.25)
45
+ self.assertIsInstance(X_train, list)
46
+ self.assertIsInstance(X_test, list)
47
+ self.assertIsInstance(Y_train, list)
48
+ self.assertIsInstance(Y_test, list)
49
+
50
+ # noinspection PyTypeChecker
51
+ X_train, X_test, Y_train, Y_test = train_test_split(tuple(self.X_list), tuple(self.Y_list), test_size = 0.25)
52
+ self.assertIsInstance(X_train, list)
53
+ self.assertIsInstance(X_test, list)
54
+ self.assertIsInstance(Y_train, list)
55
+ self.assertIsInstance(Y_test, list)
56
+
57
+ def test_random_state(self):
58
+ X_train1, X_test1, Y_train1, Y_test1 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
59
+ random_state = 42)
60
+ X_train2, X_test2, Y_train2, Y_test2 = train_test_split(self.X_array, self.Y_array, test_size = 0.25,
61
+ random_state = 42)
62
+ np.testing.assert_array_equal(X_train1, X_train2)
63
+ np.testing.assert_array_equal(X_test1, X_test2)
64
+ np.testing.assert_array_equal(Y_train1, Y_train2)
65
+ np.testing.assert_array_equal(Y_test1, Y_test2)
66
+
67
+ def test_invalid_test_size(self):
68
+ with self.assertRaises(ValueError):
69
+ train_test_split(self.X_array, self.Y_array, test_size = -0.1)
70
+ with self.assertRaises(ValueError):
71
+ train_test_split(self.X_array, self.Y_array, test_size = 1.5)
72
+
73
+ def test_inconsistent_length(self):
74
+ X = np.array([[1, 2], [3, 4]])
75
+ Y = np.array([1, 2, 3])
76
+ with self.assertRaises(ValueError):
77
+ train_test_split(X, Y)
78
+ with self.assertRaises(ValueError):
79
+ train_test_split(self.X_list_bad, self.Y_list_bad)
80
+ with self.assertRaises(ValueError):
81
+ train_test_split(self.X_list_bad, self.Y_list)
82
+ with self.assertRaises(ValueError):
83
+ train_test_split(self.X_list, self.Y_list_bad)
84
+
85
+ def test_empty(self):
86
+ X_empty = np.zeros(0)
87
+ Y_empty = np.zeros(0)
88
+ with self.assertRaises(ValueError):
89
+ train_test_split(X_empty, Y_empty)
90
+ with self.assertRaises(ValueError):
91
+ train_test_split([X_empty], [])
92
+ with self.assertRaises(ValueError):
93
+ train_test_split([], [Y_empty])
94
+ with self.assertRaises(ValueError):
95
+ train_test_split([X_empty], [Y_empty])
96
+
97
+
98
+ if __name__ == '__main__':
99
+ unittest.main()
stnn/tests/test_stats.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import unittest
3
+ import os
4
+ from stnn.utils.stats import get_stats
5
+
6
+
7
+ class TestGetStats(unittest.TestCase):
8
+
9
+ def setUp(self):
10
+ self.rho = np.array([[1, 2], [3, 4]])
11
+ self.rho_pred = np.array([[1, 2], [3, 4]])
12
+ self.filename = 'test_stats.npz'
13
+
14
+ def test_correctness(self):
15
+ get_stats(self.rho, self.rho_pred, self.filename)
16
+ with np.load(self.filename) as data:
17
+ self.assertAlmostEqual(data['max_loss'], 0.0, places=5)
18
+ self.assertEqual(data['avg_loss'], 0.0)
19
+ self.assertEqual(data['N'], self.rho.shape[0])
20
+
21
+ def test_file_creation(self):
22
+ get_stats(self.rho, self.rho_pred, self.filename)
23
+ self.assertTrue(os.path.exists(self.filename))
24
+
25
+ def test_file_content(self):
26
+ get_stats(self.rho, self.rho_pred, self.filename)
27
+ with np.load(self.filename) as data:
28
+ self.assertIn('max_loss', data)
29
+ self.assertIn('avg_loss', data)
30
+ self.assertIn('N', data)
31
+
32
+ def test_invalid_input(self):
33
+ with self.assertRaises(ValueError):
34
+ get_stats(np.array([1, 2]), np.array([[1, 2], [3, 4]]))
35
+
36
+ def tearDown(self):
37
+ if os.path.exists(self.filename):
38
+ os.remove(self.filename)
39
+
40
+ if __name__ == '__main__':
41
+ unittest.main()
stnn/tests/test_stnn_config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ import copy
4
+ from stnn.nn.stnn import build_stnn
5
+
6
+
7
+ class TestBuildSTNN(unittest.TestCase):
8
+
9
+ def setUp(self):
10
+ self.config = {
11
+ 'K': 1,
12
+ 'nx1': 8,
13
+ 'nx2': 8,
14
+ 'nx3': 8,
15
+ 'd': 8,
16
+ 'W': 3,
17
+ 'shape1': [1, 2, 3],
18
+ 'shape2': [2, 2, 2],
19
+ 'ranks': [1, 2, 2, 1],
20
+ }
21
+ self.saved_config = copy.deepcopy(self.config)
22
+ self._required_keys = ['nx1', 'nx2', 'nx3', 'K', 'd', 'shape1','shape2','ranks','W']
23
+ self._optional_keys = ['use_regularization', 'regularization_strength']
24
+
25
+ def test_missing_keys(self):
26
+ for key in self._required_keys:
27
+ del self.config[key]
28
+ with self.assertRaises(KeyError):
29
+ build_stnn(self.config)
30
+ self.config[key] = self.saved_config[key]
31
+
32
+ def test_invalid_values(self):
33
+ for key in ['K', 'd', 'W', 'nx1', 'nx2', 'nx3']:
34
+ for value in [1.5, 'a', None, np.nan]:
35
+ with self.subTest(value = value):
36
+ self.config[key] = value
37
+ with self.assertRaises(TypeError):
38
+ build_stnn(self.config)
39
+ self.config[key] = self.saved_config[key]
40
+ value = -1
41
+ with self.subTest(value = value):
42
+ self.config[key] = value
43
+ with self.assertRaises(ValueError):
44
+ build_stnn(self.config)
45
+ self.config[key] = self.saved_config[key]
46
+
47
+ self.config['nx3'] = 7 # not divisible by 2
48
+ with self.assertRaises(ValueError):
49
+ build_stnn(self.config)
50
+ self.config[key] = self.saved_config[key]
51
+
52
+ def test_positive_values(self):
53
+ for key in ['K', 'nx1', 'nx2', 'nx3', 'd']:
54
+ self.config[key] = 0
55
+ with self.assertRaises(ValueError):
56
+ build_stnn(self.config)
57
+ self.config[key] = self.saved_config[key]
58
+
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()
stnn/tests/test_ttl.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from stnn.nn.stnn_layers import TTL
3
+
4
+
5
+ class TestTTL(unittest.TestCase):
6
+
7
+ def test_missing_config_keys(self):
8
+ with self.assertRaises(KeyError):
9
+ TTL(config = {}) # Empty config
10
+
11
+ def test_invalid_use_regularization_type(self):
12
+ config = {'use_regularization': 'not a boolean', 'nx1': 1, 'nx2': 1, 'nx3': 1, 'shape1': [1], 'shape2': [1],
13
+ 'ranks': [1], 'W': 1}
14
+ with self.assertRaises(TypeError):
15
+ TTL(config = config)
16
+
17
+ def test_invalid_nx_values(self):
18
+ config = {'nx1': 5, 'nx2' : 5, 'nx3': 8, 'use_regularization': False, 'shape1': [1], 'shape2': [1], 'ranks': [1], 'W': 1}
19
+ for nx in ['nx1', 'nx2', 'nx3']:
20
+ for value in [-1, 0]:
21
+ config.update({nx: value})
22
+ with self.assertRaises(ValueError):
23
+ TTL(config = config)
24
+
25
+ def test_invalid_W_value(self):
26
+ config = {'use_regularization': False, 'nx1': 1, 'nx2': 1, 'nx3': 1, 'shape1': [1], 'shape2': [1], 'ranks': [1],
27
+ 'W': -1}
28
+ with self.assertRaises(ValueError):
29
+ TTL(config = config)
30
+
31
+ def test_shape_length_mismatch(self):
32
+ config = {'use_regularization': False, 'nx1': 1, 'nx2': 1, 'nx3': 1, 'shape1': [1, 2], 'shape2': [1],
33
+ 'ranks': [1], 'W': 1}
34
+ with self.assertRaises(ValueError):
35
+ TTL(config = config)
36
+
37
+ def test_incorrect_shape1_product(self):
38
+ config = {'use_regularization': False, 'nx1': 2, 'nx2': 2, 'nx3': 1, 'shape1': [1, 3], 'shape2': [4],
39
+ 'ranks': [1], 'W': 1}
40
+ with self.assertRaises(ValueError):
41
+ TTL(config = config)
42
+
43
+ def test_incorrect_shape2_product(self):
44
+ config = {'use_regularization': False, 'nx1': 1, 'nx2': 2, 'nx3': 1, 'shape1': [2], 'shape2': [1, 5],
45
+ 'ranks': [1], 'W': 1}
46
+ with self.assertRaises(ValueError):
47
+ TTL(config = config)
48
+
49
+
50
+ if __name__ == '__main__':
51
+ unittest.main()
stnn/utils/__init__.py ADDED
File without changes
stnn/utils/input_output.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ from scipy.sparse import csr_matrix
3
+ import numpy as np
4
+ import json
5
+ import tensorflow as tf
6
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
7
+
8
+
9
+ def save_to_hdf5(filename, datasets, start_idx, end_idx):
10
+ """
11
+ Saves subsets of datasets to an HDF5 file, either by creating new datasets or appending to existing ones.
12
+
13
+ Args:
14
+ filename (str): The name of the HDF5 file where the data will be saved.
15
+ datasets (dict): A dictionary where keys are dataset names and values are the corresponding data arrays.
16
+ start_idx (int): The starting index of the data slice to be saved.
17
+ end_idx (int): The ending index (exclusive) of the data slice to be saved.
18
+
19
+ This function will create new datasets if they do not already exist. If a dataset already exists, it will be resized
20
+ to accommodate the new data, and the data slice will be appended.
21
+ """
22
+ print(f'Saving data from n = {start_idx} to n = {end_idx}...')
23
+ with h5py.File(filename, 'a') as f:
24
+ for name, data in datasets.items():
25
+ if name not in f:
26
+ f.create_dataset(name, data = data[start_idx:end_idx], maxshape = (None,) + data.shape[1:],
27
+ chunks = (1,) + data.shape[1:])
28
+ else:
29
+ f[name].resize((f[name].shape[0] + end_idx - start_idx,) + f[name].shape[1:])
30
+ f[name][-(end_idx - start_idx):] = data[start_idx:end_idx]
31
+ print('Done.')
32
+
33
+
34
+ def write_sparse_matrix_hdf5(filename, sparse_matrix, dataset_name = 'sparse_matrix', format_ = 'csr'):
35
+ """
36
+ Writes a sparse matrix to an HDF5 file in a specified format.
37
+
38
+ Args:
39
+ filename (str): Name of the output file
40
+ sparse_matrix (scipy.sparse matrix): Sparse matrix to be written to the file.
41
+ dataset_name (str, optional): Name of the HDF5 dataset. Defaults to 'sparse_matrix'.
42
+ format_ (str, optional): The format of the sparse matrix. Currently only supports 'csr' (Compressed Sparse Row).
43
+ """
44
+ if format_ == 'csr':
45
+ with h5py.File(filename, 'w') as f:
46
+ g = f.create_group(dataset_name)
47
+ g.create_dataset('data', data = sparse_matrix.data)
48
+ g.create_dataset('indices', data = sparse_matrix.indices)
49
+ g.create_dataset('indptr', data = sparse_matrix.indptr)
50
+ g.create_dataset('shape', data = np.array(sparse_matrix.shape))
51
+ g.attrs['format'] = 'csr'
52
+ else:
53
+ raise ValueError(f'Unsupported sparse matrix format: {format_}')
54
+
55
+
56
+ def read_sparse_matrix_hdf5(filename, dataset_name = 'sparse_matrix'):
57
+ """
58
+ Reads a sparse matrix from an HDF5 file.
59
+
60
+ Args:
61
+ filename (str): Name of the output file
62
+ dataset_name (str, optional): Name of the dataset containing the matrix. Defaults to 'sparse_matrix'.
63
+
64
+ Returns:
65
+ scipy.sparse matrix: The sparse matrix read from the file.
66
+ """
67
+ with h5py.File(filename, 'r') as f:
68
+ g = f[dataset_name]
69
+ data = g['data'][:]
70
+ indices = g['indices'][:]
71
+ indptr = g['indptr'][:]
72
+ shape = tuple(g['shape'][:])
73
+ format_ = g.attrs['format']
74
+ if format_ == 'csr':
75
+ return csr_matrix((data, indices, indptr), shape = shape)
76
+ else:
77
+ raise ValueError(f'Unsupported sparse matrix format: {format_}')
78
+
79
+
80
+ def data_dump(bf, rho, rho_pred, params_dict):
81
+ """
82
+ Writes data and config to file for later use.
83
+ """
84
+ np.savez('sample_data.npz', rho = rho, rho_pred = rho_pred, bf = bf)
85
+ json.dump('sample_config.json', params_dict, encoding = 'utf-8')
86
+
87
+
88
+ def save_as_frozen_graph(model, saved_model_dir):
89
+ """
90
+ Converts a TensorFlow model into a 'frozen' SavedModel format.
91
+
92
+ Args:
93
+ model: TensorFlow model to be frozen.
94
+ saved_model_dir (str): The directory path where the frozen model will be saved.
95
+
96
+ Returns:
97
+ TFModelServer object: An instance of the TFModelServer class, which can be used for serving the model.
98
+
99
+ The function converts model variables to constants and is required for converting the model to the intermediate
100
+ representation (IR) used by openvino.
101
+ """
102
+ # Create input specifications for the model
103
+ input_specs = [tf.TensorSpec([1] + model.inputs[i].shape[1:].as_list(), model.inputs[i].dtype) for i in
104
+ range(len(model.inputs))]
105
+
106
+ # Create a concrete function from the model
107
+ full_model = tf.function(lambda x: model(x))
108
+ full_model = full_model.get_concrete_function(input_specs)
109
+
110
+ # Convert the model to a frozen function
111
+ frozen_func = convert_variables_to_constants_v2(full_model)
112
+
113
+ # Define a new module with a method that has a `@tf.function` decorator with input signatures
114
+ class TFModelServer(tf.Module):
115
+ def __init__(self, frozen_func):
116
+ super().__init__()
117
+ self.frozen_func = frozen_func
118
+
119
+ @tf.function(input_signature = input_specs)
120
+ def serve(self, *args):
121
+ return self.frozen_func(*args)
122
+
123
+ # Create an instance of TFModelServer with the frozen function
124
+ model_server = TFModelServer(frozen_func)
125
+
126
+ # Save the module as a SavedModel
127
+ tf.saved_model.save(model_server, saved_model_dir, signatures = {"serving_default": model_server.serve})
128
+
129
+ return model_server
130
+
131
+
132
+ def log_and_print(message):
133
+ print(message)
134
+ with open('log.txt', 'a') as log_file:
135
+ log_file.write(message + '\n')
stnn/utils/network_visualization.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pydot
2
+ import re
3
+ from keras.models import Model
4
+ from keras.layers import Layer, InputLayer
5
+ from pygments.lexers import graphviz
6
+
7
+
8
+ # May be necessary to manually add Graphviz to PATH, e.g.
9
+ # import os
10
+ # os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin'
11
+
12
+ def visualize_model(model, layer_labels = None, layer_colors = None, groupings = None, exclude_input_layer = False,
13
+ verbose = False, output_filename = 'model_graph.png'):
14
+ """
15
+ Creates a visual graph of a keras model. There is an option to group certain layers into subgraphs
16
+ (argument 'groupings').
17
+
18
+ Args:
19
+ model: A Keras Model instance
20
+ layer_labels (optional): List of labels for each layer. Defaults to layer names.
21
+ layer_colors (optional): List of colors for each layer. Defaults to white for all layers.
22
+ groupings (optional): Dictionary specifying groups of layers. Each key is a group name,
23
+ and its value is a list of layer names belonging to that group.
24
+ exclude_input_layer (optional): Boolean indicating whether to exclude the input layer from the graph.
25
+ verbose (boolean, optional): Whether to print verbose output. Defaults to False.
26
+ output_filename (optional): name of the output file for saving the generated graph.
27
+
28
+ Output:
29
+ Image file with name 'output_filename'.
30
+ """
31
+ if not isinstance(model, Model):
32
+ raise ValueError("model should be a Keras model instance")
33
+ num_layers = len(model.layers)
34
+
35
+ # Default labels and colors if not provided
36
+ if not layer_labels:
37
+ layer_labels = [layer.name for layer in model.layers]
38
+ if not layer_colors:
39
+ default_color = 'white'
40
+ layer_colors = [default_color] * num_layers
41
+
42
+ # Create a directed graph
43
+ graph = pydot.Dot(graph_type = 'digraph', rankdir = 'LR')
44
+
45
+ # Create nodes for each layer and add to subgraphs if specified
46
+ subgraphs = {}
47
+ layer_id_map = {}
48
+ for i, layer in enumerate(model.layers):
49
+ # Exclude the input layer if specified
50
+ if exclude_input_layer and isinstance(layer, InputLayer):
51
+ continue
52
+
53
+ # Create a node for the layer
54
+ layer_id = str(id(layer))
55
+ layer_id_map[layer] = layer_id
56
+ label = layer_labels[i]
57
+ color = layer_colors[i]
58
+
59
+ node = pydot.Node(layer_id, label = label, style = 'filled', fillcolor = color, shape = 'box')
60
+
61
+ # Check for groupings and add the node to the appropriate subgraph or main graph
62
+ group_name = None
63
+ if groupings:
64
+ for group, members in groupings.items():
65
+ if layer.name in members:
66
+ group_name = group
67
+ break
68
+
69
+ if group_name:
70
+ if group_name not in subgraphs:
71
+ subgraph = pydot.Cluster(group_name, label = group_name, style = 'dashed', fontsize = 24)
72
+ subgraphs[group_name] = subgraph
73
+ subgraphs[group_name].add_node(node)
74
+ else:
75
+ graph.add_node(node)
76
+
77
+ # Add subgraphs to the main graph
78
+ for subgraph in subgraphs.values():
79
+ graph.add_subgraph(subgraph)
80
+
81
+ # Add edges based on layer connections
82
+ for layer in model.layers:
83
+ if exclude_input_layer and isinstance(layer, InputLayer):
84
+ continue
85
+ # Handle custom or non-standard layers
86
+ if hasattr(layer, '_inbound_nodes'):
87
+ inbound_nodes = layer._inbound_nodes
88
+ else:
89
+ # If the layer doesn't have '_inbound_nodes', skip edge creation
90
+ continue
91
+
92
+ inbound_layers = []
93
+ for inbound_node in inbound_nodes:
94
+ inbound_layers = inbound_node.inbound_layers
95
+ if not isinstance(inbound_layers, list):
96
+ inbound_layers = [inbound_layers]
97
+
98
+ for inbound_node in inbound_nodes:
99
+ for inbound_layer in inbound_layers:
100
+ if isinstance(inbound_layer, Layer) and inbound_layer in layer_id_map:
101
+ src_id = layer_id_map[inbound_layer]
102
+ dest_id = layer_id_map[layer]
103
+ if (re.search('sequential', inbound_layer.name, flags = re.IGNORECASE) or
104
+ re.search(r'operators__.getitem_[0-9]+$', inbound_layer.name, flags = re.IGNORECASE)):
105
+ graph.add_edge(pydot.Edge(src_id, dest_id, style = 'invis'))
106
+ else:
107
+ graph.add_edge(pydot.Edge(src_id, dest_id))
108
+ if verbose:
109
+ print(f"Added edge from {inbound_layer.name} to {layer.name}")
110
+
111
+ graph.set_graph_defaults(sep = '+125,125')
112
+ try:
113
+ graph.write_png(output_filename)
114
+ except FileNotFoundError as e:
115
+ print(f'\nFailed to create network visualization using pydot and graphviz. Pleasure ensure that '
116
+ 'the output filename is valid, and graphviz is installed and included in the system PATH variable. '
117
+ f'Original error: {e}')
118
+ except Exception as e:
119
+ print(f'\nFailed to create network visualization using pydot and graphviz. Original error: {e}')
120
+ else:
121
+ print(f'Model visualization saved to {output_filename}')
stnn/utils/plotting.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
4
+ from matplotlib.cm import get_cmap
5
+
6
+
7
+ def plot_comparison(system, bf, rho, rho_pred, fontscale = 1, output_filename = 'comparison.png',
8
+ wspace = -0.1, hspace = 0.5):
9
+ """
10
+ Side-by-side comparison of contour plots for arguments 'rho' and 'rho_pred'.
11
+ Also plots the boundary data (argument 'bf') as a contour plot.
12
+
13
+ Args:
14
+ system: PDESystem object
15
+ bf (np.ndarray): 2D array representing the boundary data on the (x2, x3) grid.
16
+ rho (np.ndarray): 2D array, known rho
17
+ rho_pred (np.ndarray): 2D array, predicted rho
18
+ fontscale (int, optional): A scaling factor for the font size in plots.
19
+ output_filename (str, optional): name of the output file
20
+ wspace (float, optional): horizontal spacing between subplots
21
+ hspace (float, optional): vertical spacing between subplots
22
+
23
+ Returns:
24
+ Does not return a value. The figure is saved to a file with name given by 'output_filename'. The default
25
+ is 'comparison.png'.
26
+ """
27
+ # System parameters
28
+ ell = system.params['ell']
29
+ a2 = system.a2
30
+ e = system.params['eccentricity']
31
+ nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
32
+
33
+ # Get x, y grids from 'PDESystem' object
34
+ x, y = system.get_xy_grids()
35
+
36
+ # Relative error
37
+ err = np.linalg.norm(rho - rho_pred)
38
+ rel_err = err / np.linalg.norm(rho)
39
+
40
+ # wrap around values for continuity
41
+ rho = np.append(rho, rho[:, 0:1], axis = 1)
42
+ rho_pred = np.append(rho_pred, rho_pred[:, 0:1], axis = 1)
43
+
44
+ plots = [rho, rho_pred]
45
+
46
+ # Color bar limits
47
+ vmin = np.nanmin(rho_pred[:, :])
48
+ vmax = np.nanmax(rho_pred[:, :])
49
+ vmin = min(vmin, np.nanmin(rho[:, :]))
50
+ vmax = max(vmax, np.nanmax(rho[:, :]))
51
+
52
+ # Figure layout
53
+ cbar_coords = (0.89, 0.11, 0.03, 0.77)
54
+ fig = plt.figure(figsize = (24, 24))
55
+ gs = fig.add_gridspec(11, 2, hspace = hspace, wspace = wspace)
56
+ axs = [fig.add_subplot(gs[:11, 0]), fig.add_subplot(gs[:5, 1]), fig.add_subplot(gs[6:11, 1])]
57
+
58
+ # boundary data contour plot
59
+ bf_plot = np.nan * np.ones((2 * nx2, nx3))
60
+ bf_plot[:nx2, :nx3 // 2] = bf[:nx2, :]
61
+ bf_plot[nx2:, nx3 // 2:] = bf[nx2:, :]
62
+ im = axs[0].imshow(bf_plot)
63
+ axs[0].set_yticks([1, nx2 // 4, nx2 // 2 - 1, nx2 // 2 + 1, 3 * nx2 // 4, nx2])
64
+ axs[0].set_yticklabels(['-pi', '0', 'pi', '', '0', 'pi'])
65
+ axs[0].set_xticks([1, nx3 // 2, nx3])
66
+ axs[0].set_xticklabels(['0', 'pi', '2pi'])
67
+ axs[0].set_title('Boundary data\n\n', fontsize = fontscale * 48, y = 0.91)
68
+ for label in axs[0].get_xticklabels() + axs[0].get_yticklabels():
69
+ label.set_fontsize(fontscale * 48)
70
+ divider = make_axes_locatable(axs[0])
71
+ cax = divider.append_axes('bottom', size = "3%", pad = 1.5)
72
+ cb = fig.colorbar(im, cax = cax, orientation = 'horizontal')
73
+ cb.ax.tick_params(labelsize = fontscale * 48)
74
+ cax.xaxis.set_ticks_position('bottom')
75
+
76
+ # rho(x, y) contour plots
77
+ titles = ['Direct Solution', 'Tensor network']
78
+ for i, ax in enumerate(axs[1:]):
79
+ z = plots[i]
80
+ im = ax.contourf(x, y, z, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
81
+ ax.set_title(titles[i], fontsize = fontscale * 48)
82
+ for label in ax.get_xticklabels() + ax.get_yticklabels():
83
+ label.set_fontsize(fontscale * 32)
84
+ ax.set_aspect(1.0)
85
+
86
+ cbar_ax = fig.add_axes(cbar_coords)
87
+ cb = fig.colorbar(im, cax = cbar_ax, pad = 0.05)
88
+ cb.ax.tick_params(labelsize = fontscale * 32)
89
+
90
+ suptitle = f'ell = {ell:.3f}; a2 = {a2:.3f}; e = {e:.3f}; Relative error: {rel_err:.3f}'
91
+ plt.suptitle(suptitle, fontsize = fontscale * 48, x = 0.1, y = 0.97, horizontalalignment = 'left')
92
+ plt.savefig(output_filename)
93
+ plt.close()
stnn/utils/stats.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_stats(rho, rho_pred, output_filename = 'stats.npz'):
5
+ """
6
+ Calculates statistical metrics for model predictions and saves them to a file.
7
+
8
+ This function computes the normalized loss for each instance in the dataset, identifies the maximum loss and its
9
+ index, and calculates the average loss. These statistics are then saved to an NPZ file.
10
+
11
+ Args:
12
+ rho (numpy.ndarray): True values.
13
+ rho_pred (numpy.ndarray): Predicted values.
14
+ output_filename (str, optional): The name of the file where the statistics will be saved. Defaults to 'stats.npz'.
15
+ """
16
+ if rho.shape != rho_pred.shape:
17
+ raise ValueError('rho and rho_pred must have the same shape.')
18
+
19
+ y_true_flattened = rho.reshape(rho.shape[0], -1)
20
+ y_pred_flattened = rho_pred.reshape(rho_pred.shape[0], -1)
21
+ loss = np.linalg.norm(y_true_flattened - y_pred_flattened, axis = 1) / np.linalg.norm(y_true_flattened, axis = 1)
22
+ max_loss = np.max(loss)
23
+ max_loss_index = np.argmax(loss)
24
+ print(f'Maximum loss: {max_loss}')
25
+ print(f'Index of instance with max loss: {max_loss_index}')
26
+ print(f'Average loss: {np.average(loss)}')
27
+ np.savez(output_filename, max_loss = max_loss, avg_loss = np.average(loss), N = loss.shape[0])