File size: 2,993 Bytes
41ea791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad774a9
 
 
 
 
 
 
 
41ea791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2af55e5
 
 
 
41ea791
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
import matplotlib.pyplot as plt
import numpy as np


def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str, threshold: float):
    patch_lengths_np = patch_lengths.cpu().numpy().flatten()
    scores_np = scores.cpu().float().numpy().flatten()
    chars = chars.replace(" ", "_")
    tokens_np = np.array([char for char in "<"+chars])

    if len(scores_np) != len(tokens_np):
        raise ValueError("Length of scores and tokens tensors must be the same.")
    if patch_lengths_np.sum() != len(tokens_np):
        raise ValueError(f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
                        f"does not match the length of tokens/scores ({len(tokens_np)}).")


    x_indices = np.arange(len(tokens_np))

    # Calculate cumulative sums of patch lengths for vertical line positions
    # These indicate the *end* index of each patch
    patch_boundaries = np.cumsum(patch_lengths_np)

    # --- Plotting ---
    fig, ax = plt.subplots(figsize=(15, 5)) # Adjust figure size as needed

    # Plot the scores as a blue line with markers
    ax.plot(x_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores')

    # Plot the vertical dotted lines at the patch boundaries
    # We plot a line *after* each patch, so at index `boundary - 1 + 0.5`
    # We skip the last boundary as it's the end of the data
    for boundary in patch_boundaries[:-1]:
        ax.axvline(x=boundary, color='grey', linestyle='--', linewidth=1)

    ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1, )
    ax.annotate(f'Entropy Threshold',
        xy=(0.05, threshold),  # Position of the line
        xytext=(0.05, threshold + 0.1),  # Text position
        xycoords='axes fraction',  # Use axes coordinates (0-1)
        textcoords='data',  # Use data coordinates for text
        color='red'
    )

    # Set x-axis ticks and labels
    ax.set_xticks(x_indices)
    ax.set_xticklabels(tokens_np, rotation=0, fontsize=8) # Rotate labels for better readability

    # Set labels for axes
    # Using the Y-axis label from the example image
    ax.set_ylabel("Entropy of Next Byte", fontsize=12)
    ax.set_xlabel("Tokens", fontsize=12)

    # Set y-axis limits (optional, but often good practice)
    ax.set_ylim(bottom=0) # Start y-axis at 0 like the example
    ax.set_xlim(left = x_indices[0]-1.0, right = x_indices[-1]+1.0) # Add padding to x-axis

    # Add grid lines (optional)
    # ax.grid(True, axis='y', linestyle=':', color='lightgrey')

    # Remove the top and right spines for cleaner look (optional)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Adjust layout and display the plot
    plt.tight_layout()
    return fig
    #  output_filename = "token_score_plot.png"
    #  fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
    #  print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path