luca-peric commited on
Commit
2af55e5
·
1 Parent(s): d52b754

Working locally, TBD HF space

Browse files
app.py CHANGED
@@ -1,14 +1,11 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- # Removed matplotlib and plot_entropies imports
5
 
6
- # Assuming bytelatent library and its dependencies are installed
7
  from bytelatent.data.file_util import get_fs
8
- # from bytelatent.distributed import DistributedArgs, setup_torch_distributed # Not needed
9
  from bytelatent.generate_patcher import patcher_nocache
10
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
11
- # Removed: from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
12
  from bytelatent.args import TrainArgs
13
  from download_blt_weights import main as ensure_present
14
 
@@ -71,14 +68,20 @@ def process_text(prompt: str, model_name: str = "blt-1b"):
71
  batch_patch_lengths, batch_scores, batch_tokens = results
72
  # Decode the first (and only) result in the batch
73
  decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
74
- decoded_output = decoded_chars_list[0] if decoded_chars_list else "No characters decoded."
 
 
 
 
 
 
 
 
75
 
76
  print("Processing and decoding complete.")
77
  # --- End Processing ---
78
 
79
-
80
- # Return the decoded text string
81
- return decoded_output
82
 
83
  except FileNotFoundError as e:
84
  print(f"Error: {e}")
@@ -92,20 +95,45 @@ def process_text(prompt: str, model_name: str = "blt-1b"):
92
  return f"An unexpected error occurred: {e}" # Return error as text output
93
 
94
 
95
- # --- Gradio Interface Definition ---
96
  iface = gr.Interface(
97
  fn=process_text,
98
  inputs=gr.Textbox(
99
  label="Input Prompt",
100
  placeholder="Enter your text here..."
101
  ),
102
- # Changed output to display the decoded text
103
- outputs=gr.Text(label="Decoded Output"),
104
  title="ByteLatent Text Processor",
105
  description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
106
  allow_flagging="never",
107
  )
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # --- Launch the Gradio App ---
110
  if __name__ == "__main__":
111
  ensure_present(["blt-1b"])
 
1
  import os
2
  import gradio as gr
3
  import torch
 
4
 
 
5
  from bytelatent.data.file_util import get_fs
 
6
  from bytelatent.generate_patcher import patcher_nocache
7
  from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
8
+ from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
9
  from bytelatent.args import TrainArgs
10
  from download_blt_weights import main as ensure_present
11
 
 
68
  batch_patch_lengths, batch_scores, batch_tokens = results
69
  # Decode the first (and only) result in the batch
70
  decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
71
+ fig = None
72
+ if decoded_chars_list:
73
+ decoded_output = decoded_chars_list[0]
74
+ fig = plot_entropies(
75
+ batch_patch_lengths[0],
76
+ batch_scores[0],
77
+ decoded_output,
78
+ threshold=patcher.threshold
79
+ )
80
 
81
  print("Processing and decoding complete.")
82
  # --- End Processing ---
83
 
84
+ return fig
 
 
85
 
86
  except FileNotFoundError as e:
87
  print(f"Error: {e}")
 
95
  return f"An unexpected error occurred: {e}" # Return error as text output
96
 
97
 
 
98
  iface = gr.Interface(
99
  fn=process_text,
100
  inputs=gr.Textbox(
101
  label="Input Prompt",
102
  placeholder="Enter your text here..."
103
  ),
104
+ outputs=gr.Plot(label="Entropy Plot"),
 
105
  title="ByteLatent Text Processor",
106
  description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
107
  allow_flagging="never",
108
  )
109
 
110
+ with gr.Blocks() as iface:
111
+ gr.Markdown("# ByteLatent Entropy Visualizer") # Title
112
+ gr.Markdown(
113
+ "Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
114
+ "and visualize the token entropies plot below.<br><br>" # Updated description
115
+ "NOTE: this implementation differs slightly by excluding local attention so we limit "
116
+ "the characters limit to 512 to avoid any deviation.",
117
+ line_breaks=True
118
+ )
119
+
120
+ with gr.Column():
121
+ prompt_input = gr.Textbox(
122
+ label="Input Prompt",
123
+ value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
124
+ placeholder="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
125
+ max_length=512
126
+ )
127
+ submit_button = gr.Button("Generate Plot") # Add button
128
+ plot_output = gr.Plot(label="Entropy w Threshold") # Output component
129
+
130
+ # Define the action when the button is clicked
131
+ submit_button.click(
132
+ fn=process_text,
133
+ inputs=prompt_input, # Input component(s)
134
+ outputs=plot_output # Output component(s)
135
+ )
136
+
137
  # --- Launch the Gradio App ---
138
  if __name__ == "__main__":
139
  ensure_present(["blt-1b"])
bytelatent/plotting/entropy_figure_via_matplot_lib.py CHANGED
@@ -59,9 +59,8 @@ def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str
59
 
60
  # Adjust layout and display the plot
61
  plt.tight_layout()
62
- output_filename = "token_score_plot.png"
63
- fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
64
- print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
 
65
 
66
- # Close the plot figure to free memory (good practice)
67
- plt.close(fig)
 
59
 
60
  # Adjust layout and display the plot
61
  plt.tight_layout()
62
+ return fig
63
+ # output_filename = "token_score_plot.png"
64
+ # fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
65
+ # print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
66
 
 
 
demo_patcher.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
 
3
- import torch
4
  import typer
5
 
6
  from bytelatent.data.file_util import get_fs
 
1
  import os
2
 
 
3
  import typer
4
 
5
  from bytelatent.data.file_util import get_fs