lucalp commited on
Commit
a528449
·
1 Parent(s): 2af55e5
Files changed (1) hide show
  1. app.py +133 -43
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import gradio as gr
3
  import torch
 
4
 
5
  from bytelatent.data.file_util import get_fs
6
  from bytelatent.generate_patcher import patcher_nocache
@@ -12,16 +13,78 @@ from download_blt_weights import main as ensure_present
12
  # --- Global Setup (Consider loading models outside if necessary) ---
13
  # Kept inside the function for simplicity as before.
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def process_text(prompt: str, model_name: str = "blt-1b"):
16
  """
17
- Processes the input prompt using the ByteLatent model and returns decoded characters.
 
18
 
19
  Args:
20
  prompt: The input text string from the Gradio interface.
21
  model_name: The name of the model to use.
22
 
23
  Returns:
24
- A string containing the decoded characters after processing, or an error message.
 
 
 
25
  """
26
  try:
27
  # --- Model and Tokenizer Loading ---
@@ -63,55 +126,69 @@ def process_text(prompt: str, model_name: str = "blt-1b"):
63
 
64
  if not results:
65
  print("Processing returned no results.")
66
- return "Processing completed, but no results were generated." # Return info message
67
 
68
  batch_patch_lengths, batch_scores, batch_tokens = results
