demo / src /visualisation.py
Xmaster6y's picture
new repo structure
0d998a6 unverified
"""
Visualisation utils.
"""
import chess
import chess.svg
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000)
ALPHA = 1.0
NORM = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=False)
def render_heatmap(
board,
heatmap,
square=None,
vmin=None,
vmax=None,
arrows=None,
normalise="none",
):
"""
Render a heatmap on the board.
"""
if normalise == "abs":
a_max = heatmap.abs().max()
if a_max != 0:
heatmap = heatmap / a_max
vmin = -1
vmax = 1
if vmin is None:
vmin = heatmap.min()
if vmax is None:
vmax = heatmap.max()
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
color_dict = {}
for square_index in range(64):
color = COLOR_MAP(norm(heatmap[square_index]))
color = (*color[:3], ALPHA)
color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True)
fig = plt.figure(figsize=(6, 0.6))
ax = plt.gca()
ax.axis("off")
fig.colorbar(
matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP),
ax=ax,
orientation="horizontal",
fraction=1.0,
)
if square is not None:
try:
check = chess.parse_square(square)
except ValueError:
check = None
else:
check = None
if arrows is None:
arrows = []
plt.close()
return (
chess.svg.board(
board,
check=check,
fill=color_dict,
size=350,
arrows=arrows,
),
fig,
)
def render_policy_distribution(
policy,
legal_moves,
n_bins=20,
):
"""
Render the policy distribution histogram.
"""
legal_mask = torch.Tensor([move in legal_moves for move in range(1858)]).bool()
fig = plt.figure(figsize=(6, 6))
ax = plt.gca()
_, bins = np.histogram(policy, bins=n_bins)
ax.hist(
policy[~legal_mask],
bins=bins,
alpha=0.5,
density=True,
label="Illegal moves",
)
ax.hist(
policy[legal_mask],
bins=bins,
alpha=0.5,
density=True,
label="Legal moves",
)
plt.xlabel("Policy")
plt.ylabel("Density")
plt.legend()
plt.yscale("log")
return fig