Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,129 Bytes
570eaa9 b074257 570eaa9 b074257 570eaa9 2af55e5 570eaa9 b074257 570eaa9 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 570eaa9 b074257 570eaa9 b074257 570eaa9 a528449 b074257 570eaa9 b074257 570eaa9 b074257 570eaa9 b074257 570eaa9 b074257 570eaa9 b074257 570eaa9 b074257 570eaa9 b074257 a528449 b074257 a528449 b074257 a528449 b074257 a528449 b074257 570eaa9 b074257 570eaa9 b074257 a528449 b074257 2af55e5 b074257 2af55e5 b074257 2af55e5 a528449 2af55e5 a528449 b074257 2af55e5 570eaa9 b074257 570eaa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
import os
import gradio as gr
import torch
import itertools # For color cycling
import tiktoken # For GPT-4 tokenizer
from transformers import AutoTokenizer, AutoModel # For Llama3 tokenizer
# Bytelatent imports (assuming they are in the python path)
from bytelatent.data.file_util import get_fs
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
from bytelatent.args import TrainArgs
from download_blt_weights import main as ensure_present
# --- Global Setup ---
# Define colors for patches/tokens
VIZ_COLORS = [
"#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
"#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
] # Add more if you expect many segments
LLAMA3_MODEL_NAME = "meta-llama/Meta-Llama-3-8B" # Or choose another variant like Instruct
# --- Helper Functions ---
def create_bytelatent_highlight_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
"""Generates data for gr.HighlightedText based on bytelatent patches."""
# (Keep the function from the previous version - no changes needed)
if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
return None
patch_lengths = patch_lengths_tensor.tolist()
all_tokens = tokens_tensor.tolist()
highlighted_data = []
current_token_index = 0
color_cycler = itertools.cycle(colors)
for i, length in enumerate(patch_lengths):
if length <= 0: continue
patch_token_ids = all_tokens[current_token_index : current_token_index + length]
if not patch_token_ids: continue
try: patch_text = tokenizer.decode(patch_token_ids)
except Exception as decode_err:
print(f"Warning: Bytelatent patch decoding failed: {decode_err}")
patch_text = f"[Decode Error: {len(patch_token_ids)} tokens]"
patch_label = f"BL Patch {i+1}"
highlighted_data.append((patch_text, patch_label))
current_token_index += length
if current_token_index != len(all_tokens):
print(f"Warning: Bytelatent token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
remaining_tokens = all_tokens[current_token_index:]
if remaining_tokens:
try: remaining_text = tokenizer.decode(remaining_tokens)
except Exception: remaining_text = f"[Decode Error: {len(remaining_tokens)} remaining tokens]"
highlighted_data.append((remaining_text, "BL Remainder"))
return highlighted_data
def create_tiktoken_highlight_data(prompt, colors):
"""Generates data for gr.HighlightedText based on tiktoken (gpt-4) tokens."""
# (Keep the function from the previous version - no changes needed)
try:
enc = tiktoken.get_encoding("cl100k_base")
tiktoken_ids = enc.encode(prompt)
highlighted_data = []
color_cycler = itertools.cycle(colors)
for i, token_id in enumerate(tiktoken_ids):
try: token_text = enc.decode([token_id])
except UnicodeDecodeError:
try:
token_bytes = enc.decode_single_token_bytes(token_id)
token_text = f"[Bytes: {token_bytes.hex()}]"
except Exception: token_text = "[Decode Error]"
except Exception as e:
print(f"Unexpected tiktoken decode error: {e}")
token_text = "[Decode Error]"
token_label = f"GPT4 Tk {i+1}"
highlighted_data.append((token_text, token_label))
print(f"Tiktoken processing complete. Found {len(tiktoken_ids)} tokens.")
return highlighted_data
except ImportError:
print("Error: tiktoken library not found. Please install it: pip install tiktoken")
return [("tiktoken library not installed.", "Error")]
except Exception as tiktoken_err:
print(f"Error during tiktoken processing: {tiktoken_err}")
return [(f"Error processing with tiktoken: {str(tiktoken_err)}", "Error")]
def create_llama3_highlight_data(prompt, colors, model_name=LLAMA3_MODEL_NAME):
"""Generates data for gr.HighlightedText based on Llama 3 tokenizer."""
try:
# Load Llama 3 tokenizer from Hugging Face Hub
# This might download the tokenizer files on the first run
# May require `huggingface-cli login` if model is private or gated
print(f"Loading Llama 3 tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Llama 3 tokenizer loaded.")
# Encode the prompt
llama_token_ids = tokenizer.encode(prompt)
highlighted_data = []
color_cycler = itertools.cycle(colors)
for i, token_id in enumerate(llama_token_ids):
try:
# Decode individual token. Llama/SentencePiece tokenizers usually handle this well.
token_text = tokenizer.decode([token_id])
# Special case: Handle potential leading space added by sentencepiece during decode
# if token_text.startswith(' '): # Check if this improves visualization
# token_text = token_text[1:] # Remove leading space visual artifact? Test this.
except Exception as e:
print(f"Unexpected Llama 3 decode error for token {token_id}: {e}")
token_text = "[Decode Error]"
token_label = f"Llama3 Tk {i+1}" # Clearer label prefix
highlighted_data.append((token_text, token_label))
print(f"Llama 3 processing complete. Found {len(llama_token_ids)} tokens.")
return highlighted_data
except ImportError:
print("Error: transformers or sentencepiece library not found. Please install them: pip install transformers sentencepiece")
return [("transformers/sentencepiece library not installed.", "Error")]
except OSError as e:
# Handle errors like model not found, network issues, authentication needed
print(f"Error loading Llama 3 tokenizer '{model_name}': {e}")
if "authentication" in str(e).lower():
return [(f"Authentication required for Llama 3 tokenizer '{model_name}'. Use `huggingface-cli login`.", "Error")]
else:
return [(f"Could not load Llama 3 tokenizer '{model_name}'. Check model name and network. Error: {e}", "Error")]
except Exception as llama_err:
print(f"Error during Llama 3 processing: {llama_err}")
import traceback
traceback.print_exc() # Print full traceback for debugging
return [(f"Error processing with Llama 3: {str(llama_err)}", "Error")]
# --- Main Processing Function ---
def process_text(prompt: str, model_name: str = "blt-1b"):
"""
Processes the input prompt using ByteLatent, Tiktoken, and Llama 3,
returning visualizations and status.
Args:
prompt: The input text string from the Gradio interface.
model_name: The name of the bytelatent model to use.
Returns:
A tuple containing:
- Matplotlib Figure for the entropy plot (or None).
- List of tuples for bytelatent gr.HighlightedText (or None).
- List of tuples for tiktoken gr.HighlightedText (or None).
- List of tuples for Llama 3 gr.HighlightedText (or None).
- Status/Error message string.
"""
fig = None
bl_highlighted_data = None
tk_highlighted_data = None
llama_highlighted_data = None
status_message = "Starting processing..."
# --- 1. Tiktoken Processing (Independent) ---
status_message += "\nProcessing with Tiktoken (gpt-4)..."
tk_highlighted_data = create_tiktoken_highlight_data(prompt, VIZ_COLORS)
if tk_highlighted_data and tk_highlighted_data[0][1] == "Error":
status_message += f"\nTiktoken Error: {tk_highlighted_data[0][0]}"
else:
status_message += "\nTiktoken processing successful."
# --- 2. Llama 3 Processing (Independent) ---
status_message += "\nProcessing with Llama 3 tokenizer..."
llama_highlighted_data = create_llama3_highlight_data(prompt, VIZ_COLORS)
if llama_highlighted_data and llama_highlighted_data[0][1] == "Error":
status_message += f"\nLlama 3 Error: {llama_highlighted_data[0][0]}"
else:
status_message += "\nLlama 3 processing successful."
# --- 3. Bytelatent Processing ---
try:
status_message += f"\nLoading entropy model for '{model_name}'..."
# (Bytelatent loading code remains the same as previous version)
consolidated_path = os.path.join("hf-weights", model_name)
train_args_path = os.path.join(consolidated_path, "params.json")
if not os.path.exists(train_args_path): raise FileNotFoundError(f"Bytelatent training args not found at {train_args_path}.")
fs = get_fs(train_args_path); train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
bl_tokenizer = train_args.data.tokenizer_args.build(); assert isinstance(bl_tokenizer, BltTokenizer)
patcher_args = train_args.data.patcher_args.model_copy(deep=True); patcher_args.realtime_patching = True
device = "cuda" if torch.cuda.is_available() else "cpu"; print(f"Using Bytelatent device: {device}")
patcher_args.patching_device = device; patcher_args.device = device
entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
if not os.path.exists(entropy_model_dir): raise FileNotFoundError(f"Bytelatent entropy model directory not found at {entropy_model_dir}.")
patcher_args.entropy_model_checkpoint_dir = entropy_model_dir; bl_patcher = patcher_args.build()
status_message += "\nBytelatent model loaded."
# --- Processing ---
status_message += "\nRunning Bytelatent patching..."
print(f"Processing prompt with Bytelatent: '{prompt}'")
# Limit prompt length for bytelatent if necessary
prompt_bytes = prompt.encode('utf-8')
if len(prompt_bytes) > 512:
print(f"Warning: Prompt exceeds 512 bytes ({len(prompt_bytes)}). Truncating for Bytelatent.")
prompt_bl = prompt_bytes[:512].decode('utf-8', errors='ignore')
status_message += "\nWarning: Prompt truncated to 512 bytes for Bytelatent."
else:
prompt_bl = prompt
results = patcher_nocache([prompt_bl], tokenizer=bl_tokenizer, patcher=bl_patcher)
if not results:
print("Bytelatent processing returned no results.")
status_message += "\nBytelatent Warning: Processing completed, but no results were generated."
else:
batch_patch_lengths, batch_scores, batch_tokens = results
patch_lengths, scores, tokens = batch_patch_lengths[0], batch_scores[0], batch_tokens[0]
# --- Visualization Data Generation ---
try: decoded_output_for_plot = bl_tokenizer.decode(tokens.tolist())
except Exception as decode_err:
print(f"Warning: Error decoding full sequence for plot: {decode_err}")
decoded_output_for_plot = prompt_bl # Use truncated prompt for plot if decode fails
fig = plot_entropies(patch_lengths, scores, decoded_output_for_plot, threshold=bl_patcher.threshold)
bl_highlighted_data = create_bytelatent_highlight_data(bl_tokenizer, patch_lengths, tokens, VIZ_COLORS)
status_message += "\nBytelatent processing and visualization successful."
print("Bytelatent processing and decoding complete.")
except FileNotFoundError as e:
print(f"Bytelatent Error: {e}")
status_message += f"\nBytelatent FileNotFoundError: {str(e)}"
except Exception as e:
print(f"An unexpected Bytelatent error occurred: {e}")
import traceback
traceback.print_exc()
status_message += f"\nBytelatent Unexpected Error: {str(e)}"
# Return all generated data and the final status message
return fig, bl_highlighted_data, tk_highlighted_data, llama_highlighted_data, status_message
# --- Gradio Interface ---
# Create color maps for HighlightedText dynamically
MAX_EXPECTED_SEGMENTS = 1000 # Increase max expected segments further
common_error_map = {"Error": "#FF0000"} # Red for errors
bytelatent_color_map = {f"BL Patch {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
bytelatent_color_map["BL Remainder"] = "#808080"; bytelatent_color_map.update(common_error_map)
tiktoken_color_map = {f"GPT4 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
tiktoken_color_map.update(common_error_map)
llama3_color_map = {f"Llama3 Tk {i+1}": color for i, color in zip(range(MAX_EXPECTED_SEGMENTS), itertools.cycle(VIZ_COLORS))}
llama3_color_map.update(common_error_map)
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# BLT's Entropy Patcher Visualisation") # Updated Title
gr.Markdown(
"Enter text to visualize its segmentation according to different tokenizers:\n"
"1. **BLT:** Entropy plot and text segmented by dynamic patches (Input limited to 512 bytes).\n"
"2. **Tiktoken (GPT-4):** Text segmented by `cl100k_base` tokens.\n"
"3. **Llama 3:** Text segmented by the `meta-llama/Meta-Llama-3-8B` tokenizer."
)
with gr.Row():
with gr.Column(scale=1): # Input Column
prompt_input = gr.Textbox(
label="Input Prompt",
value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
placeholder="Enter text here...",
max_length=2048, # Allow even longer input, Bytelatent will truncate
lines=5,
info="Processing is limited to the first 512 bytes of the input."
)
submit_button = gr.Button("Generate Visualizations", variant="primary")
status_output = gr.Textbox(label="Processing Status", interactive=False, lines=5)
with gr.Column(scale=2): # Output Column
gr.Markdown("### BLT's Entropy Patcher Output (`100m`)")
highlighted_output_bl = gr.HighlightedText(
label="Bytelatent Patched Text",
color_map=bytelatent_color_map,
show_legend=False, # Legend can get very long, disable for compactness
show_inline_category=False,
)
plot_output = gr.Plot(label="Bytelatent Entropy vs. Token Index")
gr.Markdown("### Tiktoken Output (`cl100k_base` for GPT-4)")
highlighted_output_tk = gr.HighlightedText(
label="Tiktoken Segmented Text",
color_map=tiktoken_color_map,
show_legend=False,
show_inline_category=False,
)
gr.Markdown(f"### Llama 3 Output (`{LLAMA3_MODEL_NAME}`)")
highlighted_output_llama = gr.HighlightedText(
label="Llama 3 Segmented Text",
color_map=llama3_color_map,
show_legend=False,
show_inline_category=False,
)
# Define the action for the button click
submit_button.click(
fn=process_text,
inputs=prompt_input,
# Ensure order matches the 5 return values of process_text
outputs=[
plot_output,
highlighted_output_bl,
highlighted_output_tk,
highlighted_output_llama,
status_output
]
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
print("Please ensure 'tiktoken', 'transformers', and 'sentencepiece' are installed (`pip install tiktoken transformers sentencepiece`)")
print(f"Attempting to use Llama 3 Tokenizer: {LLAMA3_MODEL_NAME}. Ensure you have access (e.g., via `huggingface-cli login` if needed).")
ensure_present(["blt-1b"]) # Ensure bytelatent model is present
iface.launch()
|