File size: 5,990 Bytes
41ea791
 
8eea094
 
41ea791
8eea094
 
 
 
 
 
 
41ea791
 
 
8eea094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad774a9
 
41ea791
8eea094
 
 
 
 
 
 
 
 
41ea791
8eea094
 
41ea791
8eea094
 
 
41ea791
8eea094
 
 
 
 
 
41ea791
8eea094
 
41ea791
 
 
8eea094
 
41ea791
2af55e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import numpy as np
import matplotlib.pyplot as plt
import os

def plot_entropies( # Renamed from plot_entropies_revised for final output
    patch_lengths: torch.Tensor,
    scores: torch.Tensor,
    tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption.
    chars: str,
    threshold: float
):
    patch_lengths_np = patch_lengths.cpu().numpy().flatten()
    scores_np = scores.cpu().float().numpy().flatten()

    num_total_bytes_from_scores = len(scores_np)

    # Prepare display string (prepend '<', replace spaces with '_')
    display_string_processed_chars = chars.replace(" ", "_")
    display_string = "<" + display_string_processed_chars
    display_chars_list = list(display_string)
    num_display_chars = len(display_chars_list)

    if num_display_chars == 0 and num_total_bytes_from_scores == 0:
        fig, ax = plt.subplots(figsize=(15,5))
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12)
        ax.set_xlabel("Characters (on underlying byte sequence)")
        ax.set_ylabel("Entropy of Next Byte")
        ax.set_ylim(bottom=0)
        ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot
        return fig
    elif num_display_chars == 0 and num_total_bytes_from_scores > 0:
        # Edge case: scores exist but no characters to map them to (implies an issue)
        # For now, proceed with byte plot but no char labels. Or raise error.
        # Assuming display_chars_list should not be empty if scores_np is not.
        # This case should ideally be caught by byte_counts_per_display_char validation if it were run.
        # If display_chars_list is truly empty but scores are not, an error should be raised by validation.
         pass # Will be caught by validation if sum(byte_counts) != len(scores)

    # Calculate byte counts for each character in the display string (assuming UTF-8)
    try:
        byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list]
    except UnicodeEncodeError as e:
        raise ValueError(
            f"Could not encode characters in 'chars' string using UTF-8. "
            f"Problematic part: '{display_string_processed_chars}'. Error: {e}"
        )

    # --- Validations ---
    if sum(byte_counts_per_display_char) != num_total_bytes_from_scores:
        # This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0
        raise ValueError(
            f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string "
            f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) "
            f"does not match length of scores tensor ({num_total_bytes_from_scores}). "
            f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence "
            f"represented by 'scores'/'tokens'."
        )

    if patch_lengths_np.sum() != num_total_bytes_from_scores:
        raise ValueError(
            f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
            f"does not match length of scores ({num_total_bytes_from_scores})."
        )

    # --- Plotting Setup ---
    fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested
    x_byte_indices = np.arange(num_total_bytes_from_scores)

    # --- Plot Scores (Horizontally per byte) ---
    # Original plot line style from user's code: marker='.', linestyle='-'
    ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte')

    # --- Plot Vertical Patch Boundary Lines ---
    # Using (cumulative_length - 0.5) logic for lines between byte elements.
    # This matches the intent of `boundary - 1 + 0.5` from user's original code snippet.
    patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np)
    for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data)
        ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1)

    # --- Horizontal Threshold Line and Annotation ---
    ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
    ax.annotate(f'Entropy Threshold', # Original text from user's code
        xy=(0.05, threshold),         # Original xy from user's code
        xytext=(0.05, threshold + 0.1),# Original xytext from user's code
        xycoords='axes fraction',      # Original xycoords
        textcoords='data',             # Original textcoords
        color='red'
    )

    # --- X-axis Ticks and Labels (Character labels at start of their byte sequences) ---
    char_label_positions = []
    char_labels_for_ticks = []
    current_byte_tracker = 0
    if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty
        for i_char in range(num_display_chars):
            char_label_positions.append(current_byte_tracker)
            char_labels_for_ticks.append(display_chars_list[i_char])
            current_byte_tracker += byte_counts_per_display_char[i_char]

    ax.set_xticks(char_label_positions)
    ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize

    # --- Axes Configuration ---
    ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original
    ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label

    ax.set_ylim(bottom=0) # User's original y-axis bottom limit
    # Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5
    if num_total_bytes_from_scores > 0:
        ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5)
    else: # Handle case of no bytes (e.g. if chars was empty and scores was empty)
        ax.set_xlim(left=-0.5, right=0.5)


    # Spines (as per user's original code removing top and right)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Grid: User's original code did not explicitly add grid lines.

    plt.tight_layout()
    return fig