File size: 12,245 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import json
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import sympy
from matplotlib.cm import get_cmap
from stnn.nn import stnn
from stnn.pde.pde_system import PDESystem


def adjust_to_nice_number(value, round_down = False):
	"""
	Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks.
	"""
	if value == 0:
		return value

	is_negative = False
	if value < 0:
		round_down = True
		is_negative = True
		value = -value
	exponent = np.floor(np.log10(value))  # Find exponent of 10
	fractional_part = value / 10**exponent  # Find leading digit(s)

	if round_down:
		if fractional_part < 1.5:
			nice_fractional = 1
		elif fractional_part < 3:
			nice_fractional = 2
		elif fractional_part < 7:
			nice_fractional = 5
		else:
			nice_fractional = 10
	else:
		if fractional_part <= 1:
			nice_fractional = 1
		elif fractional_part <= 2:
			nice_fractional = 2
		elif fractional_part <= 5:
			nice_fractional = 5
		else:
			nice_fractional = 10

	nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1)
	if is_negative:
		nice_value = -nice_value
	return nice_value


def find_nice_values(min_val_raw, max_val, num_values = 4):
	"""
	Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks.
	"""
	# Calculate rough spacing between values
	min_val = adjust_to_nice_number(min_val_raw)
	frac_val = (min_val - min_val_raw) / (max_val - min_val_raw)
	if frac_val < 1 / num_values:
		min_val = min_val_raw
	raw_spacing = (max_val - min_val) / (num_values - 1)

	# Calculate order of magnitude of the spacing
	magnitude = np.floor(np.log10(raw_spacing))

	nice_factors = np.array([1, 2, 5, 10])
	normalized_spacing = raw_spacing / (10**magnitude)
	closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))]
	nice_spacing = closest_factor * (10**magnitude)

	nice_values = min_val + nice_spacing * np.arange(num_values)

	# Adjust if last value exceeds max_val
	if nice_values[-1] < max_val - nice_spacing:
		last_val = nice_values[-1]
		nice_values = np.append(nice_values, [last_val + nice_spacing])

	return [val for val in nice_values if min_val <= val <= max_val]


def format_tick_label(val):
	"""
	Format w/ scientific notation for large/small values.
	"""
	if val != 0:
		magnitude = np.abs(np.floor(np.log10(np.abs(val))))
		if magnitude > 2:
			return f'{val:.1e}'
		elif magnitude > 1:
			return f'{val:.0f}'
		elif magnitude > 0:
			return f'{val:.1f}'
		else:
			return f'{val:.2f}'
	else:
		return f'{val}'


def plot_simple(system, rho, fontscale = 1):
	# Major axis of outer boundary
	b2 = system.b2

	# Get x, y grids from 'PDESystem' object
	x, y = system.get_xy_grids()

	# wrap around values for continuity
	rho = np.append(rho, rho[:, 0:1], axis = 1)

	# Color bar limits
	vmin = np.nanmin(rho)
	vmax = np.nanmax(rho)

	fig = plt.figure(figsize = (5, 5))
	ax = plt.gca()

	im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
	ax.set_title('rho(x,y)', fontsize = fontscale * 16)
	for label in ax.get_xticklabels() + ax.get_yticklabels():
		label.set_fontsize(fontscale * 12)
	ax.set_aspect(1.0)
	fac = 1.05
	ax.set_xlim([-fac * b2, fac * b2])
	ax.set_ylim([-fac * b2, fac * b2])

	cbar = fig.colorbar(im, shrink = 0.8)

	# Set colorbar ticks and labels to "nice" values
	nice_values = find_nice_values(vmin, vmax, num_values = 5)
	cbar.set_ticks(nice_values)
	cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x)))

	return fig


def evaluate_2d_expression(expr_str, xvals, yvals):
	x, y = sympy.symbols('s t')
	expr = sympy.sympify(expr_str)
	f = sympy.lambdify((x, y), expr, modules = ['numpy'])
	result = f(xvals, yvals)
	if isinstance(result, (int, float)):
		return result * np.ones(xvals.shape)
	return f(xvals, yvals)