69
- # Decode the first (and only) result in the batch
70
- decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
71
- fig = None
72
- if decoded_chars_list:
73
- decoded_output = decoded_chars_list[0]
74
- fig = plot_entropies(
75
- batch_patch_lengths[0],
76
- batch_scores[0],
77
- decoded_output,
78
- threshold=patcher.threshold
79
- )
80
-
81
- print("Processing and decoding complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # --- End Processing ---
83
 
84
- return fig
85
 
86
  except FileNotFoundError as e:
87
  print(f"Error: {e}")
88
- # raise gr.Error(str(e)) # Display specific error in Gradio UI
89
- return f"Error: {str(e)}" # Return error as text output
90
  except Exception as e:
91
  print(f"An unexpected error occurred: {e}")
92
  import traceback
93
  traceback.print_exc()
94
- # raise gr.Error(f"An error occurred during processing: {e}")
95
- return f"An unexpected error occurred: {e}" # Return error as text output
96
-
97
-
98
- iface = gr.Interface(
99
- fn=process_text,
100
- inputs=gr.Textbox(
101
- label="Input Prompt",
102
- placeholder="Enter your text here..."
103
- ),
104
- outputs=gr.Plot(label="Entropy Plot"),
105
- title="ByteLatent Text Processor",
106
- description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
107
- allow_flagging="never",
108
- )
109
 
110
  with gr.Blocks() as iface:
111
  gr.Markdown("# ByteLatent Entropy Visualizer") # Title
112
  gr.Markdown(
113
  "Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
114
- "and visualize the token entropies plot below.<br><br>" # Updated description
115
  "NOTE: this implementation differs slightly by excluding local attention so we limit "
116
  "the characters limit to 512 to avoid any deviation.",
117
  line_breaks=True
@@ -121,20 +198,33 @@ with gr.Blocks() as iface:
121
  prompt_input = gr.Textbox(
122
  label="Input Prompt",
123
  value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
124
- placeholder="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
125
- max_length=512
 
126
  )
127
- submit_button = gr.Button("Generate Plot") # Add button
128
- plot_output = gr.Plot(label="Entropy w Threshold") # Output component
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- # Define the action when the button is clicked
131
  submit_button.click(
132
  fn=process_text,
133
- inputs=prompt_input, # Input component(s)
134
- outputs=plot_output # Output component(s)
135
  )
136
 
137
  # --- Launch the Gradio App ---
138
  if __name__ == "__main__":
139
- ensure_present(["blt-1b"])
140
  iface.launch()
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import itertools # Import itertools for color cycling
5
 
6
  from bytelatent.data.file_util import get_fs
7
  from bytelatent.generate_patcher import patcher_nocache
 
13
  # --- Global Setup (Consider loading models outside if necessary) ---
14
  # Kept inside the function for simplicity as before.
15
 
16
+ # Define colors for patches (similar to the image style)
17
+ # Using colors from a qualitative colormap (e.g., Colorbrewer Set3 or Paired)
18
+ PATCH_COLORS = [
19
+ "#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
20
+ "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
21
+ ] # Add more if you expect many patches
22
+
23
+
24
+ def create_highlighted_text_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
25
+ """
26
+ Generates the data structure needed for gr.HighlightedText based on patches.
27
+
28
+ Args:
29
+ tokenizer: The BltTokenizer instance.
30
+ patch_lengths_tensor: Tensor containing the length of each patch (in tokens).
31
+ tokens_tensor: Tensor containing the token IDs for the entire sequence.
32
+ colors: A list of color hex codes to cycle through.
33
+
34
+ Returns:
35
+ A list of tuples for gr.HighlightedText, e.g., [(text, label), ...].
36
+ Returns None if input tensors are invalid.
37
+ """
38
+ if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
39
+ return None
40
+
41
+ patch_lengths = patch_lengths_tensor.tolist()
42
+ all_tokens = tokens_tensor.tolist()
43
+ highlighted_data = []
44
+ current_token_index = 0
45
+ color_cycler = itertools.cycle(colors) # Use itertools to cycle through colors
46
+
47
+ for i, length in enumerate(patch_lengths):
48
+ if length <= 0: # Skip empty patches if they somehow occur
49
+ continue
50
+ patch_token_ids = all_tokens[current_token_index : current_token_index + length]
51
+ if not patch_token_ids: # Should not happen if length > 0, but good practice
52
+ continue
53
+
54
+ patch_text = tokenizer.decode(patch_token_ids)
55
+ patch_label = f"Patch {i+1}" # Unique label for each patch
56
+ patch_color = next(color_cycler) # Get the next color
57
+
58
+ # Add to highlighted_data: (text, label_for_coloring)
59
+ highlighted_data.append((patch_text, patch_label))
60
+ current_token_index += length
61
+
62
+ # Check if all tokens were consumed (optional sanity check)
63
+ if current_token_index != len(all_tokens):
64
+ print(f"Warning: Token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
65
+ # Decode any remaining tokens if necessary, though this indicates a logic issue
66
+ remaining_tokens = all_tokens[current_token_index:]
67
+ if remaining_tokens:
68
+ remaining_text = tokenizer.decode(remaining_tokens)
69
+ highlighted_data.append((remaining_text, "Remainder")) # Assign a generic label
70
+
71
+ return highlighted_data
72
+
73
+
74
  def process_text(prompt: str, model_name: str = "blt-1b"):
75
  """
76
+ Processes the input prompt using the ByteLatent model and returns
77
+ an entropy plot and color-coded text data.
78
 
79
  Args:
80
  prompt: The input text string from the Gradio interface.
81
  model_name: The name of the model to use.
82
 
83
  Returns:
84
+ A tuple containing:
85
+ - Matplotlib Figure for the entropy plot (or None on error).
86
+ - List of tuples for gr.HighlightedText (or None on error/no results).
87
+ - Error message string (or None if successful).
88
  """
89
  try:
90
  # --- Model and Tokenizer Loading ---
 
126
 
127
  if not results:
128
  print("Processing returned no results.")
129
+ return None, None, "Processing completed, but no results were generated."
130
 
131
  batch_patch_lengths, batch_scores, batch_tokens = results
132
+
133
+ # Process the first (and only) result in the batch
134
+ patch_lengths = batch_patch_lengths[0]
135
+ scores = batch_scores[0]
136
+ tokens = batch_tokens[0]
137
+
138
+ # Decode the full output once for the plot labels (if needed by plot_entropies)
139
+ # Note: BltTokenizer might decode directly to bytes, then utf-8. Ensure it handles errors.
140
+ try:
141
+ # Using the raw tokens tensor for decoding consistency
142
+ decoded_output_for_plot = tokenizer.decode(tokens.tolist())
143
+ except Exception as decode_err:
144
+ print(f"Warning: Error decoding full sequence for plot: {decode_err}")
145
+ # Fallback: attempt to decode the original prompt if possible, or use generic labels
146
+ decoded_output_for_plot = prompt # Use original prompt as fallback
147
+
148
+ # Generate the plot
149
+ fig = plot_entropies(
150
+ patch_lengths,
151
+ scores,
152
+ decoded_output_for_plot, # Pass the decoded string for plot labels
153
+ threshold=patcher.threshold
154
+ )
155
+
156
+ # Generate data for HighlightedText
157
+ highlighted_data = create_highlighted_text_data(
158
+ tokenizer, patch_lengths, tokens, PATCH_COLORS
159
+ )
160
+
161
+ print("Processing and visualization data generation complete.")
162
  # --- End Processing ---
163
 
164
+ return fig, highlighted_data, None # Return plot, highlighted text data, no error
165
 
166
  except FileNotFoundError as e:
167
  print(f"Error: {e}")
168
+ return None, None, f"Error: {str(e)}" # Return None for plot/text, error message
 
169
  except Exception as e:
170
  print(f"An unexpected error occurred: {e}")
171
  import traceback
172
  traceback.print_exc()
173
+ return None, None, f"An unexpected error occurred: {e}" # Return None for plot/text, error message
174
+
175
+ # --- Gradio Interface ---
176
+
177
+ # Create the color map for HighlightedText dynamically
178
+ # Generate enough patch labels and map them to the cycled colors
179
+ MAX_EXPECTED_PATCHES = 50 # Estimate a reasonable maximum
180
+ color_map = {
181
+ f"Patch {i+1}": color
182
+ for i, color in zip(range(MAX_EXPECTED_PATCHES), itertools.cycle(PATCH_COLORS))
183
+ }
184
+ # Add a color for the potential 'Remainder' label from create_highlighted_text_data
185
+ color_map["Remainder"] = "#808080" # Grey for any leftovers
 
 
186
 
187
  with gr.Blocks() as iface:
188
  gr.Markdown("# ByteLatent Entropy Visualizer") # Title
189
  gr.Markdown(
190
  "Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
191
+ "and visualize the token entropies plot and color-coded patches below.<br><br>" # Updated description
192
  "NOTE: this implementation differs slightly by excluding local attention so we limit "
193
  "the characters limit to 512 to avoid any deviation.",
194
  line_breaks=True
 
198
  prompt_input = gr.Textbox(
199
  label="Input Prompt",
200
  value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
201
+ placeholder="Enter text here...",
202
+ max_length=512,
203
+ lines=3
204
  )
205
+ submit_button = gr.Button("Generate Visualization") # Update button text
206
+
207
+ # Output for error messages or status
208
+ status_output = gr.Textbox(label="Status", interactive=False)
209
+
210
+ # Output component for the color-coded text
211
+ highlighted_output = gr.HighlightedText(
212
+ label="Patched Text Visualization",
213
+ color_map=color_map,
214
+ show_legend=False # Show the patch labels and colors
215
+ )
216
+
217
+ # Output component for the plot
218
+ plot_output = gr.Plot(label="Entropy vs. Token Index (with Patch Threshold)")
219
 
220
+ # Define the action for the button click
221
  submit_button.click(
222
  fn=process_text,
223
+ inputs=prompt_input,
224
+ outputs=[plot_output, highlighted_output, status_output] # Order matters!
225
  )
226
 
227
  # --- Launch the Gradio App ---
228
  if __name__ == "__main__":
229
+ ensure_present(["blt-1b"]) # Ensure model is present before launching
230
  iface.launch()