blt-entropy-patcher / bytelatent /plotting /entropy_figure_via_matplot_lib.py
luca-peric's picture
more finishing touches
ad774a9
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str, threshold: float):
patch_lengths_np = patch_lengths.cpu().numpy().flatten()
scores_np = scores.cpu().float().numpy().flatten()
chars = chars.replace(" ", "_")
tokens_np = np.array([char for char in "<"+chars])
if len(scores_np) != len(tokens_np):
raise ValueError("Length of scores and tokens tensors must be the same.")
if patch_lengths_np.sum() != len(tokens_np):
raise ValueError(f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
f"does not match the length of tokens/scores ({len(tokens_np)}).")
x_indices = np.arange(len(tokens_np))
# Calculate cumulative sums of patch lengths for vertical line positions
# These indicate the *end* index of each patch
patch_boundaries = np.cumsum(patch_lengths_np)
# --- Plotting ---
fig, ax = plt.subplots(figsize=(15, 5)) # Adjust figure size as needed
# Plot the scores as a blue line with markers
ax.plot(x_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores')
# Plot the vertical dotted lines at the patch boundaries
# We plot a line *after* each patch, so at index `boundary - 1 + 0.5`
# We skip the last boundary as it's the end of the data
for boundary in patch_boundaries[:-1]:
ax.axvline(x=boundary, color='grey', linestyle='--', linewidth=1)
ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1, )
ax.annotate(f'Entropy Threshold',
xy=(0.05, threshold), # Position of the line
xytext=(0.05, threshold + 0.1), # Text position
xycoords='axes fraction', # Use axes coordinates (0-1)
textcoords='data', # Use data coordinates for text
color='red'
)
# Set x-axis ticks and labels
ax.set_xticks(x_indices)
ax.set_xticklabels(tokens_np, rotation=0, fontsize=8) # Rotate labels for better readability
# Set labels for axes
# Using the Y-axis label from the example image
ax.set_ylabel("Entropy of Next Byte", fontsize=12)
ax.set_xlabel("Tokens", fontsize=12)
# Set y-axis limits (optional, but often good practice)
ax.set_ylim(bottom=0) # Start y-axis at 0 like the example
ax.set_xlim(left = x_indices[0]-1.0, right = x_indices[-1]+1.0) # Add padding to x-axis
# Add grid lines (optional)
# ax.grid(True, axis='y', linestyle=':', color='lightgrey')
# Remove the top and right spines for cleaner look (optional)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Adjust layout and display the plot
plt.tight_layout()
return fig
# output_filename = "token_score_plot.png"
# fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
# print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path