blt-entropy-patcher / bytelatent /plotting /entropy_figure_via_matplot_lib.py
lucalp's picture
Fixing proper UTF-8 representation
8eea094
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
def plot_entropies( # Renamed from plot_entropies_revised for final output
patch_lengths: torch.Tensor,
scores: torch.Tensor,
tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption.
chars: str,
threshold: float
):
patch_lengths_np = patch_lengths.cpu().numpy().flatten()
scores_np = scores.cpu().float().numpy().flatten()
num_total_bytes_from_scores = len(scores_np)
# Prepare display string (prepend '<', replace spaces with '_')
display_string_processed_chars = chars.replace(" ", "_")
display_string = "<" + display_string_processed_chars
display_chars_list = list(display_string)
num_display_chars = len(display_chars_list)
if num_display_chars == 0 and num_total_bytes_from_scores == 0:
fig, ax = plt.subplots(figsize=(15,5))
ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12)
ax.set_xlabel("Characters (on underlying byte sequence)")
ax.set_ylabel("Entropy of Next Byte")
ax.set_ylim(bottom=0)
ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot
return fig
elif num_display_chars == 0 and num_total_bytes_from_scores > 0:
# Edge case: scores exist but no characters to map them to (implies an issue)
# For now, proceed with byte plot but no char labels. Or raise error.
# Assuming display_chars_list should not be empty if scores_np is not.
# This case should ideally be caught by byte_counts_per_display_char validation if it were run.
# If display_chars_list is truly empty but scores are not, an error should be raised by validation.
pass # Will be caught by validation if sum(byte_counts) != len(scores)
# Calculate byte counts for each character in the display string (assuming UTF-8)
try:
byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list]
except UnicodeEncodeError as e:
raise ValueError(
f"Could not encode characters in 'chars' string using UTF-8. "
f"Problematic part: '{display_string_processed_chars}'. Error: {e}"
)
# --- Validations ---
if sum(byte_counts_per_display_char) != num_total_bytes_from_scores:
# This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0
raise ValueError(
f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string "
f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) "
f"does not match length of scores tensor ({num_total_bytes_from_scores}). "
f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence "
f"represented by 'scores'/'tokens'."
)
if patch_lengths_np.sum() != num_total_bytes_from_scores:
raise ValueError(
f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
f"does not match length of scores ({num_total_bytes_from_scores})."
)
# --- Plotting Setup ---
fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested
x_byte_indices = np.arange(num_total_bytes_from_scores)
# --- Plot Scores (Horizontally per byte) ---
# Original plot line style from user's code: marker='.', linestyle='-'
ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte')
# --- Plot Vertical Patch Boundary Lines ---
# Using (cumulative_length - 0.5) logic for lines between byte elements.
# This matches the intent of `boundary - 1 + 0.5` from user's original code snippet.
patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np)
for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data)
ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1)
# --- Horizontal Threshold Line and Annotation ---
ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
ax.annotate(f'Entropy Threshold', # Original text from user's code
xy=(0.05, threshold), # Original xy from user's code
xytext=(0.05, threshold + 0.1),# Original xytext from user's code
xycoords='axes fraction', # Original xycoords
textcoords='data', # Original textcoords
color='red'
)
# --- X-axis Ticks and Labels (Character labels at start of their byte sequences) ---
char_label_positions = []
char_labels_for_ticks = []
current_byte_tracker = 0
if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty
for i_char in range(num_display_chars):
char_label_positions.append(current_byte_tracker)
char_labels_for_ticks.append(display_chars_list[i_char])
current_byte_tracker += byte_counts_per_display_char[i_char]
ax.set_xticks(char_label_positions)
ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize
# --- Axes Configuration ---
ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original
ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label
ax.set_ylim(bottom=0) # User's original y-axis bottom limit
# Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5
if num_total_bytes_from_scores > 0:
ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5)
else: # Handle case of no bytes (e.g. if chars was empty and scores was empty)
ax.set_xlim(left=-0.5, right=0.5)
# Spines (as per user's original code removing top and right)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Grid: User's original code did not explicitly add grid lines.
plt.tight_layout()
return fig