import json import gradio as gr import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np import sympy from matplotlib.cm import get_cmap from stnn.nn import stnn from stnn.pde.pde_system import PDESystem def adjust_to_nice_number(value, round_down = False): """ Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks. """ if value == 0: return value is_negative = False if value < 0: round_down = True is_negative = True value = -value exponent = np.floor(np.log10(value)) # Find exponent of 10 fractional_part = value / 10**exponent # Find leading digit(s) if round_down: if fractional_part < 1.5: nice_fractional = 1 elif fractional_part < 3: nice_fractional = 2 elif fractional_part < 7: nice_fractional = 5 else: nice_fractional = 10 else: if fractional_part <= 1: nice_fractional = 1 elif fractional_part <= 2: nice_fractional = 2 elif fractional_part <= 5: nice_fractional = 5 else: nice_fractional = 10 nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1) if is_negative: nice_value = -nice_value return nice_value def find_nice_values(min_val_raw, max_val, num_values = 4): """ Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks. """ # Calculate rough spacing between values min_val = adjust_to_nice_number(min_val_raw) frac_val = (min_val - min_val_raw) / (max_val - min_val_raw) if frac_val < 1 / num_values: min_val = min_val_raw raw_spacing = (max_val - min_val) / (num_values - 1) # Calculate order of magnitude of the spacing magnitude = np.floor(np.log10(raw_spacing)) nice_factors = np.array([1, 2, 5, 10]) normalized_spacing = raw_spacing / (10**magnitude) closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))] nice_spacing = closest_factor * (10**magnitude) nice_values = min_val + nice_spacing * np.arange(num_values) # Adjust if last value exceeds max_val if nice_values[-1] < max_val - nice_spacing: last_val = nice_values[-1] nice_values = np.append(nice_values, [last_val + nice_spacing]) return [val for val in nice_values if min_val <= val <= max_val] def format_tick_label(val): """ Format w/ scientific notation for large/small values. """ if val != 0: magnitude = np.abs(np.floor(np.log10(np.abs(val)))) if magnitude > 2: return f'{val:.1e}' elif magnitude > 1: return f'{val:.0f}' elif magnitude > 0: return f'{val:.1f}' else: return f'{val:.2f}' else: return f'{val}' def plot_simple(system, rho, fontscale = 1): # Major axis of outer boundary b2 = system.b2 # Get x, y grids from 'PDESystem' object x, y = system.get_xy_grids() # wrap around values for continuity rho = np.append(rho, rho[:, 0:1], axis = 1) # Color bar limits vmin = np.nanmin(rho) vmax = np.nanmax(rho) fig = plt.figure(figsize = (5, 5)) ax = plt.gca() im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv')) ax.set_title('rho(x,y)', fontsize = fontscale * 16) for label in ax.get_xticklabels() + ax.get_yticklabels(): label.set_fontsize(fontscale * 12) ax.set_aspect(1.0) fac = 1.05 ax.set_xlim([-fac * b2, fac * b2]) ax.set_ylim([-fac * b2, fac * b2]) cbar = fig.colorbar(im, shrink = 0.8) # Set colorbar ticks and labels to "nice" values nice_values = find_nice_values(vmin, vmax, num_values = 5) cbar.set_ticks(nice_values) cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x))) return fig def evaluate_2d_expression(expr_str, xvals, yvals): x, y = sympy.symbols('s t') expr = sympy.sympify(expr_str) f = sympy.lambdify((x, y), expr, modules = ['numpy']) result = f(xvals, yvals) if isinstance(result, (int, float)): return result * np.ones(xvals.shape) return f(xvals, yvals) ''' # Currently unused in gradio interface def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations): # Direct solution start = timeit.default_timer() pde_config = {} for key in ['nx1', 'nx2', 'nx3']: pde_config[key] = stnn_config[key] pde_config['ell'] = ell pde_config['eccentricity'] = eccentricity pde_config['a2'] = a2 system = PDESystem(pde_config) try: ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice] except: raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.") try: obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice] except: raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.") if np.any(np.isnan(ibf_data)): raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.") if np.any(np.isnan(obf_data)): raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.") ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data) L_xp = csr_matrix(system.L) # Sparse matrix representation of the PDE operator nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3'] b_xp = asarray(b.reshape((nx1 * nx2 * nx3,))) # r.h.s. vector def callback(res): print(f'GMRES residual: {res}') f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback) residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp)) if info > 0: warnings.simplefilter('always') warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning) f = asnumpy(f_xp) rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1) direct_time = timeit.default_timer() - start print(f'Done with direct solution. Time: {direct_time} seconds.') fig = plot_simple(system, rho_direct) return fig, info ''' def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str): if a2 <= eccentricity: raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).') pde_config = {} for key in ['nx1', 'nx2', 'nx3']: pde_config[key] = stnn_config[key] pde_config['ell'] = ell pde_config['eccentricity'] = eccentricity pde_config['a2'] = a2 system = PDESystem(pde_config) try: ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice] except: raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.") try: obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice] except: raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.") if np.any(np.isnan(ibf_data)): raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.") if np.any(np.isnan(obf_data)): raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.") # Permute and reshape boundary data to the format expected by the STNN model ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data) ''' # Currently unused in gradio interface ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id) ''' # Load some relevant quantities from the config dictionaries ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max'] a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max'] nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3'] # Combine boundary data in single vector bf = np.zeros((1, 2 * nx2, nx3 // 2)) bf[:, :nx2, :] = ibf_data[np.newaxis, ...] bf[:, nx2:, :] = obf_data[np.newaxis, ...] # Normalize and combine parameters params = np.zeros((1, 3)) params[0, 0] = (a2 - a2_min) / (a2_max - a2_min) params[0, 1] = (ell - ell_min) / (ell_max - ell_min) params[0, 2] = eccentricity rho = model.predict([params, bf]) fig = plot_simple(system, rho[0, ...]) return fig with open('T5_config.json', 'r', encoding = 'utf-8') as json_file: stnn_config = json.load(json_file) model = stnn.build_stnn(stnn_config) model.load_weights('T5_weights.h5') with gr.Blocks() as demo: gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo" "\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) " "to solve a parametric PDE problem on an elliptical annular domain. " "See the paper for a detailed description of the problem and its applications." "
The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, " "including intructions for solving the PDE using a conventional iterative method (GMRES). " "Due to the long runtime of solving the PDE in this way, it is not included in the demo.") gr.Markdown("
The PDE is " "$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, " "where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. " "Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to " "the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! " "on elliptical annular domains parameterized as shown below. ", latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}]) with gr.Row(): with gr.Column(): gr.Markdown( "## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, " "and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}]) ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0) eccentricity_input = gr.Number( label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)", value = 0.5, minimum = 0.0, maximum = 0.999) a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0) gr.Markdown( "## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, " "related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. " "Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values " "between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points " "distributed uniformly over the allowable values of $s$ and $t$." "

For the PDE problem to be well-posed, boundary data should only be specified where " "$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the " "inward-pointing unit normal vector. This requirement constrains the allowable values of $t$." " and is automatically enforced when building boundary conditions from the user-specified expressions below.", latex_delimiters = [{"left": "$", "right": "$", "display": False}]) inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))") outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)") submit_button = gr.Button("Submit") with gr.Column(): gr.Markdown("## Predicted Solution") predicted_output_plot = gr.Plot() submit_button.click( fn = predict_pde_solution, inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary], outputs = [predicted_output_plot] ) demo.launch()