lucalp commited on
Commit
8eea094
·
1 Parent(s): 5c58fd6

Fixing proper UTF-8 representation

Browse files
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spaces
2
  import math
3
  import os
@@ -31,7 +32,7 @@ class Config:
31
 
32
  # Bytelatent Specific
33
  BLT_WEIGHTS_DIR: str = "hf-weights"
34
- BLT_MAX_BYTES_FOR_DEMO: float = math.inf if torch.cuda.is_available() else 512.0
35
 
36
  # Gradio
37
  DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
@@ -119,45 +120,141 @@ class BytelatentProcessor:
119
  logging.warning("Skipping Bytelatent setup as libraries are unavailable.")
120
 
121
  def _create_highlight_data(self, patch_lengths: torch.Tensor, tokens: torch.Tensor) -> Tuple[List[Tuple[str, str]], int]:
122
- """Generates data for gr.HighlightedText based on bytelatent patches."""
123
- if not self.is_available or self.tokenizer is None or patch_lengths.numel() == 0:
124
- return [("Bytelatent processing failed or produced no patches.", "Error")], 0
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  patch_lengths_list = patch_lengths.tolist()
127
- all_token_ids = tokens.tolist()
128
- highlighted_data = []
129
- current_token_index = 0
130
- patch_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- for i, length in enumerate(patch_lengths_list):
133
- if length <= 0: continue
134
- patch_token_ids = all_token_ids[current_token_index : current_token_index + length]
135
- if not patch_token_ids: continue
136
 
137
- try:
138
- patch_text = self.tokenizer.decode(patch_token_ids)
139
- except Exception as decode_err:
140
- logging.warning(f"Bytelatent patch decoding failed: {decode_err}")
141
- patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
142
-
143
- patch_label = f"BL Patch {i+1}"
144
- highlighted_data.append((patch_text, patch_label))
145
- patch_count += 1
146
- current_token_index += length
147
-
148
- # Handle remainder tokens if any
149
- if current_token_index < len(all_token_ids):
150
- remaining_tokens = all_token_ids[current_token_index:]
151
- try:
152
- remaining_text = self.tokenizer.decode(remaining_tokens)
153
- label = "BL Remainder"
154
- except Exception:
155
- remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
156
- label = "Error"
157
- highlighted_data.append((remaining_text, label))
158
- logging.warning(f"Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_token_ids)}. Remainder added.")
159
 
160
- return highlighted_data, patch_count
 
 
 
 
 
 
 
161
 
162
  def process(self, prompt: str, max_bytes: float) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
163
  """Processes the prompt using the loaded Bytelatent model."""
@@ -204,7 +301,13 @@ class BytelatentProcessor:
204
  # Run Bytelatent patching
205
  try:
206
  logging.info(f"Running Bytelatent entropy model patching on {len(prompt_bl.encode('utf-8'))} bytes...")
207
- results = patcher_nocache([prompt_bl], tokenizer=self.tokenizer, patcher=self.patcher)
 
 
 
 
 
 
208
  status += "Bytelatent patching executed.\n"
209
 
210
  if not results:
@@ -216,7 +319,12 @@ class BytelatentProcessor:
216
  patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
217
 
218
  # Create highlighted text data
219
- highlighted_data, patch_count = self._create_highlight_data(patch_lengths, tokens)
 
 
 
 
 
220
 
221
  # Create plot
222
  fig = None
@@ -228,7 +336,14 @@ class BytelatentProcessor:
228
  logging.warning(f"Error decoding full BLT token sequence for plot: {decode_err}. Using (truncated) input prompt for plot axis.")
229
  decoded_output_for_plot = prompt_bl
230
 
231
- fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=self.patcher.threshold)
 
 
 
 
 
 
 
232
  status += f"Bytelatent plot generated. Found {patch_count} patches.\n"
233
  else:
234
  status += "Plotting unavailable.\n"
@@ -418,7 +533,7 @@ with gr.Blocks(theme=Config.GRADIO_THEME) as iface:
418
  placeholder="Enter text here...",
419
  # Max length is for UI input; Bytelatent truncation happens in backend
420
  lines=5,
