Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
def plot_entropies( # Renamed from plot_entropies_revised for final output | |
patch_lengths: torch.Tensor, | |
scores: torch.Tensor, | |
tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption. | |
chars: str, | |
threshold: float | |
): | |
patch_lengths_np = patch_lengths.cpu().numpy().flatten() | |
scores_np = scores.cpu().float().numpy().flatten() | |
num_total_bytes_from_scores = len(scores_np) | |
# Prepare display string (prepend '<', replace spaces with '_') | |
display_string_processed_chars = chars.replace(" ", "_") | |
display_string = "<" + display_string_processed_chars | |
display_chars_list = list(display_string) | |
num_display_chars = len(display_chars_list) | |
if num_display_chars == 0 and num_total_bytes_from_scores == 0: | |
fig, ax = plt.subplots(figsize=(15,5)) | |
ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12) | |
ax.set_xlabel("Characters (on underlying byte sequence)") | |
ax.set_ylabel("Entropy of Next Byte") | |
ax.set_ylim(bottom=0) | |
ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot | |
return fig | |
elif num_display_chars == 0 and num_total_bytes_from_scores > 0: | |
# Edge case: scores exist but no characters to map them to (implies an issue) | |
# For now, proceed with byte plot but no char labels. Or raise error. | |
# Assuming display_chars_list should not be empty if scores_np is not. | |
# This case should ideally be caught by byte_counts_per_display_char validation if it were run. | |
# If display_chars_list is truly empty but scores are not, an error should be raised by validation. | |
pass # Will be caught by validation if sum(byte_counts) != len(scores) | |
# Calculate byte counts for each character in the display string (assuming UTF-8) | |
try: | |
byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list] | |
except UnicodeEncodeError as e: | |
raise ValueError( | |
f"Could not encode characters in 'chars' string using UTF-8. " | |
f"Problematic part: '{display_string_processed_chars}'. Error: {e}" | |
) | |
# --- Validations --- | |
if sum(byte_counts_per_display_char) != num_total_bytes_from_scores: | |
# This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0 | |
raise ValueError( | |
f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string " | |
f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) " | |
f"does not match length of scores tensor ({num_total_bytes_from_scores}). " | |
f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence " | |
f"represented by 'scores'/'tokens'." | |
) | |
if patch_lengths_np.sum() != num_total_bytes_from_scores: | |
raise ValueError( | |
f"Sum of patch_lengths ({patch_lengths_np.sum()}) " | |
f"does not match length of scores ({num_total_bytes_from_scores})." | |
) | |
# --- Plotting Setup --- | |
fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested | |
x_byte_indices = np.arange(num_total_bytes_from_scores) | |
# --- Plot Scores (Horizontally per byte) --- | |
# Original plot line style from user's code: marker='.', linestyle='-' | |
ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte') | |
# --- Plot Vertical Patch Boundary Lines --- | |
# Using (cumulative_length - 0.5) logic for lines between byte elements. | |
# This matches the intent of `boundary - 1 + 0.5` from user's original code snippet. | |
patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np) | |
for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data) | |
ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1) | |
# --- Horizontal Threshold Line and Annotation --- | |
ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1) | |
ax.annotate(f'Entropy Threshold', # Original text from user's code | |
xy=(0.05, threshold), # Original xy from user's code | |
xytext=(0.05, threshold + 0.1),# Original xytext from user's code | |
xycoords='axes fraction', # Original xycoords | |
textcoords='data', # Original textcoords | |
color='red' | |
) | |
# --- X-axis Ticks and Labels (Character labels at start of their byte sequences) --- | |
char_label_positions = [] | |
char_labels_for_ticks = [] | |
current_byte_tracker = 0 | |
if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty | |
for i_char in range(num_display_chars): | |
char_label_positions.append(current_byte_tracker) | |
char_labels_for_ticks.append(display_chars_list[i_char]) | |
current_byte_tracker += byte_counts_per_display_char[i_char] | |
ax.set_xticks(char_label_positions) | |
ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize | |
# --- Axes Configuration --- | |
ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original | |
ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label | |
ax.set_ylim(bottom=0) # User's original y-axis bottom limit | |
# Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5 | |
if num_total_bytes_from_scores > 0: | |
ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5) | |
else: # Handle case of no bytes (e.g. if chars was empty and scores was empty) | |
ax.set_xlim(left=-0.5, right=0.5) | |
# Spines (as per user's original code removing top and right) | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
# Grid: User's original code did not explicitly add grid lines. | |
plt.tight_layout() | |
return fig | |