File size: 3,570 Bytes
d68c650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.cm import get_cmap


def plot_comparison(system, bf, rho, rho_pred, fontscale = 1, output_filename = 'comparison.png',
					wspace = -0.1, hspace = 0.5):
	"""
	Side-by-side comparison of contour plots for arguments 'rho' and 'rho_pred'.
	Also plots the boundary data (argument 'bf') as a contour plot.

	Args:
		system: PDESystem object
		bf (np.ndarray): 2D array representing the boundary data on the (x2, x3) grid.
		rho (np.ndarray): 2D array, known rho
		rho_pred (np.ndarray): 2D array, predicted rho
		fontscale (int, optional): A scaling factor for the font size in plots.
		output_filename (str, optional): name of the output file
		wspace (float, optional): horizontal spacing between subplots
		hspace (float, optional): vertical spacing between subplots

	Returns:
		Does not return a value. The figure is saved to a file with name given by 'output_filename'. The default
		is 'comparison.png'.
	"""
	# System parameters
	ell = system.params['ell']
	a2 = system.a2
	e = system.params['eccentricity']
	nx1, nx2, nx3 = system.params['nx1'], system.params['nx2'], system.params['nx3']

	# Get x, y grids from 'PDESystem' object
	x, y = system.get_xy_grids()

	# Relative error
	err = np.linalg.norm(rho - rho_pred)
	rel_err = err / np.linalg.norm(rho)

	# wrap around values for continuity
	rho = np.append(rho, rho[:, 0:1], axis = 1)
	rho_pred = np.append(rho_pred, rho_pred[:, 0:1], axis = 1)

	plots = [rho, rho_pred]

	# Color bar limits
	vmin = np.nanmin(rho_pred[:, :])
	vmax = np.nanmax(rho_pred[:, :])
	vmin = min(vmin, np.nanmin(rho[:, :]))
	vmax = max(vmax, np.nanmax(rho[:, :]))

	# Figure layout
	cbar_coords = (0.89, 0.11, 0.03, 0.77)
	fig = plt.figure(figsize = (24, 24))
	gs = fig.add_gridspec(11, 2, hspace = hspace, wspace = wspace)
	axs = [fig.add_subplot(gs[:11, 0]), fig.add_subplot(gs[:5, 1]), fig.add_subplot(gs[6:11, 1])]

	# boundary data contour plot
	bf_plot = np.nan * np.ones((2 * nx2, nx3))
	bf_plot[:nx2, :nx3 // 2] = bf[:nx2, :]
	bf_plot[nx2:, nx3 // 2:] = bf[nx2:, :]
	im = axs[0].imshow(bf_plot)
	axs[0].set_yticks([1, nx2 // 4, nx2 // 2 - 1, nx2 // 2 + 1, 3 * nx2 // 4, nx2])
	axs[0].set_yticklabels(['-pi', '0', 'pi', '', '0', 'pi'])
	axs[0].set_xticks([1, nx3 // 2, nx3])
	axs[0].set_xticklabels(['0', 'pi', '2pi'])
	axs[0].set_title('Boundary data\n\n', fontsize = fontscale * 48, y = 0.91)
	for label in axs[0].get_xticklabels() + axs[0].get_yticklabels():
		label.set_fontsize(fontscale * 48)
	divider = make_axes_locatable(axs[0])
	cax = divider.append_axes('bottom', size = "3%", pad = 1.5)
	cb = fig.colorbar(im, cax = cax, orientation = 'horizontal')
	cb.ax.tick_params(labelsize = fontscale * 48)
	cax.xaxis.set_ticks_position('bottom')

	# rho(x, y) contour plots
	titles = ['Direct Solution', 'Tensor network']
	for i, ax in enumerate(axs[1:]):
		z = plots[i]
		im = ax.contourf(x, y, z, levels = np.linspace(vmin, vmax, 100), cmap = get_cmap('hsv'))
		ax.set_title(titles[i], fontsize = fontscale * 48)
		for label in ax.get_xticklabels() + ax.get_yticklabels():
			label.set_fontsize(fontscale * 32)
		ax.set_aspect(1.0)

	cbar_ax = fig.add_axes(cbar_coords)
	cb = fig.colorbar(im, cax = cbar_ax, pad = 0.05)
	cb.ax.tick_params(labelsize = fontscale * 32)

	suptitle = f'ell = {ell:.3f}; a2 = {a2:.3f}; e = {e:.3f}; Relative error: {rel_err:.3f}'
	plt.suptitle(suptitle, fontsize = fontscale * 48, x = 0.1, y = 0.97, horizontalalignment = 'left')
	plt.savefig(output_filename)
	plt.close()