421
- info="" if torch.cuda.is_available() else f"Note: Bytelatent processing is limited to ~{Config.BLT_MAX_BYTES_FOR_DEMO} bytes for this demo."
422
  )
423
  submit_button = gr.Button("Generate Visualizations", variant="primary")
424
  status_output = gr.Textbox(label="Processing Status", interactive=False, lines=10) # More space for detailed status
 
1
+ from collections import defaultdict
2
  import spaces
3
  import math
4
  import os
 
32
 
33
  # Bytelatent Specific
34
  BLT_WEIGHTS_DIR: str = "hf-weights"
35
+ BLT_MAX_BYTES_FOR_DEMO: int = 512
36
 
37
  # Gradio
38
  DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
 
120
  logging.warning("Skipping Bytelatent setup as libraries are unavailable.")
121
 
122
  def _create_highlight_data(self, patch_lengths: torch.Tensor, tokens: torch.Tensor) -> Tuple[List[Tuple[str, str]], int]:
123
+ """Generates data for gr.HighlightedText based on bytelatent patches,
124
+ formatting each byte's display text as 'char-byte_index'."""
125
+
126
+ if not self.is_available or self.tokenizer is None:
127
+ return [("Bytelatent processing unavailable.", "Error")], 0
128
+ if patch_lengths.numel() == 0 and tokens.numel() == 0: # No data at all
129
+ return [("No tokens or patches.", "Info")], 0
130
+ if tokens.numel() == 0: # No tokens to process
131
+ # Count patches even if no tokens, as per original logic for patch_count
132
+ actual_patch_count = 0
133
+ for length in patch_lengths.tolist():
134
+ if length > 0:
135
+ actual_patch_count +=1
136
+ return [("No tokens provided to highlight.", "Info")], actual_patch_count
137
+
138
 
139
  patch_lengths_list = patch_lengths.tolist()
