caleb2's picture
initial commit
d68c650
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.cm import get_cmap
def plot_comparison(system, bf, rho, rho_pred, fontscale = 1, output_filename = 'comparison.png',
wspace = -0.1, hspace = 0.5):
"""
Side-by-side comparison of contour plots for arguments 'rho' and 'rho_pred'.
Also plots the boundary data (argument 'bf') as a contour plot.
Args:
system: PDESystem object
bf (np.ndarray): 2D array representing the boundary data on the (x2, x3) grid.
rho (np.ndarray): 2D array, known rho
rho_pred (np.ndarray): 2D array, predicted rho
fontscale (int, optional): A scaling factor for the font size in plots.
output_filename (str, optional): name of the output file
wspace (float, optional): horizontal spacing between subplots
hspace (float, optional): vertical spacing between subplots
Returns:
Does not return a value. The figure is saved to a file with name given by 'output_filename'. The default
is 'comparison.png'.
"""
# System parameters
ell = system.params['ell']
a2 = system.a2
e = system.params['eccentricity']
nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']
# Get x, y grids from 'PDESystem' object
x, y = system.get_xy_grids()
# Relative error
err = np.linalg.norm(rho - rho_pred)
rel_err = err / np.linalg.norm(rho)
# wrap around values for continuity
rho = np.append(rho, rho[:, 0:1], axis = 1)
rho_pred = np.append(rho_pred, rho_pred[:, 0:1], axis = 1)
plots = [rho, rho_pred]
# Color bar limits
vmin = np.nanmin(rho_pred[:, :])
vmax = np.nanmax(rho_pred[:, :])
vmin = min(vmin, np.nanmin(rho[:, :]))
vmax = max(vmax, np.nanmax(rho[:, :]))
# Figure layout
cbar_coords = (0.89, 0.11, 0.03, 0.77)
fig = plt.figure(figsize = (24, 24))
gs = fig.add_gridspec(11, 2, hspace = hspace, wspace = wspace)
axs = [fig.add_subplot(gs[:11, 0]), fig.add_subplot(gs[:5, 1]), fig.add_subplot(gs[6:11, 1])]
# boundary data contour plot
bf_plot = np.nan * np.ones((2 * nx2, nx3))
bf_plot[:nx2, :nx3 // 2] = bf[:nx2, :]
bf_plot[nx2:, nx3 // 2:] = bf[nx2:, :]
im = axs[0].imshow(bf_plot)
axs[0].set_yticks([1, nx2 // 4, nx2 // 2 - 1, nx2 // 2 + 1, 3 * nx2 // 4, nx2])
axs[0].set_yticklabels(['-pi', '0', 'pi', '', '0', 'pi'])
axs[0].set_xticks([1, nx3 // 2, nx3])
axs[0].set_xticklabels(['0', 'pi', '2pi'])
axs[0].set_title('Boundary data\n\n', fontsize = fontscale * 48, y = 0.91)
for label in axs[0].get_xticklabels() + axs[0].get_yticklabels():
label.set_fontsize(fontscale * 48)
divider = make_axes_locatable(axs[0])
cax = divider.append_axes('bottom', size = "3%", pad = 1.5)
cb = fig.colorbar(im, cax = cax, orientation = 'horizontal')
cb.ax.tick_params(labelsize = fontscale * 48)
cax.xaxis.set_ticks_position('bottom')
# rho(x, y) contour plots
titles = ['Direct Solution', 'Tensor network']
for i, ax in enumerate(axs[1:]):
z = plots[i]
im = ax.contourf(x, y, z, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
ax.set_title(titles[i], fontsize = fontscale * 48)
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontsize(fontscale * 32)
ax.set_aspect(1.0)
cbar_ax = fig.add_axes(cbar_coords)
cb = fig.colorbar(im, cax = cbar_ax, pad = 0.05)
cb.ax.tick_params(labelsize = fontscale * 32)
suptitle = f'ell = {ell:.3f}; a2 = {a2:.3f}; e = {e:.3f}; Relative error: {rel_err:.3f}'
plt.suptitle(suptitle, fontsize = fontscale * 48, x = 0.1, y = 0.97, horizontalalignment = 'left')
plt.savefig(output_filename)
plt.close()