Spaces:
Running
on
Zero
Running
on
Zero
Fixing proper UTF-8 representation
Browse files- app.py +154 -39
- bytelatent/plotting/entropy_figure_via_matplot_lib.py +105 -55
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import spaces
|
2 |
import math
|
3 |
import os
|
@@ -31,7 +32,7 @@ class Config:
|
|
31 |
|
32 |
# Bytelatent Specific
|
33 |
BLT_WEIGHTS_DIR: str = "hf-weights"
|
34 |
-
BLT_MAX_BYTES_FOR_DEMO:
|
35 |
|
36 |
# Gradio
|
37 |
DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
@@ -119,45 +120,141 @@ class BytelatentProcessor:
|
|
119 |
logging.warning("Skipping Bytelatent setup as libraries are unavailable.")
|
120 |
|
121 |
def _create_highlight_data(self, patch_lengths: torch.Tensor, tokens: torch.Tensor) -> Tuple[List[Tuple[str, str]], int]:
|
122 |
-
"""Generates data for gr.HighlightedText based on bytelatent patches
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
patch_lengths_list = patch_lengths.tolist()
|
127 |
-
all_token_ids = tokens.tolist()
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
patch_token_ids = all_token_ids[current_token_index : current_token_index + length]
|
135 |
-
if not patch_token_ids: continue
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
highlighted_data.append((patch_text, patch_label))
|
145 |
-
patch_count += 1
|
146 |
-
current_token_index += length
|
147 |
-
|
148 |
-
# Handle remainder tokens if any
|
149 |
-
if current_token_index < len(all_token_ids):
|
150 |
-
remaining_tokens = all_token_ids[current_token_index:]
|
151 |
-
try:
|
152 |
-
remaining_text = self.tokenizer.decode(remaining_tokens)
|
153 |
-
label = "BL Remainder"
|
154 |
-
except Exception:
|
155 |
-
remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
|
156 |
-
label = "Error"
|
157 |
-
highlighted_data.append((remaining_text, label))
|
158 |
-
logging.warning(f"Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_token_ids)}. Remainder added.")
|
159 |
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
def process(self, prompt: str, max_bytes: float) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
|
163 |
"""Processes the prompt using the loaded Bytelatent model."""
|
@@ -204,7 +301,13 @@ class BytelatentProcessor:
|
|
204 |
# Run Bytelatent patching
|
205 |
try:
|
206 |
logging.info(f"Running Bytelatent entropy model patching on {len(prompt_bl.encode('utf-8'))} bytes...")
|
207 |
-
results = patcher_nocache(
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
status += "Bytelatent patching executed.\n"
|
209 |
|
210 |
if not results:
|
@@ -216,7 +319,12 @@ class BytelatentProcessor:
|
|
216 |
patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
|
217 |
|
218 |
# Create highlighted text data
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
# Create plot
|
222 |
fig = None
|
@@ -228,7 +336,14 @@ class BytelatentProcessor:
|
|
228 |
logging.warning(f"Error decoding full BLT token sequence for plot: {decode_err}. Using (truncated) input prompt for plot axis.")
|
229 |
decoded_output_for_plot = prompt_bl
|
230 |
|
231 |
-
fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=self.patcher.threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
status += f"Bytelatent plot generated. Found {patch_count} patches.\n"
|
233 |
else:
|
234 |
status += "Plotting unavailable.\n"
|
@@ -418,7 +533,7 @@ with gr.Blocks(theme=Config.GRADIO_THEME) as iface:
|
|
418 |
placeholder="Enter text here...",
|
419 |
# Max length is for UI input; Bytelatent truncation happens in backend
|
420 |
lines=5,
|
421 |
-
info=
|
422 |
)
|
423 |
submit_button = gr.Button("Generate Visualizations", variant="primary")
|
424 |
status_output = gr.Textbox(label="Processing Status", interactive=False, lines=10) # More space for detailed status
|
|
|
1 |
+
from collections import defaultdict
|
2 |
import spaces
|
3 |
import math
|
4 |
import os
|
|
|
32 |
|
33 |
# Bytelatent Specific
|
34 |
BLT_WEIGHTS_DIR: str = "hf-weights"
|
35 |
+
BLT_MAX_BYTES_FOR_DEMO: int = 512
|
36 |
|
37 |
# Gradio
|
38 |
DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
|
|
120 |
logging.warning("Skipping Bytelatent setup as libraries are unavailable.")
|
121 |
|
122 |
def _create_highlight_data(self, patch_lengths: torch.Tensor, tokens: torch.Tensor) -> Tuple[List[Tuple[str, str]], int]:
|
123 |
+
"""Generates data for gr.HighlightedText based on bytelatent patches,
|
124 |
+
formatting each byte's display text as 'char-byte_index'."""
|
125 |
+
|
126 |
+
if not self.is_available or self.tokenizer is None:
|
127 |
+
return [("Bytelatent processing unavailable.", "Error")], 0
|
128 |
+
if patch_lengths.numel() == 0 and tokens.numel() == 0: # No data at all
|
129 |
+
return [("No tokens or patches.", "Info")], 0
|
130 |
+
if tokens.numel() == 0: # No tokens to process
|
131 |
+
# Count patches even if no tokens, as per original logic for patch_count
|
132 |
+
actual_patch_count = 0
|
133 |
+
for length in patch_lengths.tolist():
|
134 |
+
if length > 0:
|
135 |
+
actual_patch_count +=1
|
136 |
+
return [("No tokens provided to highlight.", "Info")], actual_patch_count
|
137 |
+
|
138 |
|
139 |
patch_lengths_list = patch_lengths.tolist()
|
140 |
+
all_token_ids = tokens.tolist() # These are byte representations (integer IDs)
|
141 |
+
|
142 |
+
highlighted_data: List[Tuple[str, str]] = []
|
143 |
+
|
144 |
+
# Calculate original patch count (number of non-empty patches)
|
145 |
+
actual_patch_count = 0
|
146 |
+
for length in patch_lengths_list:
|
147 |
+
if length > 0:
|
148 |
+
actual_patch_count +=1
|
149 |
+
|
150 |
+
# Create a map from global token index to its patch label
|
151 |
+
token_to_patch_label = [""] * len(all_token_ids)
|
152 |
+
current_token_processed_for_patches = 0
|
153 |
+
patch_idx_counter = 0
|
154 |
+
for length in patch_lengths_list:
|
155 |
+
if length <= 0:
|
156 |
+
continue
|
157 |
+
patch_label = f"BL Patch {patch_idx_counter + 1}"
|
158 |
+
patch_idx_counter += 1
|
159 |
+
for _ in range(length):
|
160 |
+
if current_token_processed_for_patches < len(all_token_ids):
|
161 |
+
token_to_patch_label[current_token_processed_for_patches] = patch_label
|
162 |
+
current_token_processed_for_patches += 1
|
163 |
+
|
164 |
+
# Handle remainder tokens label
|
165 |
+
if current_token_processed_for_patches < len(all_token_ids):
|
166 |
+
remainder_label = "BL Remainder"
|
167 |
+
logging.warning(
|
168 |
+
f"Bytelatent patch lengths sum ({current_token_processed_for_patches}) "
|
169 |
+
f"is less than total tokens ({len(all_token_ids)}). "
|
170 |
+
f"Remainder tokens will be labelled '{remainder_label}'."
|
171 |
+
)
|
172 |
+
for k in range(current_token_processed_for_patches, len(all_token_ids)):
|
173 |
+
token_to_patch_label[k] = remainder_label
|
174 |
+
elif current_token_processed_for_patches > len(all_token_ids) and len(all_token_ids) > 0 :
|
175 |
+
logging.warning(
|
176 |
+
f"Bytelatent patch lengths sum ({current_token_processed_for_patches}) "
|
177 |
+
f"exceeds total tokens ({len(all_token_ids)}). "
|
178 |
+
f"Patch label mapping might be affected."
|
179 |
+
)
|
180 |
+
|
181 |
+
global_token_idx = 0
|
182 |
+
while global_token_idx < len(all_token_ids):
|
183 |
+
char_representation = ""
|
184 |
+
decoded_byte_ids: List[int] = []
|
185 |
+
|
186 |
+
# Handle the special case for token ID 1, often representing '<' or similar
|
187 |
+
# This assumes token ID 1 should always be treated as a single character '<'.
|
188 |
+
# Adjust if your tokenizer handles ID 1 differently or if it can be part of a multi-byte sequence.
|
189 |
+
if all_token_ids[global_token_idx] == 1:
|
190 |
+
char_representation = "<" # As per user's original code snippet's implication
|
191 |
+
decoded_byte_ids = [1]
|
192 |
+
else:
|
193 |
+
# Iteratively try to decode a character (1 to 4 bytes for UTF-8)
|
194 |
+
for length_to_try in range(1, 5):
|
195 |
+
if global_token_idx + length_to_try > len(all_token_ids):
|
196 |
+
break # Not enough tokens left for this length
|
197 |
+
|
198 |
+
current_ids_to_try = all_token_ids[global_token_idx : global_token_idx + length_to_try]
|
199 |
+
|
200 |
+
try:
|
201 |
+
temp_decode_text = self.tokenizer.decode(current_ids_to_try)
|
202 |
+
|
203 |
+
if temp_decode_text: # Successfully decoded something
|
204 |
+
# This means `current_ids_to_try` forms a valid character(s).
|
205 |
+
# We take the first successful decode, assuming it's the shortest complete char.
|
206 |
+
char_representation = temp_decode_text
|
207 |
+
decoded_byte_ids = current_ids_to_try
|
208 |
+
break # Found a character
|
209 |
+
except Exception as e:
|
210 |
+
# Decoding failed (e.g., incomplete sequence for this length_to_try).
|
211 |
+
# Log this if it's unexpected for a particular tokenizer.
|
212 |
+
# logging.debug(f"Decode attempt failed for {current_ids_to_try}: {e}")
|
213 |
+
pass # Continue to try with more bytes.
|
214 |
+
|
215 |
+
# After trying to decode:
|
216 |
+
if char_representation and decoded_byte_ids:
|
217 |
+
num_bytes_in_char = len(decoded_byte_ids)
|
218 |
+
# Ensure char_representation is treated as a single conceptual unit here.
|
219 |
+
# If tokenizer.decode can return multiple characters for a short byte sequence,
|
220 |
+
# this might need adjustment. For UTF-8, one char is expected.
|
221 |
+
processed_char_text = char_representation.splitlines()[0] # Take first char if multiple, or clean up
|
222 |
+
|
223 |
+
for j in range(num_bytes_in_char):
|
224 |
+
current_byte_abs_idx = global_token_idx + j
|
225 |
+
# Boundary check, though loop structure should prevent out-of-bounds
|
226 |
+
if current_byte_abs_idx < len(all_token_ids):
|
227 |
+
label = token_to_patch_label[current_byte_abs_idx] if current_byte_abs_idx < len(token_to_patch_label) else "Error: Label Missing"
|
228 |
+
display_text = f"{processed_char_text}-{j+1}".replace(" ", "_")
|
229 |
+
highlighted_data.append((display_text, label))
|
230 |
+
else: # Should ideally not be reached
|
231 |
+
logging.error(f"Critical: Token index {current_byte_abs_idx} out of bounds for labeling.")
|
232 |
+
global_token_idx += num_bytes_in_char
|
233 |
+
else:
|
234 |
+
# Fallback: Could not form a character starting at global_token_idx.
|
235 |
+
# Treat the current byte as a standalone problematic byte.
|
236 |
+
current_byte_abs_idx = global_token_idx
|
237 |
+
label = token_to_patch_label[current_byte_abs_idx] if current_byte_abs_idx < len(token_to_patch_label) else "Error: Label Missing"
|
238 |
|
239 |
+
problem_byte_id = all_token_ids[current_byte_abs_idx]
|
240 |
+
display_text = f"err_byte({problem_byte_id})-1"
|
|
|
|
|
241 |
|
242 |
+
# Attempt to get a direct representation if tokenizer can provide one for the single byte
|
243 |
+
try:
|
244 |
+
single_byte_char_attempt = self.tokenizer.decode([problem_byte_id])
|
245 |
+
if single_byte_char_attempt and single_byte_char_attempt != "\ufffd": # Replacement char
|
246 |
+
display_text = f"{single_byte_char_attempt}-1"
|
247 |
+
except Exception:
|
248 |
+
pass # Stick with the err_byte display_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
+
highlighted_data.append((display_text.replace(" ", "_"), label))
|
251 |
+
logging.warning(
|
252 |
+
f"Token ID {problem_byte_id} at index {current_byte_abs_idx} "
|
253 |
+
f"could not be part of a validly decoded character using iterative decode. Fallback: '{display_text}'."
|
254 |
+
)
|
255 |
+
global_token_idx += 1
|
256 |
+
|
257 |
+
return highlighted_data, actual_patch_count
|
258 |
|
259 |
def process(self, prompt: str, max_bytes: float) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
|
260 |
"""Processes the prompt using the loaded Bytelatent model."""
|
|
|
301 |
# Run Bytelatent patching
|
302 |
try:
|
303 |
logging.info(f"Running Bytelatent entropy model patching on {len(prompt_bl.encode('utf-8'))} bytes...")
|
304 |
+
results = patcher_nocache(
|
305 |
+
[prompt_bl],
|
306 |
+
tokenizer=self.tokenizer,
|
307 |
+
patcher=self.patcher,
|
308 |
+
max_prompt_len=512,
|
309 |
+
max_gen_len=256,
|
310 |
+
)
|
311 |
status += "Bytelatent patching executed.\n"
|
312 |
|
313 |
if not results:
|
|
|
319 |
patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
|
320 |
|
321 |
# Create highlighted text data
|
322 |
+
_highlighted_data, patch_count = self._create_highlight_data(patch_lengths, tokens)
|
323 |
+
ind_highlighted_data = [(text.replace("-1", ""), label) for text, label in _highlighted_data]
|
324 |
+
grouped_data = defaultdict(str)
|
325 |
+
for text, label in ind_highlighted_data:
|
326 |
+
grouped_data[label] += text
|
327 |
+
highlighted_data = [(text, label) for label, text in grouped_data.items()]
|
328 |
|
329 |
# Create plot
|
330 |
fig = None
|
|
|
336 |
logging.warning(f"Error decoding full BLT token sequence for plot: {decode_err}. Using (truncated) input prompt for plot axis.")
|
337 |
decoded_output_for_plot = prompt_bl
|
338 |
|
339 |
+
# fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=self.patcher.threshold)
|
340 |
+
fig = plot_entropies(
|
341 |
+
patch_lengths,
|
342 |
+
scores,
|
343 |
+
tokens,
|
344 |
+
chars=decoded_output_for_plot,
|
345 |
+
threshold=self.patcher.threshold
|
346 |
+
)
|
347 |
status += f"Bytelatent plot generated. Found {patch_count} patches.\n"
|
348 |
else:
|
349 |
status += "Plotting unavailable.\n"
|
|
|
533 |
placeholder="Enter text here...",
|
534 |
# Max length is for UI input; Bytelatent truncation happens in backend
|
535 |
lines=5,
|
536 |
+
info=f"Note: Entropy-based Patcher processing is limited to {Config.BLT_MAX_BYTES_FOR_DEMO} bytes for this demo."
|
537 |
)
|
538 |
submit_button = gr.Button("Generate Visualizations", variant="primary")
|
539 |
status_output = gr.Textbox(label="Processing Status", interactive=False, lines=10) # More space for detailed status
|
bytelatent/plotting/entropy_figure_via_matplot_lib.py
CHANGED
@@ -1,73 +1,123 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
import numpy as np
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
patch_lengths_np = patch_lengths.cpu().numpy().flatten()
|
9 |
scores_np = scores.cpu().float().numpy().flatten()
|
10 |
-
chars = chars.replace(" ", "_")
|
11 |
-
tokens_np = np.array([char for char in "<"+chars])
|
12 |
-
|
13 |
-
if len(scores_np) != len(tokens_np):
|
14 |
-
raise ValueError("Length of scores and tokens tensors must be the same.")
|
15 |
-
if patch_lengths_np.sum() != len(tokens_np):
|
16 |
-
raise ValueError(f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
|
17 |
-
f"does not match the length of tokens/scores ({len(tokens_np)}).")
|
18 |
-
|
19 |
-
|
20 |
-
x_indices = np.arange(len(tokens_np))
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
ax.
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
color='red'
|
45 |
)
|
46 |
|
47 |
-
#
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
ax.set_ylabel("Entropy of Next Byte", fontsize=12)
|
54 |
-
ax.set_xlabel("Tokens", fontsize=12)
|
55 |
|
56 |
-
#
|
57 |
-
ax.
|
58 |
-
ax.
|
59 |
|
60 |
-
#
|
61 |
-
#
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
|
|
64 |
ax.spines['top'].set_visible(False)
|
65 |
ax.spines['right'].set_visible(False)
|
66 |
|
67 |
-
#
|
|
|
68 |
plt.tight_layout()
|
69 |
return fig
|
70 |
-
# output_filename = "token_score_plot.png"
|
71 |
-
# fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
|
72 |
-
# print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
|
73 |
-
|
|
|
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import os
|
5 |
|
6 |
+
def plot_entropies( # Renamed from plot_entropies_revised for final output
|
7 |
+
patch_lengths: torch.Tensor,
|
8 |
+
scores: torch.Tensor,
|
9 |
+
tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption.
|
10 |
+
chars: str,
|
11 |
+
threshold: float
|
12 |
+
):
|
13 |
patch_lengths_np = patch_lengths.cpu().numpy().flatten()
|
14 |
scores_np = scores.cpu().float().numpy().flatten()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
num_total_bytes_from_scores = len(scores_np)
|
17 |
+
|
18 |
+
# Prepare display string (prepend '<', replace spaces with '_')
|
19 |
+
display_string_processed_chars = chars.replace(" ", "_")
|
20 |
+
display_string = "<" + display_string_processed_chars
|
21 |
+
display_chars_list = list(display_string)
|
22 |
+
num_display_chars = len(display_chars_list)
|
23 |
+
|
24 |
+
if num_display_chars == 0 and num_total_bytes_from_scores == 0:
|
25 |
+
fig, ax = plt.subplots(figsize=(15,5))
|
26 |
+
ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12)
|
27 |
+
ax.set_xlabel("Characters (on underlying byte sequence)")
|
28 |
+
ax.set_ylabel("Entropy of Next Byte")
|
29 |
+
ax.set_ylim(bottom=0)
|
30 |
+
ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot
|
31 |
+
return fig
|
32 |
+
elif num_display_chars == 0 and num_total_bytes_from_scores > 0:
|
33 |
+
# Edge case: scores exist but no characters to map them to (implies an issue)
|
34 |
+
# For now, proceed with byte plot but no char labels. Or raise error.
|
35 |
+
# Assuming display_chars_list should not be empty if scores_np is not.
|
36 |
+
# This case should ideally be caught by byte_counts_per_display_char validation if it were run.
|
37 |
+
# If display_chars_list is truly empty but scores are not, an error should be raised by validation.
|
38 |
+
pass # Will be caught by validation if sum(byte_counts) != len(scores)
|
39 |
+
|
40 |
+
# Calculate byte counts for each character in the display string (assuming UTF-8)
|
41 |
+
try:
|
42 |
+
byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list]
|
43 |
+
except UnicodeEncodeError as e:
|
44 |
+
raise ValueError(
|
45 |
+
f"Could not encode characters in 'chars' string using UTF-8. "
|
46 |
+
f"Problematic part: '{display_string_processed_chars}'. Error: {e}"
|
47 |
+
)
|
48 |
+
|
49 |
+
# --- Validations ---
|
50 |
+
if sum(byte_counts_per_display_char) != num_total_bytes_from_scores:
|
51 |
+
# This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0
|
52 |
+
raise ValueError(
|
53 |
+
f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string "
|
54 |
+
f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) "
|
55 |
+
f"does not match length of scores tensor ({num_total_bytes_from_scores}). "
|
56 |
+
f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence "
|
57 |
+
f"represented by 'scores'/'tokens'."
|
58 |
+
)
|
59 |
+
|
60 |
+
if patch_lengths_np.sum() != num_total_bytes_from_scores:
|
61 |
+
raise ValueError(
|
62 |
+
f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
|
63 |
+
f"does not match length of scores ({num_total_bytes_from_scores})."
|
64 |
+
)
|
65 |
+
|
66 |
+
# --- Plotting Setup ---
|
67 |
+
fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested
|
68 |
+
x_byte_indices = np.arange(num_total_bytes_from_scores)
|
69 |
+
|
70 |
+
# --- Plot Scores (Horizontally per byte) ---
|
71 |
+
# Original plot line style from user's code: marker='.', linestyle='-'
|
72 |
+
ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte')
|
73 |
+
|
74 |
+
# --- Plot Vertical Patch Boundary Lines ---
|
75 |
+
# Using (cumulative_length - 0.5) logic for lines between byte elements.
|
76 |
+
# This matches the intent of `boundary - 1 + 0.5` from user's original code snippet.
|
77 |
+
patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np)
|
78 |
+
for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data)
|
79 |
+
ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1)
|
80 |
+
|
81 |
+
# --- Horizontal Threshold Line and Annotation ---
|
82 |
+
ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
|
83 |
+
ax.annotate(f'Entropy Threshold', # Original text from user's code
|
84 |
+
xy=(0.05, threshold), # Original xy from user's code
|
85 |
+
xytext=(0.05, threshold + 0.1),# Original xytext from user's code
|
86 |
+
xycoords='axes fraction', # Original xycoords
|
87 |
+
textcoords='data', # Original textcoords
|
88 |
color='red'
|
89 |
)
|
90 |
|
91 |
+
# --- X-axis Ticks and Labels (Character labels at start of their byte sequences) ---
|
92 |
+
char_label_positions = []
|
93 |
+
char_labels_for_ticks = []
|
94 |
+
current_byte_tracker = 0
|
95 |
+
if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty
|
96 |
+
for i_char in range(num_display_chars):
|
97 |
+
char_label_positions.append(current_byte_tracker)
|
98 |
+
char_labels_for_ticks.append(display_chars_list[i_char])
|
99 |
+
current_byte_tracker += byte_counts_per_display_char[i_char]
|
100 |
|
101 |
+
ax.set_xticks(char_label_positions)
|
102 |
+
ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize
|
|
|
|
|
103 |
|
104 |
+
# --- Axes Configuration ---
|
105 |
+
ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original
|
106 |
+
ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label
|
107 |
|
108 |
+
ax.set_ylim(bottom=0) # User's original y-axis bottom limit
|
109 |
+
# Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5
|
110 |
+
if num_total_bytes_from_scores > 0:
|
111 |
+
ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5)
|
112 |
+
else: # Handle case of no bytes (e.g. if chars was empty and scores was empty)
|
113 |
+
ax.set_xlim(left=-0.5, right=0.5)
|
114 |
|
115 |
+
|
116 |
+
# Spines (as per user's original code removing top and right)
|
117 |
ax.spines['top'].set_visible(False)
|
118 |
ax.spines['right'].set_visible(False)
|
119 |
|
120 |
+
# Grid: User's original code did not explicitly add grid lines.
|
121 |
+
|
122 |
plt.tight_layout()
|
123 |
return fig
|
|
|
|
|
|
|
|