140
+ all_token_ids = tokens.tolist() # These are byte representations (integer IDs)
141
+
142
+ highlighted_data: List[Tuple[str, str]] = []
143
+
144
+ # Calculate original patch count (number of non-empty patches)
145
+ actual_patch_count = 0
146
+ for length in patch_lengths_list:
147
+ if length > 0:
148
+ actual_patch_count +=1
149
+
150
+ # Create a map from global token index to its patch label
151
+ token_to_patch_label = [""] * len(all_token_ids)
152
+ current_token_processed_for_patches = 0
153
+ patch_idx_counter = 0
154
+ for length in patch_lengths_list:
155
+ if length <= 0:
156
+ continue
157
+ patch_label = f"BL Patch {patch_idx_counter + 1}"
158
+ patch_idx_counter += 1
159
+ for _ in range(length):
160
+ if current_token_processed_for_patches < len(all_token_ids):
161
+ token_to_patch_label[current_token_processed_for_patches] = patch_label
162
+ current_token_processed_for_patches += 1
163
+
164
+ # Handle remainder tokens label
165
+ if current_token_processed_for_patches < len(all_token_ids):
166
+ remainder_label = "BL Remainder"
167
+ logging.warning(
168
+ f"Bytelatent patch lengths sum ({current_token_processed_for_patches}) "
169
+ f"is less than total tokens ({len(all_token_ids)}). "
170
+ f"Remainder tokens will be labelled '{remainder_label}'."
171
+ )
172
+ for k in range(current_token_processed_for_patches, len(all_token_ids)):
173
+ token_to_patch_label[k] = remainder_label
174
+ elif current_token_processed_for_patches > len(all_token_ids) and len(all_token_ids) > 0 :
175
+ logging.warning(
176
+ f"Bytelatent patch lengths sum ({current_token_processed_for_patches}) "
177
+ f"exceeds total tokens ({len(all_token_ids)}). "
178
+ f"Patch label mapping might be affected."
179
+ )
180
+
181
+ global_token_idx = 0
182
+ while global_token_idx < len(all_token_ids):
183
+ char_representation = ""
184
+ decoded_byte_ids: List[int] = []
185
+
186
+ # Handle the special case for token ID 1, often representing '<' or similar
187
+ # This assumes token ID 1 should always be treated as a single character '<'.
188
+ # Adjust if your tokenizer handles ID 1 differently or if it can be part of a multi-byte sequence.
189
+ if all_token_ids[global_token_idx] == 1:
190
+ char_representation = "<" # As per user's original code snippet's implication
191
+ decoded_byte_ids = [1]
192
+ else:
193
+ # Iteratively try to decode a character (1 to 4 bytes for UTF-8)
194
+ for length_to_try in range(1, 5):
195
+ if global_token_idx + length_to_try > len(all_token_ids):
196
+ break # Not enough tokens left for this length
197
+
198
+ current_ids_to_try = all_token_ids[global_token_idx : global_token_idx + length_to_try]
199
+
200
+ try:
201
+ temp_decode_text = self.tokenizer.decode(current_ids_to_try)
202
+
203
+ if temp_decode_text: # Successfully decoded something
204
+ # This means `current_ids_to_try` forms a valid character(s).
205
+ # We take the first successful decode, assuming it's the shortest complete char.
206
+ char_representation = temp_decode_text
207
+ decoded_byte_ids = current_ids_to_try
208
+ break # Found a character
209
+ except Exception as e:
210
+ # Decoding failed (e.g., incomplete sequence for this length_to_try).
211
+ # Log this if it's unexpected for a particular tokenizer.
212
+ # logging.debug(f"Decode attempt failed for {current_ids_to_try}: {e}")
213
+ pass # Continue to try with more bytes.
214
+
215
+ # After trying to decode:
216
+ if char_representation and decoded_byte_ids:
217
+ num_bytes_in_char = len(decoded_byte_ids)
218
+ # Ensure char_representation is treated as a single conceptual unit here.
219
+ # If tokenizer.decode can return multiple characters for a short byte sequence,
220
+ # this might need adjustment. For UTF-8, one char is expected.
221
+ processed_char_text = char_representation.splitlines()[0] # Take first char if multiple, or clean up
222
+
223
+ for j in range(num_bytes_in_char):
224
+ current_byte_abs_idx = global_token_idx + j
225
+ # Boundary check, though loop structure should prevent out-of-bounds
226
+ if current_byte_abs_idx < len(all_token_ids):
227
+ label = token_to_patch_label[current_byte_abs_idx] if current_byte_abs_idx < len(token_to_patch_label) else "Error: Label Missing"
228
+ display_text = f"{processed_char_text}-{j+1}".replace(" ", "_")
229
+ highlighted_data.append((display_text, label))
230
+ else: # Should ideally not be reached
231
+ logging.error(f"Critical: Token index {current_byte_abs_idx} out of bounds for labeling.")
232
+ global_token_idx += num_bytes_in_char
233
+ else:
234
+ # Fallback: Could not form a character starting at global_token_idx.
235
+ # Treat the current byte as a standalone problematic byte.
236
+ current_byte_abs_idx = global_token_idx
237
+ label = token_to_patch_label[current_byte_abs_idx] if current_byte_abs_idx < len(token_to_patch_label) else "Error: Label Missing"
238
 
239
+ problem_byte_id = all_token_ids[current_byte_abs_idx]
240
+ display_text = f"err_byte({problem_byte_id})-1"
 
 
241
 
242
+ # Attempt to get a direct representation if tokenizer can provide one for the single byte
243
+ try:
244
+ single_byte_char_attempt = self.tokenizer.decode([problem_byte_id])
245
+ if single_byte_char_attempt and single_byte_char_attempt != "\ufffd": # Replacement char
246
+ display_text = f"{single_byte_char_attempt}-1"
247
+ except Exception:
248
+ pass # Stick with the err_byte display_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ highlighted_data.append((display_text.replace(" ", "_"), label))
251
+ logging.warning(
252
+ f"Token ID {problem_byte_id} at index {current_byte_abs_idx} "
253
+ f"could not be part of a validly decoded character using iterative decode. Fallback: '{display_text}'."
254
+ )
255
+ global_token_idx += 1
256
+
257
+ return highlighted_data, actual_patch_count
258
 
259
  def process(self, prompt: str, max_bytes: float) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
260
  """Processes the prompt using the loaded Bytelatent model."""
 
