Bor Hodošček commited on
Commit
6b833b2
·
unverified ·
1 Parent(s): fade8c0

fix: bpe byte display and misc display tweaks

Browse files
Files changed (1) hide show
  1. app.py +350 -222
app.py CHANGED
@@ -13,30 +13,42 @@ app = marimo.App(width="medium")
13
  def _():
14
  import hashlib
15
  import math
 
 
16
 
17
  import altair as alt
18
  import marimo as mo
19
  import polars as pl
20
  import spacy
21
- from transformers import AutoTokenizer
 
 
 
 
22
 
23
  # Load spaCy models for English and Japanese
24
- nlp_en = spacy.load("en_core_web_md")
25
- nlp_ja = spacy.load("ja_core_news_md")
26
 
27
  # List of tokenizer models
28
- llm_model_choices = [
29
- "meta-llama/Llama-4-Scout-17B-16E-Instruct",
30
  "google/gemma-3-27b-it",
31
- "deepseek-ai/DeepSeek-R1",
32
- "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
33
- "Qwen/Qwen2.5-72B-Instruct",
 
 
 
34
  "google-bert/bert-large-uncased",
35
- "openai-community/gpt2",
36
  ]
37
-
38
  return (
 
39
  AutoTokenizer,
 
 
 
 
40
  alt,
41
  hashlib,
42
  llm_model_choices,
@@ -45,18 +57,23 @@ def _():
45
  nlp_en,
46
  nlp_ja,
47
  pl,
 
 
48
  )
49
 
50
 
51
  @app.cell
52
  def _(mo):
53
- mo.md("# Tokenization for English and Japanese")
54
  return
55
 
56
 
57
  @app.cell
58
- def _(mo):
59
  # Central state for the text input content
 
 
 
60
  get_text_content, set_text_content = mo.state("")
61
  return get_text_content, set_text_content
62
 
