Spaces:
Running
Running
# /// script | |
# [tool.marimo.runtime] | |
# auto_instantiate = false | |
# /// | |
import marimo | |
__generated_with = "0.13.0" | |
app = marimo.App(width="medium") | |
def _(): | |
import hashlib | |
import math | |
import re | |
from typing import Any, Callable, Optional, Union | |
import altair as alt | |
import marimo as mo | |
import polars as pl | |
import spacy | |
import spacy.language | |
from transformers import ( | |
AutoTokenizer, | |
PreTrainedTokenizerBase, | |
) | |
# Load spaCy models for English and Japanese | |
nlp_en: spacy.language.Language = spacy.load("en_core_web_md") | |
nlp_ja: spacy.language.Language = spacy.load("ja_core_news_md") | |
# List of tokenizer models | |
llm_model_choices: list[str] = [ | |
# "meta-llama/Llama-4-Scout-17B-16E-Instruct", | |
"google/gemma-3-27b-it", | |
"ibm-granite/granite-3.3-8b-instruct", | |
"shisa-ai/shisa-v2-qwen2.5-7b", | |
# "deepseek-ai/DeepSeek-R1", | |
# "mistralai/Mistral-Small-3.1-24B-Instruct-2503", | |
# "Qwen/Qwen2.5-72B-Instruct", | |
# "openai-community/gpt2", | |
"google-bert/bert-large-uncased", | |
] | |
return ( | |
Any, | |
AutoTokenizer, | |
Callable, | |
Optional, | |
PreTrainedTokenizerBase, | |
Union, | |
alt, | |
hashlib, | |
llm_model_choices, | |
math, | |
mo, | |
nlp_en, | |
nlp_ja, | |
pl, | |
re, | |
spacy, | |
) | |
def _(mo): | |
mo.md("""# Tokenization for English and Japanese""") | |
return | |
def _(Callable, mo): | |
# Central state for the text input content | |
# Type the getter and setter | |
get_text_content: Callable[[], str] | |
set_text_content: Callable[[str], None] | |
get_text_content, set_text_content = mo.state("") | |
return get_text_content, set_text_content | |
def _(mo): | |
# Placeholder texts | |
en_placeholder = """ | |
Mrs. Ferrars died on the night of the 16th–17th September—a Thursday. I was sent for at eight o’clock on the morning of Friday the 17th. There was nothing to be done. She had been dead some hours. | |
""".strip() | |
ja_placeholder = """ | |
吾輩は猫である。名前はまだ無い。 | |
どこで生れたかとんと見当がつかぬ。何でも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している。 | |
""".strip() | |
# Create UI element for language selection | |
language_selector: mo.ui.radio = mo.ui.radio( | |
options=["English", "Japanese"], value="English", label="Language" | |
) | |
# Return selector and placeholders | |
return en_placeholder, ja_placeholder, language_selector | |
def _( | |
en_placeholder, | |
get_text_content, | |
ja_placeholder, | |
language_selector, | |
mo, | |
set_text_content, | |
): | |
# Define text_input dynamically based on language | |
current_placeholder: str = ( | |
en_placeholder if language_selector.value == "English" else ja_placeholder | |
) | |
text_input: mo.ui.text_area = mo.ui.text_area( | |
value=get_text_content(), | |
label="Enter text", | |
placeholder=current_placeholder, | |
full_width=True, | |
on_change=lambda v: set_text_content(v), | |
) | |
# Type the return tuple | |
return current_placeholder, text_input | |
def _(Callable, current_placeholder, mo, set_text_content): | |
# Type the inner function | |
def apply_placeholder() -> None: | |
set_text_content(current_placeholder) | |
apply_placeholder_button: mo.ui.button = mo.ui.button( | |
label="Use Placeholder Text", on_click=lambda _: apply_placeholder() | |
) | |
# Type the return tuple | |
return (apply_placeholder_button,) | |
def _(apply_placeholder_button, language_selector, mo, text_input): | |
mo.vstack( | |
[ | |
text_input, | |
mo.hstack([language_selector, apply_placeholder_button], justify="start"), | |
] | |
) | |
return | |
def _(get_text_content, language_selector, mo, nlp_en, nlp_ja, spacy): | |
# Analyze text using spaCy based on selected language | |
current_text: str = get_text_content() | |
doc: spacy.tokens.Doc | |
if language_selector.value == "English": | |
doc = nlp_en(current_text) | |
else: | |
doc = nlp_ja(current_text) | |
model_name: str = ( | |
nlp_en.meta["name"] | |
if language_selector.value == "English" | |
else nlp_ja.meta["name"] | |
) | |
tokenized_text: list[str] = [token.text for token in doc] | |
token_count: int = len(tokenized_text) | |
mo.md( | |
f"**Tokenized Text using spaCy {'en_' if language_selector.value == 'English' else 'ja_'}{model_name}:** {' | '.join(tokenized_text)}\n\n**Token Count:** {token_count}" | |
) | |
return current_text, doc | |
def _(doc, mo, pl): | |
token_data: pl.DataFrame = pl.DataFrame( | |
{ | |
"Token": [token.text for token in doc], | |
"Lemma": [token.lemma_ for token in doc], | |
"POS": [token.pos_ for token in doc], | |
"Tag": [token.tag_ for token in doc], | |
"Morph": [str(token.morph) for token in doc], | |
"OOV": [ | |
token.is_oov for token in doc | |
], # FIXME: How to get .is_oov() from sudachi directly? This only works for English now... | |
"Token Position": list(range(len(doc))), | |
"Sentence Number": [ | |
i for i, sent in enumerate(doc.sents) for token in sent | |
], | |
} | |
) | |
mo.ui.dataframe(token_data, page_size=50) | |
return (token_data,) | |
def _(mo): | |
column_selector: mo.ui.dropdown = mo.ui.dropdown( | |
options=["POS", "Tag", "Lemma", "Token", "Morph", "OOV"], | |
value="POS", | |
label="Select column to visualize", | |
) | |
column_selector | |
return (column_selector,) | |
def _(alt, column_selector, mo, pl, token_data): | |
mo.stop(token_data.is_empty(), "Please set input text.") | |
selected_column: str = column_selector.value | |
# Calculate value counts for the selected column | |
counts_df: pl.DataFrame = ( | |
token_data[selected_column] | |
.value_counts() | |
.sort(by=["count", selected_column], descending=[True, False]) | |
) | |
chart: alt.Chart = ( | |
alt.Chart(counts_df) | |
.mark_bar() | |
.encode( | |
x=alt.X("count", title="Frequency"), | |
y=alt.Y(selected_column, title=selected_column, sort=None), | |
tooltip=[selected_column, "count"], | |
) | |
.properties(title=f"{selected_column} Distribution") | |
.interactive() | |
) | |
mo.ui.altair_chart(chart) | |
return | |
def _(llm_model_choices, mo): | |
llm_tokenizer_selector: mo.ui.dropdown = mo.ui.dropdown( | |
options=llm_model_choices, | |
value=llm_model_choices[0], | |
label="Select LLM Tokenizer Model", | |
) | |
llm_tokenizer_selector | |
return (llm_tokenizer_selector,) | |
def _(AutoTokenizer, PreTrainedTokenizerBase, llm_tokenizer_selector): | |
# Adapted code from: https://huggingface.co/spaces/barttee/tokenizers/blob/main/app.py | |
selected_model_name: str = llm_tokenizer_selector.value | |
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( | |
selected_model_name | |
) | |
return (tokenizer,) | |
def _(Union, math): | |
TokenStatsDict = dict[str, dict[str, Union[int, float]]] | |
def get_token_stats(tokens: list[str], original_text: str) -> TokenStatsDict: | |
"""Calculate enhanced statistics about the tokens.""" | |
if not tokens: | |
# Return default structure matching TokenStatsDict | |
return { | |
"basic_stats": { | |
"total_tokens": 0, | |
"unique_tokens": 0, | |
"compression_ratio": 0.0, | |
"space_tokens": 0, | |
"newline_tokens": 0, | |
"special_tokens": 0, | |
"punctuation_tokens": 0, | |
"unique_percentage": 0.0, | |
}, | |
"length_stats": { | |
"avg_length": 0.0, | |
"std_dev": 0.0, | |
"min_length": 0, | |
"max_length": 0, | |
"median_length": 0.0, | |
}, | |
} | |
total_tokens: int = len(tokens) | |
unique_tokens: int = len(set(tokens)) | |
compression_ratio: float = ( | |
len(original_text) / total_tokens if total_tokens > 0 else 0.0 | |
) | |
space_tokens: int = sum(1 for t in tokens if t.startswith(("Ġ", " "))) | |
newline_tokens: int = sum( | |
1 for t in tokens if "Ċ" in t or t == "\n" or t == "<0x0A>" | |
) | |
special_tokens: int = sum( | |
1 | |
for t in tokens | |
if (t.startswith("<") and t.endswith(">")) | |
or (t.startswith("[") and t.endswith("]")) | |
) | |
punctuation_tokens: int = sum( | |
1 | |
for t in tokens | |
if len(t) == 1 and not t.isalnum() and t not in [" ", "\n", "Ġ", "Ċ"] | |
) | |
lengths: list[int] = [len(t) for t in tokens] | |
if not lengths: # Should not happen if tokens is not empty, but safe check | |
return { # Return default structure matching TokenStatsDict | |
"basic_stats": { | |
"total_tokens": 0, | |
"unique_tokens": 0, | |
"compression_ratio": 0.0, | |
"space_tokens": 0, | |
"newline_tokens": 0, | |
"special_tokens": 0, | |
"punctuation_tokens": 0, | |
"unique_percentage": 0.0, | |
}, | |
"length_stats": { | |
"avg_length": 0.0, | |
"std_dev": 0.0, | |
"min_length": 0, | |
"max_length": 0, | |
"median_length": 0.0, | |
}, | |
} | |
mean_length: float = sum(lengths) / len(lengths) | |
variance: float = sum((x - mean_length) ** 2 for x in lengths) / len(lengths) | |
std_dev: float = math.sqrt(variance) | |
sorted_lengths: list[int] = sorted(lengths) | |
median_length: float = float(sorted_lengths[len(lengths) // 2]) | |
return { | |
"basic_stats": { | |
"total_tokens": total_tokens, | |
"unique_tokens": unique_tokens, | |
"compression_ratio": round(compression_ratio, 2), | |
"space_tokens": space_tokens, | |
"newline_tokens": newline_tokens, | |
"special_tokens": special_tokens, | |
"punctuation_tokens": punctuation_tokens, | |
"unique_percentage": round(unique_tokens / total_tokens * 100, 1) | |
if total_tokens > 0 | |
else 0.0, | |
}, | |
"length_stats": { | |
"avg_length": round(mean_length, 2), | |
"std_dev": round(std_dev, 2), | |
"min_length": min(lengths), | |
"max_length": max(lengths), | |
"median_length": median_length, | |
}, | |
} | |
return (get_token_stats,) | |
def _(hashlib): | |
def get_varied_color(token: str) -> dict[str, str]: | |
"""Generate vibrant colors with HSL for better visual distinction.""" | |
token_hash: str = hashlib.md5(token.encode()).hexdigest() | |
hue: int = int(token_hash[:3], 16) % 360 | |
saturation: int = 70 + (int(token_hash[3:5], 16) % 20) | |
lightness: int = 80 + (int(token_hash[5:7], 16) % 10) | |
text_lightness: int = 20 | |
return { | |
"background": f"hsl({hue}, {saturation}%, {lightness}%)", | |
"text": f"hsl({hue}, {saturation}%, {text_lightness}%)", | |
} | |
return (get_varied_color,) | |
def fix_token( | |
token: str, re | |
) -> ( | |
str | |
): # re module type is complex, leave as Any implicitly or import types.ModuleType | |
"""Fix token for display, handling byte fallbacks and spaces.""" | |
# Check for byte fallback pattern <0xHH> using a full match | |
byte_match = re.fullmatch(r"<0x([0-9A-Fa-f]{2})>", token) | |
if byte_match: | |
hex_value = byte_match.group(1).upper() | |
# Return a clear representation indicating it's a byte | |
return f"<0x{hex_value}>" | |
# Replace SentencePiece space marker U+2581 (' ') with a middle dot | |
token = token.replace(" ", "·") | |
# Replace BPE space marker 'Ġ' with a middle dot | |
if token.startswith("Ġ"): | |
space_count = token.count("Ġ") | |
# Ensure we only replace the leading 'Ġ' markers | |
return "·" * space_count + token[space_count:] | |
# Replace newline markers for display | |
token = token.replace("Ċ", "↵\n") | |
# Handle byte representation of newline AFTER general byte check | |
# This specific check might become redundant if <0x0A> is caught by the byte_match above | |
# Keep it for now as a fallback. | |
token = token.replace("<0x0A>", "↵\n") | |
return token | |
def _(Any, PreTrainedTokenizerBase): | |
def get_tokenizer_info( | |
tokenizer: PreTrainedTokenizerBase, | |
) -> dict[str, Any]: | |
""" | |
Extract useful information from a tokenizer. | |
Returns a dictionary with tokenizer details. | |
""" | |
info: dict[str, Any] = {} | |
try: | |
if hasattr(tokenizer, "vocab_size"): | |
info["vocab_size"] = tokenizer.vocab_size | |
elif hasattr(tokenizer, "get_vocab"): | |
info["vocab_size"] = len(tokenizer.get_vocab()) | |
if ( | |
hasattr(tokenizer, "model_max_length") | |
and isinstance(tokenizer.model_max_length, int) | |
and tokenizer.model_max_length < 1000000 | |
): | |
info["model_max_length"] = tokenizer.model_max_length | |
else: | |
info["model_max_length"] = "Not specified or very large" | |
info["tokenizer_type"] = tokenizer.__class__.__name__ | |
special_tokens: dict[str, str] = {} | |
special_token_attributes: list[str] = [ | |
"pad_token", | |
"eos_token", | |
"bos_token", | |
"sep_token", | |
"cls_token", | |
"unk_token", | |
"mask_token", | |
] | |
processed_tokens: set[str] = ( | |
set() | |
) # Keep track of processed tokens to avoid duplicates | |
# Prefer all_special_tokens if available | |
if hasattr(tokenizer, "all_special_tokens"): | |
for token_value in tokenizer.all_special_tokens: | |
if ( | |
not token_value | |
or not str(token_value).strip() | |
or str(token_value) in processed_tokens | |
): | |
continue | |
token_name = "special_token" # Default name | |
# Find the attribute name corresponding to the token value | |
for attr_name in special_token_attributes: | |
if ( | |
hasattr(tokenizer, attr_name) | |
and getattr(tokenizer, attr_name) == token_value | |
): | |
token_name = attr_name | |
break | |
special_tokens[token_name] = str(token_value) | |
processed_tokens.add(str(token_value)) | |
# Fallback/Augment with individual attributes if not covered by all_special_tokens | |
for token_name in special_token_attributes: | |
if hasattr(tokenizer, token_name): | |
token_value = getattr(tokenizer, token_name) | |
if ( | |
token_value | |
and str(token_value).strip() | |
and str(token_value) not in processed_tokens | |
): | |
special_tokens[token_name] = str(token_value) | |
processed_tokens.add(str(token_value)) | |
info["special_tokens"] = special_tokens if special_tokens else "None found" | |
except Exception as e: | |
info["error"] = f"Error extracting tokenizer info: {str(e)}" | |
return info | |
return (get_tokenizer_info,) | |
def _(mo): | |
show_ids_switch: mo.ui.switch = mo.ui.switch( | |
label="Show token IDs instead of text", value=False | |
) | |
return (show_ids_switch,) | |
def _( | |
Any, | |
Optional, | |
Union, | |
current_text, | |
fix_token, | |
get_token_stats, | |
get_tokenizer_info, | |
get_varied_color, | |
llm_tokenizer_selector, | |
mo, | |
re, | |
show_ids_switch, | |
tokenizer, | |
): | |
# Define the Unicode replacement character | |
REPLACEMENT_CHARACTER = "\ufffd" | |
# Get tokenizer metadata | |
tokenizer_info: dict[str, Any] = get_tokenizer_info(tokenizer) | |
# 1. Encode text to get token IDs first. | |
token_ids: list[int] = tokenizer.encode(current_text, add_special_tokens=False) | |
# 2. Decode each token ID individually. | |
# We will check for REPLACEMENT_CHARACTER later. | |
all_decoded_tokens: list[str] = [ | |
tokenizer.decode( | |
[token_id], skip_special_tokens=False, clean_up_tokenization_spaces=False | |
) | |
for token_id in token_ids | |
] | |
total_token_count: int = len(token_ids) # Count based on IDs | |
# Limit the number of tokens for display | |
display_limit: int = 1000 | |
# Limit consistently using token IDs and the decoded tokens | |
display_token_ids: list[int] = token_ids[:display_limit] | |
display_decoded_tokens: list[str] = all_decoded_tokens[:display_limit] | |
display_limit_reached: bool = total_token_count > display_limit | |
# Generate data for visualization | |
TokenVisData = dict[str, Union[str, int, bool, dict[str, str]]] | |
llm_token_data: list[TokenVisData] = [] | |
# Use zip for parallel iteration | |
for idx, (token_id, token_str) in enumerate( | |
zip(display_token_ids, display_decoded_tokens) | |
): | |
colors: dict[str, str] = get_varied_color( | |
token_str | |
if REPLACEMENT_CHARACTER not in token_str | |
else f"invalid_{token_id}" | |
) # Color based on string or ID if invalid | |
is_invalid_utf8 = REPLACEMENT_CHARACTER in token_str | |
fixed_token_display: str | |
original_for_title: str = ( | |
token_str # Store the potentially problematic string for title | |
) | |
if is_invalid_utf8: | |
# If decode failed, show a representation with the hex ID | |
fixed_token_display = f"<0x{token_id:X}>" | |
else: | |
# If decode succeeded, apply standard fixes | |
fixed_token_display = fix_token(token_str, re) | |
llm_token_data.append( | |
{ | |
"original": original_for_title, # Store the raw decoded string (might contain �) | |
"display": fixed_token_display, # Store the cleaned/invalid representation | |
"colors": colors, | |
"is_newline": "↵" in fixed_token_display, # Check the display version | |
"token_id": token_id, | |
"token_index": idx, | |
"is_invalid": is_invalid_utf8, # Add flag for potential styling/title changes | |
} | |
) | |
# Calculate statistics using the list of *successfully* decoded token strings | |
# We might want to reconsider what `all_tokens` means for stats if many are invalid. | |
# For now, let's use the potentially problematic strings, as stats are mostly length/count based. | |
token_stats: dict[str, dict[str, Union[int, float]]] = get_token_stats( | |
all_decoded_tokens, | |
current_text, # Pass the full list from decode() | |
) | |
# Construct HTML for colored tokens using list comprehension (functional style) | |
html_parts: list[str] = [ | |
( | |
lambda item: ( | |
style | |
:= f"background-color: {item['colors']['background']}; color: {item['colors']['text']}; padding: 1px 3px; margin: 1px; border-radius: 3px; display: inline-block; white-space: pre-wrap; line-height: 1.4;" | |
# Add specific style for invalid tokens if needed | |
+ (" border: 1px solid red;" if item.get("is_invalid") else ""), | |
# Modify title based on validity | |
title := ( | |
f"Original: {item['original']}\nID: {item['token_id']}" | |
+ ("\n(Invalid UTF-8)" if item.get("is_invalid") else "") | |
+ ("\n(Byte Token)" if item["display"].startswith("byte[") else "") | |
), | |
display_content := str(item["token_id"]) | |
if show_ids_switch.value | |
else item["display"], | |
f'<span style="{style}" title="{title}">{display_content}</span>', | |
)[-1] # Get the last element (the formatted string) from the lambda's tuple | |
)(item) | |
for item in llm_token_data | |
] | |
token_viz_html: mo.Html = mo.Html( | |
f'<div style="line-height: 1.6;">{"".join(html_parts)}</div>' | |
) | |
# Optional: Add a warning if the display limit was reached | |
limit_warning: Optional[mo.md] = None # Use Optional type | |
if display_limit_reached: | |
limit_warning = mo.md(f"""**Warning:** Displaying only the first {display_limit:,} tokens out of {total_token_count:,}. | |
Statistics are calculated on the full text.""").callout(kind="warn") | |
# Use dict access safely with .get() for stats | |
basic_stats: dict[str, Union[int, float]] = token_stats.get("basic_stats", {}) | |
length_stats: dict[str, Union[int, float]] = token_stats.get("length_stats", {}) | |
# Use list comprehensions for markdown generation (functional style) | |
basic_stats_md: str = "**Basic Stats:**\n\n" + "\n".join( | |
f"- **{key.replace('_', ' ').title()}:** `{value}`" | |
for key, value in basic_stats.items() | |
) | |
length_stats_md: str = "**Length (Character) Stats:**\n\n" + "\n".join( | |
f"- **{key.replace('_', ' ').title()}:** `{value}`" | |
for key, value in length_stats.items() | |
) | |
# Build tokenizer info markdown parts | |
tokenizer_info_md_parts: list[str] = [ | |
f"**Tokenizer Type:** `{tokenizer_info.get('tokenizer_type', 'N/A')}`" | |
] | |
if vocab_size := tokenizer_info.get("vocab_size"): | |
tokenizer_info_md_parts.append(f"**Vocab Size:** `{vocab_size:,}`") | |
if max_len := tokenizer_info.get("model_max_length"): | |
tokenizer_info_md_parts.append(f"**Model Max Length:** `{max_len}`") | |
special_tokens_info = tokenizer_info.get("special_tokens") | |
if isinstance(special_tokens_info, dict) and special_tokens_info: | |
tokenizer_info_md_parts.append("**Special Tokens:**") | |
tokenizer_info_md_parts.extend( | |
f" - `{name}`: `{str(val)}`" for name, val in special_tokens_info.items() | |
) | |
elif isinstance(special_tokens_info, str): # Handle "None found" case | |
tokenizer_info_md_parts.append(f"**Special Tokens:** `{special_tokens_info}`") | |
if error_info := tokenizer_info.get("error"): | |
tokenizer_info_md_parts.append(f"**Info Error:** `{error_info}`") | |
tokenizer_info_md: str = "\n\n".join(tokenizer_info_md_parts) | |
# Display the final markdown output | |
mo.md(f"""# LLM tokenizer: {llm_tokenizer_selector.value} | |
## Tokenizer Info | |
{tokenizer_info_md} | |
{show_ids_switch} | |
## Tokenizer output | |
{limit_warning if limit_warning else ""} | |
{mo.as_html(token_viz_html)} | |
## Token Statistics | |
(Calculated on full text if truncated above) | |
{basic_stats_md} | |
{length_stats_md} | |
""") | |
return ( | |
all_decoded_tokens, | |
token_ids, | |
basic_stats_md, | |
display_limit_reached, | |
length_stats_md, | |
limit_warning, | |
llm_token_data, | |
token_stats, | |
token_viz_html, | |
tokenizer_info, | |
tokenizer_info_md, | |
total_token_count, | |
) | |
def _(): | |
return | |
if __name__ == "__main__": | |
app.run() | |