kisate
Add explanations
c98496e
raw
history blame
3.59 kB
import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
from datasets import load_dataset
import os
import numpy as np
token_table = pq.read_table("weights/tokens.parquet")
cache_path = "weights/caches"
parquets = os.listdir(cache_path)
TOKENIZER = "microsoft/Phi-3-mini-4k-instruct"
dataset = load_dataset("kisate-team/feature-explanations", split="train")
layers = dataset.unique("layer")
features = {layer:{item["feature"]:item for item in dataset if item["layer"] == layer} for layer in layers}
nearby = 8
stride = 0.25
n_bins = 10
def make_cache_name(layer):
return f"{cache_path}/phi-l{layer}-r4-st0.25x128-activations.parquet"
with gr.Blocks() as demo:
feature_table = gr.State(None)
tokenizer_name = gr.Textbox(TOKENIZER)
layer_dropdown = gr.Dropdown(layers)
feature_dropdown = gr.Dropdown()
def update_features(layer):
feature_dropdown = gr.Dropdown(features[layer].keys())
return feature_dropdown
layer_dropdown.input(update_features, layer_dropdown, feature_dropdown)
frequency = gr.Number(0, label="Total frequency (%)")
# histogram = gr.LinePlot(x="activation", y="freq")
autoi_expl = gr.Textbox()
selfe_expl = gr.Textbox()
cm = gr.HighlightedText()
frame = gr.Highlightedtext()
def update(layer, feature, tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
table = pq.read_table(make_cache_name(layer))
table_feat = table.filter(pc.field("feature") == feature).to_pandas()
# freq_t = table_feat[["activation", "freq"]]
total_freq = float(table_feat["freq"].sum()) * 100
table_feat = table_feat[table_feat["activation"] > 0]
table_feat = table_feat[table_feat["freq"] > 0]
table_feat = table_feat.sort_values("activation", ascending=False)
texts = table_feat["token"].apply(
lambda x: tokenizer.decode(token_table[max(0, x - nearby - 1):x + nearby + 1]["tokens"].to_numpy())
)
texts = [tokenizer.tokenize(text) for text in texts]
activations = table_feat["nearby"].to_numpy()
if len(activations) > 0:
activations = np.stack(activations) * stride
max_act = table_feat["activation"].max()
activations = activations / max_act
highlight_data = [
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
for text, activation in zip(texts, activations)
]
flat_data = [item for sublist in highlight_data for item in sublist]
color_map_data = [i / n_bins for i in range(n_bins + 1)]
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
else:
flat_data = []
color_map_data = []
autoi_expl = features[layer][feature]["explanation"]
selfe_expl = features[layer][feature]["gen_explanations"]
if selfe_expl is not None:
selfe_expl = "\n".join(
f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl)
)
return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl
feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
# feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency])
if __name__ == "__main__":
demo.launch(share=True)