Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,990 Bytes
41ea791 8eea094 41ea791 8eea094 41ea791 8eea094 ad774a9 41ea791 8eea094 41ea791 8eea094 41ea791 8eea094 41ea791 8eea094 41ea791 8eea094 41ea791 8eea094 41ea791 2af55e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
def plot_entropies( # Renamed from plot_entropies_revised for final output
patch_lengths: torch.Tensor,
scores: torch.Tensor,
tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption.
chars: str,
threshold: float
):
patch_lengths_np = patch_lengths.cpu().numpy().flatten()
scores_np = scores.cpu().float().numpy().flatten()
num_total_bytes_from_scores = len(scores_np)
# Prepare display string (prepend '<', replace spaces with '_')
display_string_processed_chars = chars.replace(" ", "_")
display_string = "<" + display_string_processed_chars
display_chars_list = list(display_string)
num_display_chars = len(display_chars_list)
if num_display_chars == 0 and num_total_bytes_from_scores == 0:
fig, ax = plt.subplots(figsize=(15,5))
ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12)
ax.set_xlabel("Characters (on underlying byte sequence)")
ax.set_ylabel("Entropy of Next Byte")
ax.set_ylim(bottom=0)
ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot
return fig
elif num_display_chars == 0 and num_total_bytes_from_scores > 0:
# Edge case: scores exist but no characters to map them to (implies an issue)
# For now, proceed with byte plot but no char labels. Or raise error.
# Assuming display_chars_list should not be empty if scores_np is not.
# This case should ideally be caught by byte_counts_per_display_char validation if it were run.
# If display_chars_list is truly empty but scores are not, an error should be raised by validation.
pass # Will be caught by validation if sum(byte_counts) != len(scores)
# Calculate byte counts for each character in the display string (assuming UTF-8)
try:
byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list]
except UnicodeEncodeError as e:
raise ValueError(
f"Could not encode characters in 'chars' string using UTF-8. "
f"Problematic part: '{display_string_processed_chars}'. Error: {e}"
)
# --- Validations ---
if sum(byte_counts_per_display_char) != num_total_bytes_from_scores:
# This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0
raise ValueError(
f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string "
f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) "
f"does not match length of scores tensor ({num_total_bytes_from_scores}). "
f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence "
f"represented by 'scores'/'tokens'."
)
if patch_lengths_np.sum() != num_total_bytes_from_scores:
raise ValueError(
f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
f"does not match length of scores ({num_total_bytes_from_scores})."
)
# --- Plotting Setup ---
fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested
x_byte_indices = np.arange(num_total_bytes_from_scores)
# --- Plot Scores (Horizontally per byte) ---
# Original plot line style from user's code: marker='.', linestyle='-'
ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte')
# --- Plot Vertical Patch Boundary Lines ---
# Using (cumulative_length - 0.5) logic for lines between byte elements.
# This matches the intent of `boundary - 1 + 0.5` from user's original code snippet.
patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np)
for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data)
ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1)
# --- Horizontal Threshold Line and Annotation ---
ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
ax.annotate(f'Entropy Threshold', # Original text from user's code
xy=(0.05, threshold), # Original xy from user's code
xytext=(0.05, threshold + 0.1),# Original xytext from user's code
xycoords='axes fraction', # Original xycoords
textcoords='data', # Original textcoords
color='red'
)
# --- X-axis Ticks and Labels (Character labels at start of their byte sequences) ---
char_label_positions = []
char_labels_for_ticks = []
current_byte_tracker = 0
if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty
for i_char in range(num_display_chars):
char_label_positions.append(current_byte_tracker)
char_labels_for_ticks.append(display_chars_list[i_char])
current_byte_tracker += byte_counts_per_display_char[i_char]
ax.set_xticks(char_label_positions)
ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize
# --- Axes Configuration ---
ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original
ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label
ax.set_ylim(bottom=0) # User's original y-axis bottom limit
# Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5
if num_total_bytes_from_scores > 0:
ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5)
else: # Handle case of no bytes (e.g. if chars was empty and scores was empty)
ax.set_xlim(left=-0.5, right=0.5)
# Spines (as per user's original code removing top and right)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Grid: User's original code did not explicitly add grid lines.
plt.tight_layout()
return fig
|