'''
# Currently unused in gradio interface
def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations):
	# Direct solution
	start = timeit.default_timer()

	pde_config = {}
	for key in ['nx1', 'nx2', 'nx3']:
		pde_config[key] = stnn_config[key]
	pde_config['ell'] = ell
	pde_config['eccentricity'] = eccentricity
	pde_config['a2'] = a2
	system = PDESystem(pde_config)

	try:
		ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
	except:
		raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
	try:
		obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
	except:
		raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")

	if np.any(np.isnan(ibf_data)):
		raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.")

	if np.any(np.isnan(obf_data)):
		raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.")

	ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)

	L_xp = csr_matrix(system.L)  # Sparse matrix representation of the PDE operator
	nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
	b_xp = asarray(b.reshape((nx1 * nx2 * nx3,)))  # r.h.s. vector

	def callback(res):
		print(f'GMRES residual: {res}')

	f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback)

	residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp))

	if info > 0:
		warnings.simplefilter('always')
		warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning)

	f = asnumpy(f_xp)
	rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1)
	direct_time = timeit.default_timer() - start
	print(f'Done with direct solution. Time: {direct_time} seconds.')

	fig = plot_simple(system, rho_direct)
	return fig, info
'''


def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str):
	if a2 <= eccentricity:
		raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).')

	pde_config = {}
	for key in ['nx1', 'nx2', 'nx3']:
		pde_config[key] = stnn_config[key]
	pde_config['ell'] = ell
	pde_config['eccentricity'] = eccentricity
	pde_config['a2'] = a2
	system = PDESystem(pde_config)

	try:
		ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
	except:
		raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
	try:
		obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
	except:
		raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")

	if np.any(np.isnan(ibf_data)):
		raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.")

	if np.any(np.isnan(obf_data)):
		raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.")

	# Permute and reshape boundary data to the format expected by the STNN model
	ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)

	'''
	# Currently unused in gradio interface
	ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id)
	'''

	# Load some relevant quantities from the config dictionaries
	ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max']
	a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max']
	nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3']

	# Combine boundary data in single vector
	bf = np.zeros((1, 2 * nx2, nx3 // 2))
	bf[:, :nx2, :] = ibf_data[np.newaxis, ...]
	bf[:, nx2:, :] = obf_data[np.newaxis, ...]

	# Normalize and combine parameters
	params = np.zeros((1, 3))
	params[0, 0] = (a2 - a2_min) / (a2_max - a2_min)
	params[0, 1] = (ell - ell_min) / (ell_max - ell_min)
	params[0, 2] = eccentricity

	rho = model.predict([params, bf])
	fig = plot_simple(system, rho[0, ...])
	return fig


with open('T5_config.json', 'r', encoding = 'utf-8') as json_file:
	stnn_config = json.load(json_file)

model = stnn.build_stnn(stnn_config)
model.load_weights('T5_weights.h5')

with gr.Blocks() as demo:
    gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo"
                "\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) "
                "to solve a parametric PDE problem on an elliptical annular domain. "
                "See the paper for a detailed description of the problem and its applications."
                "<br/>The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, "
                "including intructions for solving the PDE using a conventional iterative method (GMRES). "
                "Due to the long runtime of solving the PDE in this way, it is not included in the demo.")
    gr.Markdown("<br/>The PDE is "
                "$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, "
                "where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. "
                "Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to "
                "the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! "
                "on elliptical annular domains parameterized as shown below. ",
                latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}])
    with gr.Row():
        with gr.Column():
            gr.Markdown(
            "## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, "
            "and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False}, 
            {"left": "!", "right": "!", "display": True}])
            ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0)
            eccentricity_input = gr.Number(
                label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)",
                value = 0.5, minimum = 0.0, maximum = 0.999)
            a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0)
            gr.Markdown(
                "## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, "
                "related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. "
                "Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values "
                "between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points "
                "distributed uniformly over the allowable values of $s$ and $t$."
                "<br/><br/>For the PDE problem to be well-posed, boundary data should only be specified where "
                "$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the "
                "inward-pointing unit normal vector. This requirement constrains the allowable values of $t$."
                " and is automatically enforced when building boundary conditions from the user-specified expressions below.",
                latex_delimiters = [{"left": "$", "right": "$", "display": False}])

            inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))")
            outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)")

            submit_button = gr.Button("Submit")

        with gr.Column():
            gr.Markdown("## Predicted Solution")
            predicted_output_plot = gr.Plot()

    submit_button.click(
        fn = predict_pde_solution,
        inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary],
        outputs = [predicted_output_plot]
    )

demo.launch()