Spaces:
Sleeping
Sleeping
import numpy as np | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
from matplotlib.cm import get_cmap | |
def plot_comparison(system, bf, rho, rho_pred, fontscale = 1, output_filename = 'comparison.png', | |
wspace = -0.1, hspace = 0.5): | |
""" | |
Side-by-side comparison of contour plots for arguments 'rho' and 'rho_pred'. | |
Also plots the boundary data (argument 'bf') as a contour plot. | |
Args: | |
system: PDESystem object | |
bf (np.ndarray): 2D array representing the boundary data on the (x2, x3) grid. | |
rho (np.ndarray): 2D array, known rho | |
rho_pred (np.ndarray): 2D array, predicted rho | |
fontscale (int, optional): A scaling factor for the font size in plots. | |
output_filename (str, optional): name of the output file | |
wspace (float, optional): horizontal spacing between subplots | |
hspace (float, optional): vertical spacing between subplots | |
Returns: | |
Does not return a value. The figure is saved to a file with name given by 'output_filename'. The default | |
is 'comparison.png'. | |
""" | |
# System parameters | |
ell = system.params['ell'] | |
a2 = system.a2 | |
e = system.params['eccentricity'] | |
nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3'] | |
# Get x, y grids from 'PDESystem' object | |
x, y = system.get_xy_grids() | |
# Relative error | |
err = np.linalg.norm(rho - rho_pred) | |
rel_err = err / np.linalg.norm(rho) | |
# wrap around values for continuity | |
rho = np.append(rho, rho[:, 0:1], axis = 1) | |
rho_pred = np.append(rho_pred, rho_pred[:, 0:1], axis = 1) | |
plots = [rho, rho_pred] | |
# Color bar limits | |
vmin = np.nanmin(rho_pred[:, :]) | |
vmax = np.nanmax(rho_pred[:, :]) | |
vmin = min(vmin, np.nanmin(rho[:, :])) | |
vmax = max(vmax, np.nanmax(rho[:, :])) | |
# Figure layout | |
cbar_coords = (0.89, 0.11, 0.03, 0.77) | |
fig = plt.figure(figsize = (24, 24)) | |
gs = fig.add_gridspec(11, 2, hspace = hspace, wspace = wspace) | |
axs = [fig.add_subplot(gs[:11, 0]), fig.add_subplot(gs[:5, 1]), fig.add_subplot(gs[6:11, 1])] | |
# boundary data contour plot | |
bf_plot = np.nan * np.ones((2 * nx2, nx3)) | |
bf_plot[:nx2, :nx3 // 2] = bf[:nx2, :] | |
bf_plot[nx2:, nx3 // 2:] = bf[nx2:, :] | |
im = axs[0].imshow(bf_plot) | |
axs[0].set_yticks([1, nx2 // 4, nx2 // 2 - 1, nx2 // 2 + 1, 3 * nx2 // 4, nx2]) | |
axs[0].set_yticklabels(['-pi', '0', 'pi', '', '0', 'pi']) | |
axs[0].set_xticks([1, nx3 // 2, nx3]) | |
axs[0].set_xticklabels(['0', 'pi', '2pi']) | |
axs[0].set_title('Boundary data\n\n', fontsize = fontscale * 48, y = 0.91) | |
for label in axs[0].get_xticklabels() + axs[0].get_yticklabels(): | |
label.set_fontsize(fontscale * 48) | |
divider = make_axes_locatable(axs[0]) | |
cax = divider.append_axes('bottom', size = "3%", pad = 1.5) | |
cb = fig.colorbar(im, cax = cax, orientation = 'horizontal') | |
cb.ax.tick_params(labelsize = fontscale * 48) | |
cax.xaxis.set_ticks_position('bottom') | |
# rho(x, y) contour plots | |
titles = ['Direct Solution', 'Tensor network'] | |
for i, ax in enumerate(axs[1:]): | |
z = plots[i] | |
im = ax.contourf(x, y, z, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv')) | |
ax.set_title(titles[i], fontsize = fontscale * 48) | |
for label in ax.get_xticklabels() + ax.get_yticklabels(): | |
label.set_fontsize(fontscale * 32) | |
ax.set_aspect(1.0) | |
cbar_ax = fig.add_axes(cbar_coords) | |
cb = fig.colorbar(im, cax = cbar_ax, pad = 0.05) | |
cb.ax.tick_params(labelsize = fontscale * 32) | |
suptitle = f'ell = {ell:.3f}; a2 = {a2:.3f}; e = {e:.3f}; Relative error: {rel_err:.3f}' | |
plt.suptitle(suptitle, fontsize = fontscale * 48, x = 0.1, y = 0.97, horizontalalignment = 'left') | |
plt.savefig(output_filename) | |
plt.close() | |