caleb2's picture
initial commit
d68c650
import json
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import sympy
from matplotlib.cm import get_cmap
from stnn.nn import stnn
from stnn.pde.pde_system import PDESystem
def adjust_to_nice_number(value, round_down = False):
"""
Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks.
"""
if value == 0:
return value
is_negative = False
if value < 0:
round_down = True
is_negative = True
value = -value
exponent = np.floor(np.log10(value)) # Find exponent of 10
fractional_part = value / 10**exponent # Find leading digit(s)
if round_down:
if fractional_part < 1.5:
nice_fractional = 1
elif fractional_part < 3:
nice_fractional = 2
elif fractional_part < 7:
nice_fractional = 5
else:
nice_fractional = 10
else:
if fractional_part <= 1:
nice_fractional = 1
elif fractional_part <= 2:
nice_fractional = 2
elif fractional_part <= 5:
nice_fractional = 5
else:
nice_fractional = 10
nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1)
if is_negative:
nice_value = -nice_value
return nice_value
def find_nice_values(min_val_raw, max_val, num_values = 4):
"""
Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks.
"""
# Calculate rough spacing between values
min_val = adjust_to_nice_number(min_val_raw)
frac_val = (min_val - min_val_raw) / (max_val - min_val_raw)
if frac_val < 1 / num_values:
min_val = min_val_raw
raw_spacing = (max_val - min_val) / (num_values - 1)
# Calculate order of magnitude of the spacing
magnitude = np.floor(np.log10(raw_spacing))
nice_factors = np.array([1, 2, 5, 10])
normalized_spacing = raw_spacing / (10**magnitude)
closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))]
nice_spacing = closest_factor * (10**magnitude)
nice_values = min_val + nice_spacing * np.arange(num_values)
# Adjust if last value exceeds max_val
if nice_values[-1] < max_val - nice_spacing:
last_val = nice_values[-1]
nice_values = np.append(nice_values, [last_val + nice_spacing])
return [val for val in nice_values if min_val <= val <= max_val]
def format_tick_label(val):
"""
Format w/ scientific notation for large/small values.
"""
if val != 0:
magnitude = np.abs(np.floor(np.log10(np.abs(val))))
if magnitude > 2:
return f'{val:.1e}'
elif magnitude > 1:
return f'{val:.0f}'
elif magnitude > 0:
return f'{val:.1f}'
else:
return f'{val:.2f}'
else:
return f'{val}'
def plot_simple(system, rho, fontscale = 1):
# Major axis of outer boundary
b2 = system.b2
# Get x, y grids from 'PDESystem' object
x, y = system.get_xy_grids()
# wrap around values for continuity
rho = np.append(rho, rho[:, 0:1], axis = 1)
# Color bar limits
vmin = np.nanmin(rho)
vmax = np.nanmax(rho)
fig = plt.figure(figsize = (5, 5))
ax = plt.gca()
im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
ax.set_title('rho(x,y)', fontsize = fontscale * 16)
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontsize(fontscale * 12)
ax.set_aspect(1.0)
fac = 1.05
ax.set_xlim([-fac * b2, fac * b2])
ax.set_ylim([-fac * b2, fac * b2])
cbar = fig.colorbar(im, shrink = 0.8)
# Set colorbar ticks and labels to "nice" values
nice_values = find_nice_values(vmin, vmax, num_values = 5)
cbar.set_ticks(nice_values)
cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x)))
return fig
def evaluate_2d_expression(expr_str, xvals, yvals):
x, y = sympy.symbols('s t')
expr = sympy.sympify(expr_str)
f = sympy.lambdify((x, y), expr, modules = ['numpy'])
result = f(xvals, yvals)
if isinstance(result, (int, float)):
return result * np.ones(xvals.shape)
return f(xvals, yvals)
'''
# Currently unused in gradio interface
def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations):
# Direct solution
start = timeit.default_timer()
pde_config = {}
for key in ['nx1', 'nx2', 'nx3']:
pde_config[key] = stnn_config[key]
pde_config['ell'] = ell
pde_config['eccentricity'] = eccentricity
pde_config['a2'] = a2
system = PDESystem(pde_config)
try:
ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
except:
raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
try:
obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
except:
raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")
if np.any(np.isnan(ibf_data)):
raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.")
if np.any(np.isnan(obf_data)):
raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.")
ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)
L_xp = csr_matrix(system.L) # Sparse matrix representation of the PDE operator
nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
b_xp = asarray(b.reshape((nx1 * nx2 * nx3,))) # r.h.s. vector
def callback(res):
print(f'GMRES residual: {res}')
f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback)
residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp))
if info > 0:
warnings.simplefilter('always')
warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning)
f = asnumpy(f_xp)
rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1)
direct_time = timeit.default_timer() - start
print(f'Done with direct solution. Time: {direct_time} seconds.')
fig = plot_simple(system, rho_direct)
return fig, info
'''
def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str):
if a2 <= eccentricity:
raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).')
pde_config = {}
for key in ['nx1', 'nx2', 'nx3']:
pde_config[key] = stnn_config[key]
pde_config['ell'] = ell
pde_config['eccentricity'] = eccentricity
pde_config['a2'] = a2
system = PDESystem(pde_config)
try:
ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice]
except:
raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.")
try:
obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice]
except:
raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.")
if np.any(np.isnan(ibf_data)):
raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.")
if np.any(np.isnan(obf_data)):
raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.")
# Permute and reshape boundary data to the format expected by the STNN model
ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data)
'''
# Currently unused in gradio interface
ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id)
'''
# Load some relevant quantities from the config dictionaries
ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max']
a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max']
nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3']
# Combine boundary data in single vector
bf = np.zeros((1, 2 * nx2, nx3 // 2))
bf[:, :nx2, :] = ibf_data[np.newaxis, ...]
bf[:, nx2:, :] = obf_data[np.newaxis, ...]
# Normalize and combine parameters
params = np.zeros((1, 3))
params[0, 0] = (a2 - a2_min) / (a2_max - a2_min)
params[0, 1] = (ell - ell_min) / (ell_max - ell_min)
params[0, 2] = eccentricity
rho = model.predict([params, bf])
fig = plot_simple(system, rho[0, ...])
return fig
with open('T5_config.json', 'r', encoding = 'utf-8') as json_file:
stnn_config = json.load(json_file)
model = stnn.build_stnn(stnn_config)
model.load_weights('T5_weights.h5')
with gr.Blocks() as demo:
gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo"
"\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) "
"to solve a parametric PDE problem on an elliptical annular domain. "
"See the paper for a detailed description of the problem and its applications."
"<br/>The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, "
"including intructions for solving the PDE using a conventional iterative method (GMRES). "
"Due to the long runtime of solving the PDE in this way, it is not included in the demo.")
gr.Markdown("<br/>The PDE is "
"$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, "
"where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. "
"Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to "
"the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! "
"on elliptical annular domains parameterized as shown below. ",
latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}])
with gr.Row():
with gr.Column():
gr.Markdown(
"## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, "
"and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False},
{"left": "!", "right": "!", "display": True}])
ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0)
eccentricity_input = gr.Number(
label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)",
value = 0.5, minimum = 0.0, maximum = 0.999)
a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0)
gr.Markdown(
"## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, "
"related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. "
"Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values "
"between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points "
"distributed uniformly over the allowable values of $s$ and $t$."
"<br/><br/>For the PDE problem to be well-posed, boundary data should only be specified where "
"$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the "
"inward-pointing unit normal vector. This requirement constrains the allowable values of $t$."
" and is automatically enforced when building boundary conditions from the user-specified expressions below.",
latex_delimiters = [{"left": "$", "right": "$", "display": False}])
inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))")
outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)")
submit_button = gr.Button("Submit")
with gr.Column():
gr.Markdown("## Predicted Solution")
predicted_output_plot = gr.Plot()
submit_button.click(
fn = predict_pde_solution,
inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary],
outputs = [predicted_output_plot]
)
demo.launch()