@@ -73,7 +90,7 @@ def _(mo):
73
  """.strip()
74
 
75
  # Create UI element for language selection
76
- language_selector = mo.ui.radio(
77
  options=["English", "Japanese"], value="English", label="Language"
78
  )
79
 
@@ -91,29 +108,30 @@ def _(
91
  set_text_content,
92
  ):
93
  # Define text_input dynamically based on language
94
- current_placeholder = (
95
  en_placeholder if language_selector.value == "English" else ja_placeholder
96
  )
97
- text_input = mo.ui.text_area(
98
- # Read value from state
99
  value=get_text_content(),
100
  label="Enter text",
101
  placeholder=current_placeholder,
102
  full_width=True,
103
- # Update state on user input
104
  on_change=lambda v: set_text_content(v),
105
  )
 
106
  return current_placeholder, text_input
107
 
108
 
109
  @app.cell
110
- def _(current_placeholder, mo, set_text_content):
111
- def apply_placeholder():
 
112
  set_text_content(current_placeholder)
113
 
114
- apply_placeholder_button = mo.ui.button(
115
  label="Use Placeholder Text", on_click=lambda _: apply_placeholder()
116
  )
 
117
  return (apply_placeholder_button,)
118
 
119
 
@@ -129,37 +147,41 @@ def _(apply_placeholder_button, language_selector, mo, text_input):
129
 
130
 
131
  @app.cell
132
- def _(get_text_content, language_selector, mo, nlp_en, nlp_ja):
133
  # Analyze text using spaCy based on selected language
134
- # Read text from state
135
- current_text = get_text_content()
136
  if language_selector.value == "English":
137
  doc = nlp_en(current_text)
138
  else:
139
  doc = nlp_ja(current_text)
 
 
 
 
 
140
 
141
- # Tokenized version and count
142
- tokenized_text = [token.text for token in doc]
143
- token_count = len(tokenized_text)
144
 
145
  mo.md(
146
- f"**Tokenized Text:** {' | '.join(tokenized_text)}\n\n**Token Count:** {token_count}"
147
  )
148
  return current_text, doc
149
 
150
 
151
  @app.cell
152
  def _(doc, mo, pl):
153
- # Create a polars DataFrame with token attributes
154
- token_data = pl.DataFrame(
155
  {
156
  "Token": [token.text for token in doc],
157
  "Lemma": [token.lemma_ for token in doc],
158
  "POS": [token.pos_ for token in doc],
159
  "Tag": [token.tag_ for token in doc],
160
- "Morph": [
161
- str(token.morph) for token in doc
162
- ], # To be more precise, this should be merged back in via .to_dict()
 
163
  "Token Position": list(range(len(doc))),
164
  "Sentence Number": [
165
  i for i, sent in enumerate(doc.sents) for token in sent
@@ -173,9 +195,8 @@ def _(doc, mo, pl):
173
 
174
  @app.cell
175
  def _(mo):
176
- # Create UI element for selecting the column to visualize
177
- column_selector = mo.ui.dropdown(
178
- options=["POS", "Tag", "Lemma", "Token", "Morph"],
179
  value="POS",
180
  label="Select column to visualize",
181
  )
@@ -185,18 +206,18 @@ def _(mo):
185
 
186
 
187
  @app.cell
188
- def _(alt, column_selector, mo, token_data):
189
  mo.stop(token_data.is_empty(), "Please set input text.")
190
 
191
- selected_column = column_selector.value
192
  # Calculate value counts for the selected column
193
- counts_df = (
194
  token_data[selected_column]
195
  .value_counts()
196
  .sort(by=["count", selected_column], descending=[True, False])
197
  )
198
 
199
- chart = (
200
  alt.Chart(counts_df)
201
  .mark_bar()
202
  .encode(
@@ -213,10 +234,9 @@ def _(alt, column_selector, mo, token_data):
213
 
214
  @app.cell
215
  def _(llm_model_choices, mo):
216
- # UI for selecting the LLM tokenizer model
217
- llm_tokenizer_selector = mo.ui.dropdown(
218
  options=llm_model_choices,
219
- value=llm_model_choices[-1], # Default to gpt2 for faster loading initially
220
  label="Select LLM Tokenizer Model",
221
  )
222
  llm_tokenizer_selector
@@ -224,101 +244,92 @@ def _(llm_model_choices, mo):
224
 
225
 
226
  @app.cell
227
- def _(AutoTokenizer, llm_tokenizer_selector):
228
- # Load the selected tokenizer
229
  # Adapted code from: https://huggingface.co/spaces/barttee/tokenizers/blob/main/app.py
230
- # This cell will re-run when llm_tokenizer_selector.value changes
231
- # Marimo caches the result implicitly based on inputs
232
- selected_model_name = llm_tokenizer_selector.value
233
- tokenizer = AutoTokenizer.from_pretrained(selected_model_name)
234
  return (tokenizer,)
235
 
236
 
237
  @app.cell
238
- def _(math):
239
- # Function to calculate token statistics
240
- def get_token_stats(tokens: list, original_text: str) -> dict:
 
241
  """Calculate enhanced statistics about the tokens."""
242
  if not tokens:
243
- return { # Return default structure even for empty input
 
244
  "basic_stats": {
245
  "total_tokens": 0,
246
  "unique_tokens": 0,
247
- "compression_ratio": 0,
248
  "space_tokens": 0,
249
  "newline_tokens": 0,
250
  "special_tokens": 0,
251
  "punctuation_tokens": 0,
252
- "unique_percentage": 0,
253
  },
254
  "length_stats": {
255
- "avg_length": 0,
256
- "std_dev": 0,
257
  "min_length": 0,
258
  "max_length": 0,
259
- "median_length": 0,
260
  },
261
  }
262
 
263
- total_tokens = len(tokens)
264
- unique_tokens = len(set(tokens))
265
- # Handle potential division by zero if total_tokens is 0 (already checked by `if not tokens`)
266
- avg_length = (
267
- sum(len(t) for t in tokens) / total_tokens if total_tokens > 0 else 0
268
  )
269
- # Handle potential division by zero if total_tokens is 0
270
- compression_ratio = len(original_text) / total_tokens if total_tokens > 0 else 0
271
-
272
- # Token type analysis (Note: Heuristics might vary between tokenizers)
273
- # Using startswith(('Ġ', ' ')) covers common space markers like SentencePiece's U+2581 and BPE's 'Ġ'
274
- space_tokens = sum(1 for t in tokens if t.startswith(("Ġ", " ")))
275
- # Check for common newline representations
276
- newline_tokens = sum(
277
  1 for t in tokens if "Ċ" in t or t == "\n" or t == "<0x0A>"
278
  )
279
- # A broader definition for special tokens based on common patterns (control tokens)
280
- special_tokens = sum(
281
  1
282
  for t in tokens
283
  if (t.startswith("<") and t.endswith(">"))
284
  or (t.startswith("[") and t.endswith("]"))
285
  )
286
- # Simple punctuation check (might overlap with other categories, focuses on single char punct)
287
- punctuation_tokens = sum(
288
  1
289
  for t in tokens
290
  if len(t) == 1 and not t.isalnum() and t not in [" ", "\n", "Ġ", "Ċ"]
291
  )
292
 
293
- # Length distribution
294
- lengths = [len(t) for t in tokens]
295
  if not lengths: # Should not happen if tokens is not empty, but safe check
296
- return {
297
  "basic_stats": {
298
  "total_tokens": 0,
299
  "unique_tokens": 0,
300
- "compression_ratio": 0,
301
  "space_tokens": 0,
302
  "newline_tokens": 0,
303
  "special_tokens": 0,
304
  "punctuation_tokens": 0,
305
- "unique_percentage": 0,
306
  },
307
  "length_stats": {
308
- "avg_length": 0,
309
- "std_dev": 0,
310
  "min_length": 0,
311
  "max_length": 0,
312
- "median_length": 0,
313
  },
314
  }
315
 
316
- mean_length = sum(lengths) / len(lengths)
317
- variance = sum((x - mean_length) ** 2 for x in lengths) / len(lengths)
318
- std_dev = math.sqrt(variance)
319
- sorted_lengths = sorted(lengths)
320
- # Handle case where lengths list might be empty after filtering, though unlikely here
321
- median_length = sorted_lengths[len(lengths) // 2] if lengths else 0
322
 
323
  return {
324
  "basic_stats": {
@@ -331,13 +342,13 @@ def _(math):
331
  "punctuation_tokens": punctuation_tokens,
332
  "unique_percentage": round(unique_tokens / total_tokens * 100, 1)
333
  if total_tokens > 0
334
- else 0,
335
  },
336
  "length_stats": {
337
- "avg_length": round(avg_length, 2),
338
  "std_dev": round(std_dev, 2),
339
- "min_length": min(lengths) if lengths else 0,
340
- "max_length": max(lengths) if lengths else 0,
341
  "median_length": median_length,
342
  },
343
  }
@@ -347,17 +358,13 @@ def _(math):
347
 
348
  @app.cell
349
  def _(hashlib):
350
- def get_varied_color(token: str) -> dict:
351
  """Generate vibrant colors with HSL for better visual distinction."""
352
- # Use a fixed salt or seed if you want consistent colors across runs for the same token
353
- token_hash = hashlib.md5(token.encode()).hexdigest()
354
- hue = int(token_hash[:3], 16) % 360
355
- saturation = 70 + (int(token_hash[3:5], 16) % 20) # Saturation between 70-90%
356
- lightness = 80 + (
357
- int(token_hash[5:7], 16) % 10
358
- ) # Lightness between 80-90% (light background)
359
- # Ensure text color contrasts well with the light background
360
- text_lightness = 20 # Dark text for light background
361
 
362
  return {
363
  "background": f"hsl({hue}, {saturation}%, {lightness}%)",
@@ -368,76 +375,67 @@ def _(hashlib):
368
 
369
 
370
  @app.function
371
- def fix_token(token: str) -> str:
372
- """Fix token for display with improved space visualization."""
373
- # Replace SentencePiece space marker U+2581 with a middle dot
 
 
 
 
 
 
 
 
 
 
 
374
  token = token.replace(" ", "·")
 
375
  # Replace BPE space marker 'Ġ' with a middle dot
376
  if token.startswith("Ġ"):
377
  space_count = token.count("Ġ")
 
378
  return "·" * space_count + token[space_count:]
 
379
  # Replace newline markers for display
380
- token = token.replace(
381
- "Ċ", "↵\n"
382
- ) # Replace newline marker with symbol and actual newline
383
- token = token.replace("<0x0A>", "↵\n") # Handle byte representation of newline
 
 
384
  return token
385
 
386
 
387
- @app.function
388
- def get_tokenizer_info(tokenizer):
389
- """
390
- Extract useful information from a tokenizer.
391
- Returns a dictionary with tokenizer details.
392
- """
393
-
394
- info = {}
395
- try:
396
- # Get vocabulary size (dictionary size)
397
- if hasattr(tokenizer, "vocab_size"):
398
- info["vocab_size"] = tokenizer.vocab_size
399
- elif hasattr(tokenizer, "get_vocab"):
400
- info["vocab_size"] = len(tokenizer.get_vocab())
401
-
402
- # Get model max length if available
403
- if (
404
- hasattr(tokenizer, "model_max_length")
405
- and tokenizer.model_max_length < 1000000
406
- ): # Sanity check for realistic values
407
- info["model_max_length"] = tokenizer.model_max_length
408
- else:
409
- info["model_max_length"] = "Not specified or very large"
410
-
411
- # Check tokenizer type
412
- info["tokenizer_type"] = tokenizer.__class__.__name__
413
-
414
- # Get special tokens using the recommended attributes/methods
415
- special_tokens = {}
416
- # Prefer all_special_tokens if available
417
- if hasattr(tokenizer, "all_special_tokens"):
418
- for token in tokenizer.all_special_tokens:
419
- # Try to find the attribute name corresponding to the token value
420
- token_name = "unknown_special_token" # Default name
421
- for attr_name in [
422
- "pad_token",
423
- "eos_token",
424
- "bos_token",
425
- "sep_token",
426
- "cls_token",
427
- "unk_token",
428
- "mask_token",
429
- ]:
430
- if (
431
- hasattr(tokenizer, attr_name)
432
- and getattr(tokenizer, attr_name) == token
433
- ):
434
- token_name = attr_name
435
- break
436
- if token and str(token).strip():
437
- special_tokens[token_name] = str(token)
438
- else:
439
- # Fallback to checking individual attributes
440
- for token_name in [
441
  "pad_token",
442
  "eos_token",
443
  "bos_token",
@@ -445,129 +443,259 @@ def get_tokenizer_info(tokenizer):
445
  "cls_token",
446
  "unk_token",
447
  "mask_token",
448
- ]:
449
- if (
450
- hasattr(tokenizer, token_name)
451
- and getattr(tokenizer, token_name) is not None
452
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  token_value = getattr(tokenizer, token_name)
454
- if token_value and str(token_value).strip():
 
 
 
 
455
  special_tokens[token_name] = str(token_value)
 
456
 
457
- info["special_tokens"] = special_tokens if special_tokens else "None found"
458
 
459
- except Exception as e:
460
- info["error"] = f"Error extracting tokenizer info: {str(e)}"
461
 
462
- return info
 
 
463
 
464
 
465
  @app.cell
466
  def _(mo):
467
- show_ids_switch = mo.ui.switch(label="Show Token IDs instead of Text", value=False)
 
 
468
  return (show_ids_switch,)
469
 
470
 
471
  @app.cell
472
  def _(
 
 
 
473
  current_text,
 
474
  get_token_stats,
 
475
  get_varied_color,
476
  llm_tokenizer_selector,
477
  mo,
 
478
  show_ids_switch,
479
  tokenizer,
480
  ):
481
- # --- Tokenization and Data Preparation ---
 
482
 
483
  # Get tokenizer metadata
484
- tokenizer_info = get_tokenizer_info(tokenizer)
 
 
 
 
 
 
 
 
 
 
 
485
 
486
- # Tokenize the input text
487
- # Use tokenize to get string representations for analysis and display
488
- all_tokens = tokenizer.tokenize(current_text)
489
- total_token_count = len(all_tokens)
490
 
491
- # Limit the number of tokens for display to avoid browser slowdown
492
- display_limit = 1000
493
- display_tokens = all_tokens[:display_limit]
494
- display_limit_reached = total_token_count > display_limit
 
 
495
 
496
  # Generate data for visualization
497
- llm_token_data = []
498
- for idx, token in enumerate(display_tokens):
499
- colors = get_varied_color(token)
500
- fixed_token_display = fix_token(token) # Apply fixes for display
501
- # Handle potential errors during ID conversion (e.g., unknown tokens if not handled by tokenizer)
502
- try:
503
- token_id = tokenizer.convert_tokens_to_ids(token)
504
- except KeyError:
505
- token_id = (
506
- tokenizer.unk_token_id if hasattr(tokenizer, "unk_token_id") else -1
507
- ) # Use UNK id or -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
  llm_token_data.append(
510
  {
511
- "original": token,
512
- "display": fixed_token_display,
513
  "colors": colors,
514
- "is_newline": "↵"
515
- in fixed_token_display, # Check if it represents a newline
516
  "token_id": token_id,
517
  "token_index": idx,
 
518
  }
519
  )
520
 
521
- # Calculate statistics using the full token list
522
- token_stats = get_token_stats(all_tokens, current_text)
523
-
524
- # Construct HTML for colored tokens
525
- html_parts = []
526
- for item in llm_token_data:
527
- # Use pre-wrap to respect spaces and newlines within the token display
528
- style = f"background-color: {item['colors']['background']}; color: {item['colors']['text']}; padding: 1px 3px; margin: 1px; border-radius: 3px; display: inline-block; white-space: pre-wrap; line-height: 1.4;"
529
- # Add title attribute for hover info (original token + ID)
530
- title = f"Original: {item['original']}\nID: {item['token_id']}"
531
- display_content = (
532
- str(item["token_id"]) if show_ids_switch.value else item["display"]
533
- )
534
- html_parts.append(
535
- f'<span style="{style}" title="{title}">{display_content}</span>'
536
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
- token_viz_html = mo.Html(
539
  f'<div style="line-height: 1.6;">{"".join(html_parts)}</div>'
540
  )
541
 
542
- basic_stats = token_stats["basic_stats"]
543
- length_stats = token_stats["length_stats"]
 
 
 
544
 
545
- basic_stats_md = "**Basic Stats:**\n\n" + "\n".join(
 
 
 
 
 
546
  f"- **{key.replace('_', ' ').title()}:** `{value}`"
547
  for key, value in basic_stats.items()
548
  )
549
 
550
- length_stats_md = "**Length (Character) Stats:**\n\n" + "\n".join(
551
  f"- **{key.replace('_', ' ').title()}:** `{value}`"
552
  for key, value in length_stats.items()
553
  )
554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  mo.md(f"""# LLM tokenizer: {llm_tokenizer_selector.value}
556
 
 
 
 
557
  {show_ids_switch}
