Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
2af55e5
1
Parent(s):
d52b754
Working locally, TBD HF space
Browse files- app.py +39 -11
- bytelatent/plotting/entropy_figure_via_matplot_lib.py +4 -5
- demo_patcher.py +0 -1
app.py
CHANGED
@@ -1,14 +1,11 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
# Removed matplotlib and plot_entropies imports
|
5 |
|
6 |
-
# Assuming bytelatent library and its dependencies are installed
|
7 |
from bytelatent.data.file_util import get_fs
|
8 |
-
# from bytelatent.distributed import DistributedArgs, setup_torch_distributed # Not needed
|
9 |
from bytelatent.generate_patcher import patcher_nocache
|
10 |
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
11 |
-
|
12 |
from bytelatent.args import TrainArgs
|
13 |
from download_blt_weights import main as ensure_present
|
14 |
|
@@ -71,14 +68,20 @@ def process_text(prompt: str, model_name: str = "blt-1b"):
|
|
71 |
batch_patch_lengths, batch_scores, batch_tokens = results
|
72 |
# Decode the first (and only) result in the batch
|
73 |
decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
print("Processing and decoding complete.")
|
77 |
# --- End Processing ---
|
78 |
|
79 |
-
|
80 |
-
# Return the decoded text string
|
81 |
-
return decoded_output
|
82 |
|
83 |
except FileNotFoundError as e:
|
84 |
print(f"Error: {e}")
|
@@ -92,20 +95,45 @@ def process_text(prompt: str, model_name: str = "blt-1b"):
|
|
92 |
return f"An unexpected error occurred: {e}" # Return error as text output
|
93 |
|
94 |
|
95 |
-
# --- Gradio Interface Definition ---
|
96 |
iface = gr.Interface(
|
97 |
fn=process_text,
|
98 |
inputs=gr.Textbox(
|
99 |
label="Input Prompt",
|
100 |
placeholder="Enter your text here..."
|
101 |
),
|
102 |
-
|
103 |
-
outputs=gr.Text(label="Decoded Output"),
|
104 |
title="ByteLatent Text Processor",
|
105 |
description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
|
106 |
allow_flagging="never",
|
107 |
)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
# --- Launch the Gradio App ---
|
110 |
if __name__ == "__main__":
|
111 |
ensure_present(["blt-1b"])
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import torch
|
|
|
4 |
|
|
|
5 |
from bytelatent.data.file_util import get_fs
|
|
|
6 |
from bytelatent.generate_patcher import patcher_nocache
|
7 |
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
8 |
+
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
|
9 |
from bytelatent.args import TrainArgs
|
10 |
from download_blt_weights import main as ensure_present
|
11 |
|
|
|
68 |
batch_patch_lengths, batch_scores, batch_tokens = results
|
69 |
# Decode the first (and only) result in the batch
|
70 |
decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
|
71 |
+
fig = None
|
72 |
+
if decoded_chars_list:
|
73 |
+
decoded_output = decoded_chars_list[0]
|
74 |
+
fig = plot_entropies(
|
75 |
+
batch_patch_lengths[0],
|
76 |
+
batch_scores[0],
|
77 |
+
decoded_output,
|
78 |
+
threshold=patcher.threshold
|
79 |
+
)
|
80 |
|
81 |
print("Processing and decoding complete.")
|
82 |
# --- End Processing ---
|
83 |
|
84 |
+
return fig
|
|
|
|
|
85 |
|
86 |
except FileNotFoundError as e:
|
87 |
print(f"Error: {e}")
|
|
|
95 |
return f"An unexpected error occurred: {e}" # Return error as text output
|
96 |
|
97 |
|
|
|
98 |
iface = gr.Interface(
|
99 |
fn=process_text,
|
100 |
inputs=gr.Textbox(
|
101 |
label="Input Prompt",
|
102 |
placeholder="Enter your text here..."
|
103 |
),
|
104 |
+
outputs=gr.Plot(label="Entropy Plot"),
|
|
|
105 |
title="ByteLatent Text Processor",
|
106 |
description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
|
107 |
allow_flagging="never",
|
108 |
)
|
109 |
|
110 |
+
with gr.Blocks() as iface:
|
111 |
+
gr.Markdown("# ByteLatent Entropy Visualizer") # Title
|
112 |
+
gr.Markdown(
|
113 |
+
"Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
|
114 |
+
"and visualize the token entropies plot below.<br><br>" # Updated description
|
115 |
+
"NOTE: this implementation differs slightly by excluding local attention so we limit "
|
116 |
+
"the characters limit to 512 to avoid any deviation.",
|
117 |
+
line_breaks=True
|
118 |
+
)
|
119 |
+
|
120 |
+
with gr.Column():
|
121 |
+
prompt_input = gr.Textbox(
|
122 |
+
label="Input Prompt",
|
123 |
+
value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
|
124 |
+
placeholder="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
|
125 |
+
max_length=512
|
126 |
+
)
|
127 |
+
submit_button = gr.Button("Generate Plot") # Add button
|
128 |
+
plot_output = gr.Plot(label="Entropy w Threshold") # Output component
|
129 |
+
|
130 |
+
# Define the action when the button is clicked
|
131 |
+
submit_button.click(
|
132 |
+
fn=process_text,
|
133 |
+
inputs=prompt_input, # Input component(s)
|
134 |
+
outputs=plot_output # Output component(s)
|
135 |
+
)
|
136 |
+
|
137 |
# --- Launch the Gradio App ---
|
138 |
if __name__ == "__main__":
|
139 |
ensure_present(["blt-1b"])
|
bytelatent/plotting/entropy_figure_via_matplot_lib.py
CHANGED
@@ -59,9 +59,8 @@ def plot_entropies(patch_lengths: torch.Tensor, scores: torch.Tensor, chars: str
|
|
59 |
|
60 |
# Adjust layout and display the plot
|
61 |
plt.tight_layout()
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
-
# Close the plot figure to free memory (good practice)
|
67 |
-
plt.close(fig)
|
|
|
59 |
|
60 |
# Adjust layout and display the plot
|
61 |
plt.tight_layout()
|
62 |
+
return fig
|
63 |
+
# output_filename = "token_score_plot.png"
|
64 |
+
# fig.savefig(output_filename, dpi=300, bbox_inches='tight') # Save the figure
|
65 |
+
# print(f"Plot saved to {os.path.abspath(output_filename)}") # Print confirmation with full path
|
66 |
|
|
|
|
demo_patcher.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
|
3 |
-
import torch
|
4 |
import typer
|
5 |
|
6 |
from bytelatent.data.file_util import get_fs
|
|
|
1 |
import os
|
2 |
|
|
|
3 |
import typer
|
4 |
|
5 |
from bytelatent.data.file_util import get_fs
|