301
  # Run Bytelatent patching
302
  try:
303
  logging.info(f"Running Bytelatent entropy model patching on {len(prompt_bl.encode('utf-8'))} bytes...")
304
+ results = patcher_nocache(
305
+ [prompt_bl],
306
+ tokenizer=self.tokenizer,
307
+ patcher=self.patcher,
308
+ max_prompt_len=512,
309
+ max_gen_len=256,
310
+ )
311
  status += "Bytelatent patching executed.\n"
312
 
313
  if not results:
 
319
  patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
320
 
321
  # Create highlighted text data
322
+ _highlighted_data, patch_count = self._create_highlight_data(patch_lengths, tokens)
323
+ ind_highlighted_data = [(text.replace("-1", ""), label) for text, label in _highlighted_data]
324
+ grouped_data = defaultdict(str)
325
+ for text, label in ind_highlighted_data:
326
+ grouped_data[label] += text
327
+ highlighted_data = [(text, label) for label, text in grouped_data.items()]
328
 
329
  # Create plot
330
  fig = None
 
336
  logging.warning(f"Error decoding full BLT token sequence for plot: {decode_err}. Using (truncated) input prompt for plot axis.")
337
  decoded_output_for_plot = prompt_bl
338
 
339
+ # fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=self.patcher.threshold)
340
+ fig = plot_entropies(
341
+ patch_lengths,
342
+ scores,
343
+ tokens,
344
+ chars=decoded_output_for_plot,
345
+ threshold=self.patcher.threshold
346
+ )
347
  status += f"Bytelatent plot generated. Found {patch_count} patches.\n"
348
  else:
349
  status += "Plotting unavailable.\n"
 
533
  placeholder="Enter text here...",
534
  # Max length is for UI input; Bytelatent truncation happens in backend
535
  lines=5,
536
+ info=f"Note: Entropy-based Patcher processing is limited to {Config.BLT_MAX_BYTES_FOR_DEMO} bytes for this demo."
537
  )
538
  submit_button = gr.Button("Generate Visualizations", variant="primary")
539
  status_output = gr.Textbox(label="Processing Status", interactive=False, lines=10) # More space for detailed status
bytelatent/plotting/entropy_figure_via_matplot_lib.py CHANGED
@@ -1,73 +1,123 @@
1
- import os
2
  import torch
3
- import matplotlib.pyplot as plt
4
  import numpy as np
 
 
5
 
6
-
7
- def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str, threshold: float):
 
 
 
 
 
8
  patch_lengths_np = patch_lengths.cpu().numpy().flatten()
9
  scores_np = scores.cpu().float().numpy().flatten()
10
- chars = chars.replace(" ", "_")
11
- tokens_np = np.array([char for char in "<"+chars])
12
-
13
- if len(scores_np) != len(tokens_np):
14
- raise ValueError("Length of scores and tokens tensors must be the same.")
15
- if patch_lengths_np.sum() != len(tokens_np):
16
- raise ValueError(f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
17
- f"does not match the length of tokens/scores ({len(tokens_np)}).")
18
-
19
-
20
- x_indices = np.arange(len(tokens_np))
21
 
22
- # Calculate cumulative sums of patch lengths for vertical line positions
23
- # These indicate the *end* index of each patch
24
- patch_boundaries = np.cumsum(patch_lengths_np)
25
-
26
- # --- Plotting ---
27
- fig, ax = plt.subplots(figsize=(15, 5)) # Adjust figure size as needed
28
-
29
- # Plot the scores as a blue line with markers
30
- ax.plot(x_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores')
31
-
32
- # Plot the vertical dotted lines at the patch boundaries
33
- # We plot a line *after* each patch, so at index `boundary - 1 + 0.5`
34
- # We skip the last boundary as it's the end of the data
35
- for boundary in patch_boundaries[:-1]:
36
- ax.axvline(x=boundary, color='grey', linestyle='--', linewidth=1)
37
-
38
- ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1, )
39
- ax.annotate(f'Entropy Threshold',
40
- xy=(0.05, threshold), # Position of the line
41
- xytext=(0.05, threshold + 0.1), # Text position
42
- xycoords='axes fraction', # Use axes coordinates (0-1)
43
- textcoords='data', # Use data coordinates for text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  color='red'
45
  )
