import torch import numpy as np import matplotlib.pyplot as plt import os def plot_entropies( # Renamed from plot_entropies_revised for final output patch_lengths: torch.Tensor, scores: torch.Tensor, tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption. chars: str, threshold: float ): patch_lengths_np = patch_lengths.cpu().numpy().flatten() scores_np = scores.cpu().float().numpy().flatten() num_total_bytes_from_scores = len(scores_np) # Prepare display string (prepend '<', replace spaces with '_') display_string_processed_chars = chars.replace(" ", "_") display_string = "<" + display_string_processed_chars display_chars_list = list(display_string) num_display_chars = len(display_chars_list) if num_display_chars == 0 and num_total_bytes_from_scores == 0: fig, ax = plt.subplots(figsize=(15,5)) ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12) ax.set_xlabel("Characters (on underlying byte sequence)") ax.set_ylabel("Entropy of Next Byte") ax.set_ylim(bottom=0) ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot return fig elif num_display_chars == 0 and num_total_bytes_from_scores > 0: # Edge case: scores exist but no characters to map them to (implies an issue) # For now, proceed with byte plot but no char labels. Or raise error. # Assuming display_chars_list should not be empty if scores_np is not. # This case should ideally be caught by byte_counts_per_display_char validation if it were run. # If display_chars_list is truly empty but scores are not, an error should be raised by validation. pass # Will be caught by validation if sum(byte_counts) != len(scores) # Calculate byte counts for each character in the display string (assuming UTF-8) try: byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list] except UnicodeEncodeError as e: raise ValueError( f"Could not encode characters in 'chars' string using UTF-8. " f"Problematic part: '{display_string_processed_chars}'. Error: {e}" ) # --- Validations --- if sum(byte_counts_per_display_char) != num_total_bytes_from_scores: # This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0 raise ValueError( f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string " f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) " f"does not match length of scores tensor ({num_total_bytes_from_scores}). " f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence " f"represented by 'scores'/'tokens'." ) if patch_lengths_np.sum() != num_total_bytes_from_scores: raise ValueError( f"Sum of patch_lengths ({patch_lengths_np.sum()}) " f"does not match length of scores ({num_total_bytes_from_scores})." ) # --- Plotting Setup --- fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested x_byte_indices = np.arange(num_total_bytes_from_scores) # --- Plot Scores (Horizontally per byte) --- # Original plot line style from user's code: marker='.', linestyle='-' ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte') # --- Plot Vertical Patch Boundary Lines --- # Using (cumulative_length - 0.5) logic for lines between byte elements. # This matches the intent of `boundary - 1 + 0.5` from user's original code snippet. patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np) for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data) ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1) # --- Horizontal Threshold Line and Annotation --- ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1) ax.annotate(f'Entropy Threshold', # Original text from user's code xy=(0.05, threshold), # Original xy from user's code xytext=(0.05, threshold + 0.1),# Original xytext from user's code xycoords='axes fraction', # Original xycoords textcoords='data', # Original textcoords color='red' ) # --- X-axis Ticks and Labels (Character labels at start of their byte sequences) --- char_label_positions = [] char_labels_for_ticks = [] current_byte_tracker = 0 if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty for i_char in range(num_display_chars): char_label_positions.append(current_byte_tracker) char_labels_for_ticks.append(display_chars_list[i_char]) current_byte_tracker += byte_counts_per_display_char[i_char] ax.set_xticks(char_label_positions) ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize # --- Axes Configuration --- ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label ax.set_ylim(bottom=0) # User's original y-axis bottom limit # Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5 if num_total_bytes_from_scores > 0: ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5) else: # Handle case of no bytes (e.g. if chars was empty and scores was empty) ax.set_xlim(left=-0.5, right=0.5) # Spines (as per user's original code removing top and right) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) # Grid: User's original code did not explicitly add grid lines. plt.tight_layout() return fig