Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pyarrow.parquet as pq | |
import pyarrow.compute as pc | |
from transformers import AutoTokenizer | |
import os | |
import numpy as np | |
token_table = pq.read_table("weights/tokens.parquet") | |
cache_path = "weights/caches" | |
parquets = os.listdir(cache_path) | |
TOKENIZER = "microsoft/Phi-3-mini-4k-instruct" | |
nearby = 8 | |
stride = 0.25 | |
n_bins = 10 | |
with gr.Blocks() as demo: | |
feature_table = gr.State(None) | |
tokenizer_name = gr.Textbox(TOKENIZER) | |
dropdown = gr.Dropdown(parquets) | |
feature_input = gr.Number(0) | |
token_range = gr.Number(64) | |
frequency = gr.Number(0, label="Total frequency (%)") | |
histogram = gr.LinePlot(x="activation", y="freq") | |
cm = gr.HighlightedText() | |
frame = gr.Highlightedtext( | |
show_legend=True | |
) | |
def update(cache_name, feature, tokenizer_name, token_range): | |
if cache_name is None: | |
return | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
table = pq.read_table(f"{cache_path}/{cache_name}") | |
table_feat = table.filter(pc.field("feature") == feature).to_pandas() | |
freq_t = table_feat[["activation", "freq"]] | |
total_freq = float(table_feat["freq"].sum()) * 100 | |
table_feat = table_feat[table_feat["activation"] > 0] | |
table_feat = table_feat[table_feat["freq"] > 0] | |
table_feat = table_feat.sort_values("activation", ascending=False) | |
texts = table_feat["token"].apply( | |
lambda x: tokenizer.decode(token_table[max(0, x - nearby - 1):x + nearby + 1]["tokens"].to_numpy()) | |
) | |
texts = [tokenizer.tokenize(text) for text in texts] | |
activations = table_feat["nearby"].to_numpy() | |
if len(activations) > 0: | |
activations = np.stack(activations) * stride | |
max_act = table_feat["activation"].max() | |
activations = activations / max_act | |
highlight_data = [ | |
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)] | |
for text, activation in zip(texts, activations) | |
] | |
flat_data = [item for sublist in highlight_data for item in sublist] | |
color_map_data = [i / n_bins for i in range(n_bins + 1)] | |
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data] | |
else: | |
flat_data = [] | |
color_map_data = [] | |
return flat_data, color_map_data, freq_t, total_freq | |
dropdown.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency]) | |
feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency]) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |