Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
import numpy as np | |
import sympy | |
from matplotlib.cm import get_cmap | |
from stnn.nn import stnn | |
from stnn.pde.pde_system import PDESystem | |
def adjust_to_nice_number(value, round_down = False): | |
""" | |
Adjust the given value to the nearest "nice" number. Used for colorbar tickmarks. | |
""" | |
if value == 0: | |
return value | |
is_negative = False | |
if value < 0: | |
round_down = True | |
is_negative = True | |
value = -value | |
exponent = np.floor(np.log10(value)) # Find exponent of 10 | |
fractional_part = value / 10**exponent # Find leading digit(s) | |
if round_down: | |
if fractional_part < 1.5: | |
nice_fractional = 1 | |
elif fractional_part < 3: | |
nice_fractional = 2 | |
elif fractional_part < 7: | |
nice_fractional = 5 | |
else: | |
nice_fractional = 10 | |
else: | |
if fractional_part <= 1: | |
nice_fractional = 1 | |
elif fractional_part <= 2: | |
nice_fractional = 2 | |
elif fractional_part <= 5: | |
nice_fractional = 5 | |
else: | |
nice_fractional = 10 | |
nice_value = nice_fractional * 10**exponent if round_down or nice_fractional != 10 else 10**(exponent + 1) | |
if is_negative: | |
nice_value = -nice_value | |
return nice_value | |
def find_nice_values(min_val_raw, max_val, num_values = 4): | |
""" | |
Calculate 'num_values' evenly spaced "nice" values within the given range. Used for colorbar tickmarks. | |
""" | |
# Calculate rough spacing between values | |
min_val = adjust_to_nice_number(min_val_raw) | |
frac_val = (min_val - min_val_raw) / (max_val - min_val_raw) | |
if frac_val < 1 / num_values: | |
min_val = min_val_raw | |
raw_spacing = (max_val - min_val) / (num_values - 1) | |
# Calculate order of magnitude of the spacing | |
magnitude = np.floor(np.log10(raw_spacing)) | |
nice_factors = np.array([1, 2, 5, 10]) | |
normalized_spacing = raw_spacing / (10**magnitude) | |
closest_factor = nice_factors[np.argmin(np.abs(nice_factors - normalized_spacing))] | |
nice_spacing = closest_factor * (10**magnitude) | |
nice_values = min_val + nice_spacing * np.arange(num_values) | |
# Adjust if last value exceeds max_val | |
if nice_values[-1] < max_val - nice_spacing: | |
last_val = nice_values[-1] | |
nice_values = np.append(nice_values, [last_val + nice_spacing]) | |
return [val for val in nice_values if min_val <= val <= max_val] | |
def format_tick_label(val): | |
""" | |
Format w/ scientific notation for large/small values. | |
""" | |
if val != 0: | |
magnitude = np.abs(np.floor(np.log10(np.abs(val)))) | |
if magnitude > 2: | |
return f'{val:.1e}' | |
elif magnitude > 1: | |
return f'{val:.0f}' | |
elif magnitude > 0: | |
return f'{val:.1f}' | |
else: | |
return f'{val:.2f}' | |
else: | |
return f'{val}' | |
def plot_simple(system, rho, fontscale = 1): | |
# Major axis of outer boundary | |
b2 = system.b2 | |
# Get x, y grids from 'PDESystem' object | |
x, y = system.get_xy_grids() | |
# wrap around values for continuity | |
rho = np.append(rho, rho[:, 0:1], axis = 1) | |
# Color bar limits | |
vmin = np.nanmin(rho) | |
vmax = np.nanmax(rho) | |
fig = plt.figure(figsize = (5, 5)) | |
ax = plt.gca() | |
im = ax.contourf(x, y, rho, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv')) | |
ax.set_title('rho(x,y)', fontsize = fontscale * 16) | |
for label in ax.get_xticklabels() + ax.get_yticklabels(): | |
label.set_fontsize(fontscale * 12) | |
ax.set_aspect(1.0) | |
fac = 1.05 | |
ax.set_xlim([-fac * b2, fac * b2]) | |
ax.set_ylim([-fac * b2, fac * b2]) | |
cbar = fig.colorbar(im, shrink = 0.8) | |
# Set colorbar ticks and labels to "nice" values | |
nice_values = find_nice_values(vmin, vmax, num_values = 5) | |
cbar.set_ticks(nice_values) | |
cbar.ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: format_tick_label(x))) | |
return fig | |
def evaluate_2d_expression(expr_str, xvals, yvals): | |
x, y = sympy.symbols('s t') | |
expr = sympy.sympify(expr_str) | |
f = sympy.lambdify((x, y), expr, modules = ['numpy']) | |
result = f(xvals, yvals) | |
if isinstance(result, (int, float)): | |
return result * np.ones(xvals.shape) | |
return f(xvals, yvals) | |
''' | |
# Currently unused in gradio interface | |
def direct_solution(ell, a2, eccentricity, ibc_str, obc_str, max_krylov_dim, max_iterations): | |
# Direct solution | |
start = timeit.default_timer() | |
pde_config = {} | |
for key in ['nx1', 'nx2', 'nx3']: | |
pde_config[key] = stnn_config[key] | |
pde_config['ell'] = ell | |
pde_config['eccentricity'] = eccentricity | |
pde_config['a2'] = a2 | |
system = PDESystem(pde_config) | |
try: | |
ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice] | |
except: | |
raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.") | |
try: | |
obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice] | |
except: | |
raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.") | |
if np.any(np.isnan(ibf_data)): | |
raise ValueError(f"The expression `{ibc_str}` evaluates to nan at one or more grid points.") | |
if np.any(np.isnan(obf_data)): | |
raise ValueError(f"The expression `{obc_str}` evaluates to nan at one or more grid points.") | |
ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data) | |
L_xp = csr_matrix(system.L) # Sparse matrix representation of the PDE operator | |
nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3'] | |
b_xp = asarray(b.reshape((nx1 * nx2 * nx3,))) # r.h.s. vector | |
def callback(res): | |
print(f'GMRES residual: {res}') | |
f_xp, info = spx.linalg.gmres(L_xp, b_xp, maxiter=max_iterations, tol=1e-7, restart=max_krylov_dim, callback=callback) | |
residual = (xp.linalg.norm(b_xp - L_xp @ f_xp) / xp.linalg.norm(b_xp)) | |
if info > 0: | |
warnings.simplefilter('always') | |
warnings.warn(f'GMRES solver did not converge. Number of iterations: {info}; residual: {residual}', RuntimeWarning) | |
f = asnumpy(f_xp) | |
rho_direct = np.sum(f.reshape((nx1, nx2, nx3)), axis=-1) | |
direct_time = timeit.default_timer() - start | |
print(f'Done with direct solution. Time: {direct_time} seconds.') | |
fig = plot_simple(system, rho_direct) | |
return fig, info | |
''' | |
def predict_pde_solution(ell, a2, eccentricity, ibc_str, obc_str): | |
if a2 <= eccentricity: | |
raise ValueError(f'Outer minor axis must be greater than the eccentricity (here, {eccentricity}).') | |
pde_config = {} | |
for key in ['nx1', 'nx2', 'nx3']: | |
pde_config[key] = stnn_config[key] | |
pde_config['ell'] = ell | |
pde_config['eccentricity'] = eccentricity | |
pde_config['a2'] = a2 | |
system = PDESystem(pde_config) | |
try: | |
ibf_data = evaluate_2d_expression(ibc_str, system.x2_ib, system.x2_ib - system.x3_ib)[system.ib_slice] | |
except: | |
raise ValueError(f"Failed to parse the expression `{ibc_str}` for the boundary condition @ the inner boundary.") | |
try: | |
obf_data = evaluate_2d_expression(obc_str, system.x2_ob, system.x2_ob - system.x3_ob)[system.ob_slice] | |
except: | |
raise ValueError(f"Failed to parse the expression `{obc_str}` for the boundary condition @ the outer boundary.") | |
if np.any(np.isnan(ibf_data)): | |
raise ValueError(f"The expression `{ibc_str}` evaluates to NaN at one or more grid points.") | |
if np.any(np.isnan(obf_data)): | |
raise ValueError(f"The expression `{obc_str}` evaluates to NaN at one or more grid points.") | |
# Permute and reshape boundary data to the format expected by the STNN model | |
ibf_data, obf_data, b = system.convert_boundary_data(ibf_data, obf_data) | |
''' | |
# Currently unused in gradio interface | |
ibf_data, obf_data, b, _ = system.generate_random_bc(func_gen_id) | |
''' | |
# Load some relevant quantities from the config dictionaries | |
ell_min, ell_max = stnn_config['ell_min'], stnn_config['ell_max'] | |
a2_min, a2_max = stnn_config['a2_min'], stnn_config['a2_max'] | |
nx1, nx2, nx3 = pde_config['nx1'], pde_config['nx2'], pde_config['nx3'] | |
# Combine boundary data in single vector | |
bf = np.zeros((1, 2 * nx2, nx3 // 2)) | |
bf[:, :nx2, :] = ibf_data[np.newaxis, ...] | |
bf[:, nx2:, :] = obf_data[np.newaxis, ...] | |
# Normalize and combine parameters | |
params = np.zeros((1, 3)) | |
params[0, 0] = (a2 - a2_min) / (a2_max - a2_min) | |
params[0, 1] = (ell - ell_min) / (ell_max - ell_min) | |
params[0, 2] = eccentricity | |
rho = model.predict([params, bf]) | |
fig = plot_simple(system, rho[0, ...]) | |
return fig | |
with open('T5_config.json', 'r', encoding = 'utf-8') as json_file: | |
stnn_config = json.load(json_file) | |
model = stnn.build_stnn(stnn_config) | |
model.load_weights('T5_weights.h5') | |
with gr.Blocks() as demo: | |
gr.Markdown("# Stacked Tensorial Neural Network (STNN) demo" | |
"\nThis demo uses the model architecture from [arXiv:2312.14979](https://arxiv.org/abs/2312.14979) " | |
"to solve a parametric PDE problem on an elliptical annular domain. " | |
"See the paper for a detailed description of the problem and its applications." | |
"<br/>The [GitHub repo](https://github.com/caleb399/stacked_tensorial_nn) contains additional examples, " | |
"including intructions for solving the PDE using a conventional iterative method (GMRES). " | |
"Due to the long runtime of solving the PDE in this way, it is not included in the demo.") | |
gr.Markdown("<br/>The PDE is " | |
"$\ell \\left( \\boldsymbol{\hat{u}} \cdot \\nabla \\right) f(\\boldsymbol{r}, w) = \partial_{ww} f(\\boldsymbol{r}, w)$, " | |
"where $\ell$ is a parameter and $\\boldsymbol{\hat{u}} = (\\cos w, \\sin w)$. " | |
"Here, $\\boldsymbol{r}$ is the 2D position vector, and $w$ is an angular coordinate unrelated to " | |
"the spatial domain. The model predicts the density !\\rho(\\boldsymbol{r}) = \int f(\\boldsymbol{r}, w) dw! " | |
"on elliptical annular domains parameterized as shown below. ", | |
latex_delimiters = [{"left": "$", "right": "$", "display": False}, {"left": "!", "right": "!", "display": True}]) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"## PDE Parameters \n The model was trained on solutions of the PDE with $\ell$ between 0.01 and 100, $a$ between 2 and 20, " | |
"and $ecc$ between 0 and 0.8.", latex_delimiters = [{"left": "$", "right": "$", "display": False}, | |
{"left": "!", "right": "!", "display": True}]) | |
ell_input = gr.Number(label = "ell (must be > 0)", value = 1.0) | |
eccentricity_input = gr.Number( | |
label = "ecc: eccentricity of the inner boundary (must be >= 0 and <= 0.999)", | |
value = 0.5, minimum = 0.0, maximum = 0.999) | |
a2_input = gr.Number(label = "a: Minor axis of outer boundary (must be > eccentricity)", value = 2.0) | |
gr.Markdown( | |
"## Boundary Conditions \n $(s, t)$ are angular coordinates parameterizing the PDE domain, " | |
"related to $\\boldsymbol{r}$ and $w$ by a coordinate transformation. " | |
"Specifically, $s$ is the polar elliptical coordinate along the boundary (inner or outer), with values " | |
"between $-\pi$ and $\pi$, while $t = s - w$. Boundary conditions are generated from grid points " | |
"distributed uniformly over the allowable values of $s$ and $t$." | |
"<br/><br/>For the PDE problem to be well-posed, boundary data should only be specified where " | |
"$\\boldsymbol{\hat{u}} \cdot \\boldsymbol{\hat{n}} > 0$, where $\\boldsymbol{\hat{n}}$ is the " | |
"inward-pointing unit normal vector. This requirement constrains the allowable values of $t$." | |
" and is automatically enforced when building boundary conditions from the user-specified expressions below.", | |
latex_delimiters = [{"left": "$", "right": "$", "display": False}]) | |
inner_boundary = gr.Textbox(label = "Inner boundary condition", value = "0.5 * (1 + sign(cos(s)))") | |
outer_boundary = gr.Textbox(label = "Outer boundary condition", value = "1 + 0.1 * cos(4*s)") | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
gr.Markdown("## Predicted Solution") | |
predicted_output_plot = gr.Plot() | |
submit_button.click( | |
fn = predict_pde_solution, | |
inputs = [ell_input, a2_input, eccentricity_input, inner_boundary, outer_boundary], | |
outputs = [predicted_output_plot] | |
) | |
demo.launch() | |