46
 
47
- # Set x-axis ticks and labels
48
- ax.set_xticks(x_indices)
49
- ax.set_xticklabels(tokens_np, rotation=0, fontsize=8) # Rotate labels for better readability
 
 
 
 
 
 
50
 
51
- # Set labels for axes
52
- # Using the Y-axis label from the example image
53
- ax.set_ylabel("Entropy of Next Byte", fontsize=12)
54
- ax.set_xlabel("Tokens", fontsize=12)
55
 
56
- # Set y-axis limits (optional, but often good practice)
57
- ax.set_ylim(bottom=0) # Start y-axis at 0 like the example
58
- ax.set_xlim(left = x_indices[0]-1.0, right = x_indices[-1]+1.0) # Add padding to x-axis
59
 
60
- # Add grid lines (optional)
61
- # ax.grid(True, axis='y', linestyle=':', color='lightgrey')
 
 
 
 
62
 
63
- # Remove the top and right spines for cleaner look (optional)
 
64
  ax.spines['top'].set_visible(False)
65
  ax.spines['right'].set_visible(False)
66
 
67
- # Adjust layout and display the plot
 
68
  plt.tight_layout()
69
  return fig
70
- # output_filename = "token_score_plot.png"
71
- # fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
72
- # print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
73
-
 
 
1
  import torch
 
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import os
5
 
6
+ def plot_entropies( # Renamed from plot_entropies_revised for final output
7
+ patch_lengths: torch.Tensor,
8
+ scores: torch.Tensor,
9
+ tokens: torch.Tensor, # Length used via scores. Content implicitly for UTF-8 assumption.
10
+ chars: str,
11
+ threshold: float
12
+ ):
13
  patch_lengths_np = patch_lengths.cpu().numpy().flatten()
