kisate
Test app
73ab266
raw
history blame
2.77 kB
import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
import os
import numpy as np
token_table = pq.read_table("weights/tokens.parquet")
cache_path = "weights/caches"
parquets = os.listdir(cache_path)
TOKENIZER = "microsoft/Phi-3-mini-4k-instruct"
nearby = 8
stride = 0.25
n_bins = 10
with gr.Blocks() as demo:
feature_table = gr.State(None)
tokenizer_name = gr.Textbox(TOKENIZER)
dropdown = gr.Dropdown(parquets)
feature_input = gr.Number(0)
token_range = gr.Number(64)
frequency = gr.Number(0, label="Total frequency (%)")
histogram = gr.LinePlot(x="activation", y="freq")
cm = gr.HighlightedText()
frame = gr.Highlightedtext(
show_legend=True
)
def update(cache_name, feature, tokenizer_name, token_range):
if cache_name is None:
return
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
table = pq.read_table(f"{cache_path}/{cache_name}")
table_feat = table.filter(pc.field("feature") == feature).to_pandas()
freq_t = table_feat[["activation", "freq"]]
total_freq = float(table_feat["freq"].sum()) * 100
table_feat = table_feat[table_feat["activation"] > 0]
table_feat = table_feat[table_feat["freq"] > 0]
table_feat = table_feat.sort_values("activation", ascending=False)
texts = table_feat["token"].apply(
lambda x: tokenizer.decode(token_table[max(0, x - nearby - 1):x + nearby + 1]["tokens"].to_numpy())
)
texts = [tokenizer.tokenize(text) for text in texts]
activations = table_feat["nearby"].to_numpy()
if len(activations) > 0:
activations = np.stack(activations) * stride
max_act = table_feat["activation"].max()
activations = activations / max_act
highlight_data = [
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
for text, activation in zip(texts, activations)
]
flat_data = [item for sublist in highlight_data for item in sublist]
color_map_data = [i / n_bins for i in range(n_bins + 1)]
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
else:
flat_data = []
color_map_data = []
return flat_data, color_map_data, freq_t, total_freq
dropdown.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
if __name__ == "__main__":
demo.launch(share=True)