558
 
559
  ## Tokenizer output
560
-
561
  {mo.as_html(token_viz_html)}
562
 
563
  ## Token Statistics
 
564
 
565
  {basic_stats_md}
566
 
567
  {length_stats_md}
568
 
569
  """)
570
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
 
573
  @app.cell
 
13
  def _():
14
  import hashlib
15
  import math
16
+ import re
17
+ from typing import Any, Callable, Optional, Union
18
 
19
  import altair as alt
20
  import marimo as mo
21
  import polars as pl
22
  import spacy
23
+ import spacy.language
24
+ from transformers import (
25
+ AutoTokenizer,
26
+ PreTrainedTokenizerBase,
27
+ )
28
 
29
  # Load spaCy models for English and Japanese
30
+ nlp_en: spacy.language.Language = spacy.load("en_core_web_md")
31
+ nlp_ja: spacy.language.Language = spacy.load("ja_core_news_md")
32
 
33
  # List of tokenizer models
34
+ llm_model_choices: list[str] = [
35
+ # "meta-llama/Llama-4-Scout-17B-16E-Instruct",
36
  "google/gemma-3-27b-it",
37
+ "ibm-granite/granite-3.3-8b-instruct",
38
+ "shisa-ai/shisa-v2-qwen2.5-7b",
39
+ # "deepseek-ai/DeepSeek-R1",
40
+ # "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
41
+ # "Qwen/Qwen2.5-72B-Instruct",
42
+ # "openai-community/gpt2",
43
  "google-bert/bert-large-uncased",
 
44
  ]
 
45
  return (
46
+ Any,
47
  AutoTokenizer,
48
+ Callable,
49
+ Optional,
50
+ PreTrainedTokenizerBase,
51
+ Union,
52
  alt,
53
  hashlib,
54
  llm_model_choices,
 
57
  nlp_en,
58
  nlp_ja,
59
  pl,
60
+ re,
61
+ spacy,
62
  )
63
 
64
 
65
  @app.cell
66
  def _(mo):
67
+ mo.md("""# Tokenization for English and Japanese""")
68
  return
69
 
70
 
71
  @app.cell
72
+ def _(Callable, mo):
73
  # Central state for the text input content
74
+ # Type the getter and setter
75
+ get_text_content: Callable[[], str]
76
+ set_text_content: Callable[[str], None]
77
  get_text_content, set_text_content = mo.state("")
78
  return get_text_content, set_text_content
79
 
 
90
  """.strip()
91
 
92
  # Create UI element for language selection
93
+ language_selector: mo.ui.radio = mo.ui.radio(
94
  options=["English", "Japanese"], value="English", label="Language"
95
  )
96
 
 
108
  set_text_content,
109
  ):
110
  # Define text_input dynamically based on language
111
+ current_placeholder: str = (
112
  en_placeholder if language_selector.value == "English" else ja_placeholder
113
  )
114
+ text_input: mo.ui.text_area = mo.ui.text_area(
 
115
  value=get_text_content(),
116
  label="Enter text",
117
  placeholder=current_placeholder,
118
  full_width=True,
 
119
  on_change=lambda v: set_text_content(v),
120
  )
121
+ # Type the return tuple
122
  return current_placeholder, text_input
123
 
124
 
125
  @app.cell
126
+ def _(Callable, current_placeholder, mo, set_text_content):
127
+ # Type the inner function
128
+ def apply_placeholder() -> None:
129
  set_text_content(current_placeholder)
130
 
131
+ apply_placeholder_button: mo.ui.button = mo.ui.button(
132
  label="Use Placeholder Text", on_click=lambda _: apply_placeholder()
133
  )
134
+ # Type the return tuple
135
  return (apply_placeholder_button,)
136
 
137
 
 
147
 
148
 
149
  @app.cell
150
+ def _(get_text_content, language_selector, mo, nlp_en, nlp_ja, spacy):
151
  # Analyze text using spaCy based on selected language
152
+ current_text: str = get_text_content()
153
+ doc: spacy.tokens.Doc
154
  if language_selector.value == "English":
155
  doc = nlp_en(current_text)
156
  else:
157
  doc = nlp_ja(current_text)
158
+ model_name: str = (
159
+ nlp_en.meta["name"]
160
+ if language_selector.value == "English"
161
+ else nlp_ja.meta["name"]
162
+ )
163
 
164
+ tokenized_text: list[str] = [token.text for token in doc]
165
+ token_count: int = len(tokenized_text)
 
166
 
167
  mo.md(
168
+ f"**Tokenized Text using spaCy {'en_' if language_selector.value == 'English' else 'ja_'}{model_name}:** {' | '.join(tokenized_text)}\n\n**Token Count:** {token_count}"
169
  )
170
  return current_text, doc
171
 
172
 
173
  @app.cell
174
  def _(doc, mo, pl):
175
+ token_data: pl.DataFrame = pl.DataFrame(
 
176
  {
177
  "Token": [token.text for token in doc],
178
  "Lemma": [token.lemma_ for token in doc],
179
  "POS": [token.pos_ for token in doc],
180
  "Tag": [token.tag_ for token in doc],
181
+ "Morph": [str(token.morph) for token in doc],
182
+ "OOV": [
183
+ token.is_oov for token in doc
184
+ ], # FIXME: How to get .is_oov() from sudachi directly? This only works for English now...
185
  "Token Position": list(range(len(doc))),
186
  "Sentence Number": [
187
  i for i, sent in enumerate(doc.sents) for token in sent
 
195
 
196
  @app.cell
197
  def _(mo):
198
+ column_selector: mo.ui.dropdown = mo.ui.dropdown(
199
+ options=["POS", "Tag", "Lemma", "Token", "Morph", "OOV"],
 
200
  value="POS",
201
  label="Select column to visualize",
202
  )
 
206
 
207
 
208
  @app.cell
209
+ def _(alt, column_selector, mo, pl, token_data):
210
  mo.stop(token_data.is_empty(), "Please set input text.")
211
 
212
+ selected_column: str = column_selector.value
213
  # Calculate value counts for the selected column
214
+ counts_df: pl.DataFrame = (
215
  token_data[selected_column]
216
  .value_counts()
217
  .sort(by=["count", selected_column], descending=[True, False])
218
  )
219
 
220
+ chart: alt.Chart = (
221
  alt.Chart(counts_df)
222
  .mark_bar()
223
  .encode(
 
234
 
235
  @app.cell
236
  def _(llm_model_choices, mo):
237
+ llm_tokenizer_selector: mo.ui.dropdown = mo.ui.dropdown(
 
238
  options=llm_model_choices,
239
+ value=llm_model_choices[0],
240
  label="Select LLM Tokenizer Model",
241
  )
242
  llm_tokenizer_selector
 
244
 
245
 
246
  @app.cell
247
+ def _(AutoTokenizer, PreTrainedTokenizerBase, llm_tokenizer_selector):
 
248
  # Adapted code from: https://huggingface.co/spaces/barttee/tokenizers/blob/main/app.py
249
+ selected_model_name: str = llm_tokenizer_selector.value
250
+ tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
251
+ selected_model_name
252
+ )
253
  return (tokenizer,)
254
 
255
 
256
  @app.cell
257
+ def _(Union, math):
258
+ TokenStatsDict = dict[str, dict[str, Union[int, float]]]
259
+
260
+ def get_token_stats(tokens: list[str], original_text: str) -> TokenStatsDict:
261
  """Calculate enhanced statistics about the tokens."""
262
  if not tokens:
263
+ # Return default structure matching TokenStatsDict
264
+ return {
265
  "basic_stats": {
266
  "total_tokens": 0,
267
  "unique_tokens": 0,
268
+ "compression_ratio": 0.0,
269
  "space_tokens": 0,
270
  "newline_tokens": 0,
271
  "special_tokens": 0,
272
  "punctuation_tokens": 0,
273
+ "unique_percentage": 0.0,
274
  },
275
  "length_stats": {
276
+ "avg_length": 0.0,
277
+ "std_dev": 0.0,
278
  "min_length": 0,
279
  "max_length": 0,
280
+ "median_length": 0.0,
281
  },
282
  }
283
 
284
+ total_tokens: int = len(tokens)
285
+ unique_tokens: int = len(set(tokens))
286
+ compression_ratio: float = (
287
+ len(original_text) / total_tokens if total_tokens > 0 else 0.0
 
288
  )
289
+
290
+ space_tokens: int = sum(1 for t in tokens if t.startswith(("Ġ", " ")))
291
+ newline_tokens: int = sum(
 
 
 
 
 
292
  1 for t in tokens if "Ċ" in t or t == "\n" or t == "<0x0A>"
293
  )
294
+ special_tokens: int = sum(
 
295
  1
296
  for t in tokens
297
  if (t.startswith("<") and t.endswith(">"))
298
  or (t.startswith("[") and t.endswith("]"))
299
  )
300
+ punctuation_tokens: int = sum(
 
301
  1
302
  for t in tokens
303
  if len(t) == 1 and not t.isalnum() and t not in [" ", "\n", "Ġ", "Ċ"]
304
  )
305
 
306
+ lengths: list[int] = [len(t) for t in tokens]
 
307
  if not lengths: # Should not happen if tokens is not empty, but safe check
308
+ return { # Return default structure matching TokenStatsDict
309
  "basic_stats": {
310
  "total_tokens": 0,
311
  "unique_tokens": 0,
312
+ "compression_ratio": 0.0,
313
  "space_tokens": 0,
314
  "newline_tokens": 0,
315
  "special_tokens": 0,
316
  "punctuation_tokens": 0,
317
+ "unique_percentage": 0.0,
318
  },
319
  "length_stats": {
320
+ "avg_length": 0.0,
321
+ "std_dev": 0.0,
322
  "min_length": 0,
323
  "max_length": 0,
324
+ "median_length": 0.0,
325
  },
326
  }
327
 
328
+ mean_length: float = sum(lengths) / len(lengths)
329
+ variance: float = sum((x - mean_length) ** 2 for x in lengths) / len(lengths)
330
+ std_dev: float = math.sqrt(variance)
331
+ sorted_lengths: list[int] = sorted(lengths)
332
+ median_length: float = float(sorted_lengths[len(lengths) // 2])
 
333
 
334
  return {
335
  "basic_stats": {
 
342
  "punctuation_tokens": punctuation_tokens,
343
  "unique_percentage": round(unique_tokens / total_tokens * 100, 1)
344
  if total_tokens > 0
345
+ else 0.0,
346
  },
347
  "length_stats": {
348
+ "avg_length": round(mean_length, 2),
349
  "std_dev": round(std_dev, 2),
350
+ "min_length": min(lengths),
351
+ "max_length": max(lengths),
352
  "median_length": median_length,
353
  },
354
  }
 
358
 
359
  @app.cell
360
  def _(hashlib):
361
+ def get_varied_color(token: str) -> dict[str, str]:
362
  """Generate vibrant colors with HSL for better visual distinction."""
363
+ token_hash: str = hashlib.md5(token.encode()).hexdigest()
364
+ hue: int = int(token_hash[:3], 16) % 360
365
+ saturation: int = 70 + (int(token_hash[3:5], 16) % 20)
366
+ lightness: int = 80 + (int(token_hash[5:7], 16) % 10)
367
+ text_lightness: int = 20
 
 
 
 
368
 
369
  return {
370
  "background": f"hsl({hue}, {saturation}%, {lightness}%)",
 
375
 
376
 
377
  @app.function
378
+ def fix_token(
379
+ token: str, re
380
+ ) -> (
381
+ str
382
+ ): # re module type is complex, leave as Any implicitly or import types.ModuleType
383
+ """Fix token for display, handling byte fallbacks and spaces."""
384
+ # Check for byte fallback pattern <0xHH> using a full match
385
+ byte_match = re.fullmatch(r"<0x([0-9A-Fa-f]{2})>", token)
386
+ if byte_match:
387
+ hex_value = byte_match.group(1).upper()
388
+ # Return a clear representation indicating it's a byte
389
+ return f"<0x{hex_value}>"
390
+
391
+ # Replace SentencePiece space marker U+2581 (' ') with a middle dot
392
  token = token.replace(" ", "·")
393
+
394
  # Replace BPE space marker 'Ġ' with a middle dot
395
  if token.startswith("Ġ"):
396
  space_count = token.count("Ġ")
397
+ # Ensure we only replace the leading 'Ġ' markers
398
  return "·" * space_count + token[space_count:]
399
+
400
  # Replace newline markers for display
401
+ token = token.replace("Ċ", "↵\n")
402
+ # Handle byte representation of newline AFTER general byte check
403
+ # This specific check might become redundant if <0x0A> is caught by the byte_match above
404
+ # Keep it for now as a fallback.
405
+ token = token.replace("<0x0A>", "↵\n")
406
+
407
  return token
408
 
409
 
410
+ @app.cell
411
+ def _(Any, PreTrainedTokenizerBase):
412
+ def get_tokenizer_info(
413
+ tokenizer: PreTrainedTokenizerBase,
414
+ ) -> dict[str, Any]:
415
+ """
416
+ Extract useful information from a tokenizer.
417
+ Returns a dictionary with tokenizer details.
418
+ """
419
+ info: dict[str, Any] = {}
420
+ try:
421
+ if hasattr(tokenizer, "vocab_size"):
422
+ info["vocab_size"] = tokenizer.vocab_size
423
+ elif hasattr(tokenizer, "get_vocab"):
424
+ info["vocab_size"] = len(tokenizer.get_vocab())
425
+
426
+ if (
427
+ hasattr(tokenizer, "model_max_length")
428
+ and isinstance(tokenizer.model_max_length, int)
429
+ and tokenizer.model_max_length < 1000000
430
+ ):
431
+ info["model_max_length"] = tokenizer.model_max_length
432
+ else:
433
+ info["model_max_length"] = "Not specified or very large"
434
+
435
+ info["tokenizer_type"] = tokenizer.__class__.__name__
436
+
437
+ special_tokens: dict[str, str] = {}
438
+ special_token_attributes: list[str] = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  "pad_token",
440
  "eos_token",
441
  "bos_token",
 
443
  "cls_token",
444
  "unk_token",
445
  "mask_token",
446
+ ]
447
+
448
+ processed_tokens: set[str] = (
449
+ set()
450
+ ) # Keep track of processed tokens to avoid duplicates
451
+
452
+ # Prefer all_special_tokens if available
453
+ if hasattr(tokenizer, "all_special_tokens"):
454
+ for token_value in tokenizer.all_special_tokens:
455
+ if (
456
+ not token_value
457
+ or not str(token_value).strip()
458
+ or str(token_value) in processed_tokens
459
+ ):
460
+ continue
461
+
462
+ token_name = "special_token" # Default name
463
+ # Find the attribute name corresponding to the token value
464
+ for attr_name in special_token_attributes:
465
+ if (
466
+ hasattr(tokenizer, attr_name)
467
+ and getattr(tokenizer, attr_name) == token_value
468
+ ):
469
+ token_name = attr_name
470
+ break
471
+ special_tokens[token_name] = str(token_value)
472
+ processed_tokens.add(str(token_value))
473
+
474
+ # Fallback/Augment with individual attributes if not covered by all_special_tokens
475
+ for token_name in special_token_attributes:
476
+ if hasattr(tokenizer, token_name):
477
  token_value = getattr(tokenizer, token_name)
478
+ if (
479
+ token_value
480
+ and str(token_value).strip()
481
+ and str(token_value) not in processed_tokens
482
+ ):
483
  special_tokens[token_name] = str(token_value)
484
+ processed_tokens.add(str(token_value))
485
 
486
+ info["special_tokens"] = special_tokens if special_tokens else "None found"
487
 
488
+ except Exception as e:
489
+ info["error"] = f"Error extracting tokenizer info: {str(e)}"
490
 
491
+ return info
492
+
493
+ return (get_tokenizer_info,)
494
 
495
 
496
  @app.cell
497
  def _(mo):
498
+ show_ids_switch: mo.ui.switch = mo.ui.switch(
499
+ label="Show token IDs instead of text", value=False
500
+ )
501
  return (show_ids_switch,)
502
 
503
 
504
  @app.cell
505
  def _(
506
+ Any,
507
+ Optional,
508
+ Union,
509
  current_text,
510
+ fix_token,
511
  get_token_stats,
512
+ get_tokenizer_info,
513
  get_varied_color,
514
  llm_tokenizer_selector,
515
  mo,
516
+ re,
517
  show_ids_switch,
518
  tokenizer,
519
  ):
520
+ # Define the Unicode replacement character
521
+ REPLACEMENT_CHARACTER = "\ufffd"
522
 
523
  # Get tokenizer metadata
524
+ tokenizer_info: dict[str, Any] = get_tokenizer_info(tokenizer)
525
+
526
+ # 1. Encode text to get token IDs first.
527
+ token_ids: list[int] = tokenizer.encode(current_text, add_special_tokens=False)
528
+ # 2. Decode each token ID individually.
529
+ # We will check for REPLACEMENT_CHARACTER later.
530
+ all_decoded_tokens: list[str] = [
531
+ tokenizer.decode(
532
+ [token_id], skip_special_tokens=False, clean_up_tokenization_spaces=False
533
+ )
534
+ for token_id in token_ids
535
+ ]
536
 
537
+ total_token_count: int = len(token_ids) # Count based on IDs
 
 
 
538
 
539
+ # Limit the number of tokens for display
540
+ display_limit: int = 1000
541
+ # Limit consistently using token IDs and the decoded tokens
542
+ display_token_ids: list[int] = token_ids[:display_limit]
543
+ display_decoded_tokens: list[str] = all_decoded_tokens[:display_limit]
544
+ display_limit_reached: bool = total_token_count > display_limit
545
 
546
  # Generate data for visualization
547
+ TokenVisData = dict[str, Union[str, int, bool, dict[str, str]]]
548
+ llm_token_data: list[TokenVisData] = []
549
+
550
+ # Use zip for parallel iteration
551
+ for idx, (token_id, token_str) in enumerate(
552
+ zip(display_token_ids, display_decoded_tokens)
553
+ ):
554
+ colors: dict[str, str] = get_varied_color(
555
+ token_str
556
+ if REPLACEMENT_CHARACTER not in token_str
557
+ else f"invalid_{token_id}"
558
+ ) # Color based on string or ID if invalid
559
+
560
+ is_invalid_utf8 = REPLACEMENT_CHARACTER in token_str
561
+ fixed_token_display: str
562
+ original_for_title: str = (
563
+ token_str # Store the potentially problematic string for title
564
+ )
565
+
566
+ if is_invalid_utf8:
567
+ # If decode failed, show a representation with the hex ID
568
+ fixed_token_display = f"<0x{token_id:X}>"
569
+ else:
570
+ # If decode succeeded, apply standard fixes
571
+ fixed_token_display = fix_token(token_str, re)
572
 
573
  llm_token_data.append(
574
  {
575
+ "original": original_for_title, # Store the raw decoded string (might contain �)
576
+ "display": fixed_token_display, # Store the cleaned/invalid representation
577
  "colors": colors,
578
+ "is_newline": "↵" in fixed_token_display, # Check the display version
 
579
  "token_id": token_id,
580
  "token_index": idx,
581
+ "is_invalid": is_invalid_utf8, # Add flag for potential styling/title changes
582
  }
583
  )
584
 
585
+ # Calculate statistics using the list of *successfully* decoded token strings
586
+ # We might want to reconsider what `all_tokens` means for stats if many are invalid.
587
+ # For now, let's use the potentially problematic strings, as stats are mostly length/count based.
588
+ token_stats: dict[str, dict[str, Union[int, float]]] = get_token_stats(
589
+ all_decoded_tokens,
590
+ current_text, # Pass the full list from decode()
591
+ )
592
+
593
+ # Construct HTML for colored tokens using list comprehension (functional style)
594
+ html_parts: list[str] = [
595
+ (
596
+ lambda item: (
597
+ style
598
+ := f"background-color: {item['colors']['background']}; color: {item['colors']['text']}; padding: 1px 3px; margin: 1px; border-radius: 3px; display: inline-block; white-space: pre-wrap; line-height: 1.4;"
599
+ # Add specific style for invalid tokens if needed
600
+ + (" border: 1px solid red;" if item.get("is_invalid") else ""),
601
+ # Modify title based on validity
602
+ title := (
603
+ f"Original: {item['original']}\nID: {item['token_id']}"
604
+ + ("\n(Invalid UTF-8)" if item.get("is_invalid") else "")
605
+ + ("\n(Byte Token)" if item["display"].startswith("byte[") else "")
606
+ ),
607
+ display_content := str(item["token_id"])
608
+ if show_ids_switch.value
609
+ else item["display"],
610
+ f'<span style="{style}" title="{title}">{display_content}</span>',
611
+ )[-1] # Get the last element (the formatted string) from the lambda's tuple
612
+ )(item)
613
+ for item in llm_token_data
614
+ ]
615
 
616
+ token_viz_html: mo.Html = mo.Html(
617
  f'<div style="line-height: 1.6;">{"".join(html_parts)}</div>'
618
  )
619
 
620
+ # Optional: Add a warning if the display limit was reached
621
+ limit_warning: Optional[mo.md] = None # Use Optional type
622
+ if display_limit_reached:
623
+ limit_warning = mo.md(f"""**Warning:** Displaying only the first {display_limit:,} tokens out of {total_token_count:,}.
624
+ Statistics are calculated on the full text.""").callout(kind="warn")
625
 
626
+ # Use dict access safely with .get() for stats
627
+ basic_stats: dict[str, Union[int, float]] = token_stats.get("basic_stats", {})
628
+ length_stats: dict[str, Union[int, float]] = token_stats.get("length_stats", {})
629
+
630
+ # Use list comprehensions for markdown generation (functional style)
631
+ basic_stats_md: str = "**Basic Stats:**\n\n" + "\n".join(
632
  f"- **{key.replace('_', ' ').title()}:** `{value}`"
633
  for key, value in basic_stats.items()
634
  )
635
 
636
+ length_stats_md: str = "**Length (Character) Stats:**\n\n" + "\n".join(
637
  f"- **{key.replace('_', ' ').title()}:** `{value}`"
638
  for key, value in length_stats.items()
639
  )
640
 
641
+ # Build tokenizer info markdown parts
642
+ tokenizer_info_md_parts: list[str] = [
643
+ f"**Tokenizer Type:** `{tokenizer_info.get('tokenizer_type', 'N/A')}`"
644
+ ]
645
+ if vocab_size := tokenizer_info.get("vocab_size"):
646
+ tokenizer_info_md_parts.append(f"**Vocab Size:** `{vocab_size:,}`")
647
+ if max_len := tokenizer_info.get("model_max_length"):
648
+ tokenizer_info_md_parts.append(f"**Model Max Length:** `{max_len}`")
649
+
650
+ special_tokens_info = tokenizer_info.get("special_tokens")
651
+ if isinstance(special_tokens_info, dict) and special_tokens_info:
652
+ tokenizer_info_md_parts.append("**Special Tokens:**")
653
+ tokenizer_info_md_parts.extend(
654
+ f" - `{name}`: `{str(val)}`" for name, val in special_tokens_info.items()
655
+ )
656
+ elif isinstance(special_tokens_info, str): # Handle "None found" case
657
+ tokenizer_info_md_parts.append(f"**Special Tokens:** `{special_tokens_info}`")
658
+
659
+ if error_info := tokenizer_info.get("error"):
660
+ tokenizer_info_md_parts.append(f"**Info Error:** `{error_info}`")
661
+
662
+ tokenizer_info_md: str = "\n\n".join(tokenizer_info_md_parts)
663
+
664
+ # Display the final markdown output
665
  mo.md(f"""# LLM tokenizer: {llm_tokenizer_selector.value}
666
 
667
+ ## Tokenizer Info
668
+ {tokenizer_info_md}
669
+
670
  {show_ids_switch}
671
 
672
  ## Tokenizer output
673
+ {limit_warning if limit_warning else ""}
674
  {mo.as_html(token_viz_html)}
675
 
676
  ## Token Statistics
677
+ (Calculated on full text if truncated above)
678
 
679
  {basic_stats_md}
680
 
681
  {length_stats_md}
682
 
683
  """)
684
+
685
+ return (
686
+ all_decoded_tokens,
687
+ token_ids,
688
+ basic_stats_md,
689
+ display_limit_reached,
690
+ length_stats_md,
691
+ limit_warning,
692
+ llm_token_data,
693
+ token_stats,
694
+ token_viz_html,
695
+ tokenizer_info,
696
+ tokenizer_info_md,
697
+ total_token_count,
698
+ )
699
 
700
 
701
  @app.cell