14
  scores_np = scores.cpu().float().numpy().flatten()
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ num_total_bytes_from_scores = len(scores_np)
17
+
18
+ # Prepare display string (prepend '<', replace spaces with '_')
19
+ display_string_processed_chars = chars.replace(" ", "_")
20
+ display_string = "<" + display_string_processed_chars
21
+ display_chars_list = list(display_string)
22
+ num_display_chars = len(display_chars_list)
23
+
24
+ if num_display_chars == 0 and num_total_bytes_from_scores == 0:
25
+ fig, ax = plt.subplots(figsize=(15,5))
26
+ ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center', fontsize=12)
27
+ ax.set_xlabel("Characters (on underlying byte sequence)")
28
+ ax.set_ylabel("Entropy of Next Byte")
29
+ ax.set_ylim(bottom=0)
30
+ ax.set_xlim(left = -0.5, right = 0.5) # Default xlim for empty plot
31
+ return fig
32
+ elif num_display_chars == 0 and num_total_bytes_from_scores > 0:
33
+ # Edge case: scores exist but no characters to map them to (implies an issue)
34
+ # For now, proceed with byte plot but no char labels. Or raise error.
35
+ # Assuming display_chars_list should not be empty if scores_np is not.
36
+ # This case should ideally be caught by byte_counts_per_display_char validation if it were run.
37
+ # If display_chars_list is truly empty but scores are not, an error should be raised by validation.
38
+ pass # Will be caught by validation if sum(byte_counts) != len(scores)
39
+
40
+ # Calculate byte counts for each character in the display string (assuming UTF-8)
41
+ try:
42
+ byte_counts_per_display_char = [len(c.encode('utf-8')) for c in display_chars_list]
43
+ except UnicodeEncodeError as e:
44
+ raise ValueError(
45
+ f"Could not encode characters in 'chars' string using UTF-8. "
46
+ f"Problematic part: '{display_string_processed_chars}'. Error: {e}"
47
+ )
48
+
49
+ # --- Validations ---
50
+ if sum(byte_counts_per_display_char) != num_total_bytes_from_scores:
51
+ # This condition also handles num_display_chars == 0 but num_total_bytes_from_scores > 0
52
+ raise ValueError(
53
+ f"Mismatch in byte counts: Sum of UTF-8 bytes for display_string "
54
+ f"('{display_string}' -> {sum(byte_counts_per_display_char)} bytes) "
55
+ f"does not match length of scores tensor ({num_total_bytes_from_scores}). "
56
+ f"Ensure 'chars' (and the prepended '<') correctly correspond to the byte sequence "
57
+ f"represented by 'scores'/'tokens'."
58
+ )
59
+
60
+ if patch_lengths_np.sum() != num_total_bytes_from_scores:
61
+ raise ValueError(
62
+ f"Sum of patch_lengths ({patch_lengths_np.sum()}) "
63
+ f"does not match length of scores ({num_total_bytes_from_scores})."
64
+ )
65
+
66
+ # --- Plotting Setup ---
67
+ fig, ax = plt.subplots(figsize=(15, 5)) # Fixed size as requested
68
+ x_byte_indices = np.arange(num_total_bytes_from_scores)
69
+
70
+ # --- Plot Scores (Horizontally per byte) ---
71
+ # Original plot line style from user's code: marker='.', linestyle='-'
72
+ ax.plot(x_byte_indices, scores_np, marker='.', linestyle='-', color='steelblue', label='Scores per byte')
73
+
74
+ # --- Plot Vertical Patch Boundary Lines ---
75
+ # Using (cumulative_length - 0.5) logic for lines between byte elements.
76
+ # This matches the intent of `boundary - 1 + 0.5` from user's original code snippet.
77
+ patch_end_byte_cumulative_lengths = np.cumsum(patch_lengths_np)
78
+ for boundary_len in patch_end_byte_cumulative_lengths[:-1]: # Exclude the last boundary (end of all data)
79
+ ax.axvline(x=boundary_len, color='grey', linestyle='--', linewidth=1)
80
+
81
+ # --- Horizontal Threshold Line and Annotation ---
82
+ ax.axhline(y=threshold, color='red', linestyle='--', linewidth=1)
83
+ ax.annotate(f'Entropy Threshold', # Original text from user's code
84
+ xy=(0.05, threshold), # Original xy from user's code
85
+ xytext=(0.05, threshold + 0.1),# Original xytext from user's code
86
+ xycoords='axes fraction', # Original xycoords
87
+ textcoords='data', # Original textcoords
88
  color='red'
89
  )
90
 
91
+ # --- X-axis Ticks and Labels (Character labels at start of their byte sequences) ---
92
+ char_label_positions = []
93
+ char_labels_for_ticks = []
94
+ current_byte_tracker = 0
95
+ if num_display_chars > 0 : # Ensure byte_counts_per_display_char is not empty
96
+ for i_char in range(num_display_chars):
97
+ char_label_positions.append(current_byte_tracker)
98
+ char_labels_for_ticks.append(display_chars_list[i_char])
99
+ current_byte_tracker += byte_counts_per_display_char[i_char]
100
 
101
+ ax.set_xticks(char_label_positions)
102
+ ax.set_xticklabels(char_labels_for_ticks, rotation=0, fontsize=8) # User's original rotation and fontsize
 
 
103
 
104
+ # --- Axes Configuration ---
105
+ ax.set_ylabel("Entropy of Next Byte", fontsize=12) # User's original
106
+ ax.set_xlabel("Characters (on underlying byte sequence)", fontsize=12) # Descriptive X-axis label
107
 
108
+ ax.set_ylim(bottom=0) # User's original y-axis bottom limit
109
+ # Set x-axis limits to show all bytes clearly from -0.5 to last_byte_idx + 0.5
110
+ if num_total_bytes_from_scores > 0:
111
+ ax.set_xlim(left=-0.5, right=num_total_bytes_from_scores - 0.5)
112
+ else: # Handle case of no bytes (e.g. if chars was empty and scores was empty)
113
+ ax.set_xlim(left=-0.5, right=0.5)
114
 
115
+
116
+ # Spines (as per user's original code removing top and right)
117
  ax.spines['top'].set_visible(False)
118
  ax.spines['right'].set_visible(False)
119
 
120
+ # Grid: User's original code did not explicitly add grid lines.
121
+
122
  plt.tight_layout()
123
  return fig