Dmitrii
add our gemma residuals
157fcd6
raw
history blame contribute delete
5.88 kB
import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
from datasets import load_dataset
import os
import numpy as np
cache_path = "weights/caches"
parquets = os.listdir(cache_path)
dataset = load_dataset("kisate-team/feature-explanations", split="train")
def find_revions():
revisions = set()
for parquet in parquets:
if parquet.endswith(".parquet"):
parts = parquet.split("-")
if len(parts) > 2:
revisions.add(int(parts[-3][1:]))
return sorted(revisions)
def find_layers(revision):
layers = set()
for parquet in parquets:
if parquet.endswith(".parquet"):
parts = parquet.split("-")
if len(parts) > 2 and int(parts[-3][1:]) == revision:
layers.add(int(parts[-4][1:]))
return sorted(layers)
revisions = find_revions()
layers = {
revision: find_layers(revision) for revision in revisions
}
features = {
revision: {
layer: {
item["feature"]:item for item in dataset if item["layer"] == layer and item["version"] == revision
} for layer in layers[revision]
} for revision in revisions
}
# layers = dataset.unique("layer")
nearby = 8
stride = 0.25
n_bins = 10
def make_cache_name(layer, revision, model):
return f"{cache_path}/{model}-l{layer}-r{revision}-st0.25x128-activations.parquet"
models = {
"gemma-2b-r": "gemma-2b-residuals",
"phi-3": "phi"
}
tokenizers = {
"gemma-2b-r": "alpindale/gemma-2b",
"phi-3": "microsoft/Phi-3-mini-4k-instruct"
}
token_tables = {
"gemma-2b-r": pq.read_table("weights/tokens_gemma.parquet"),
"phi-3": pq.read_table("weights/tokens.parquet")
}
with gr.Blocks() as demo:
feature_table = gr.State(None)
model_name = gr.Dropdown(["phi-3", "gemma-2b-r"], label="Model")
revision_dropdown = gr.Dropdown(revisions, label="Revision")
layer_dropdown = gr.Dropdown(layers[4], label="Layer")
def update_features(layer):
feature_dropdown = gr.Dropdown(features[layer].keys())
return feature_dropdown
def update_layers(revision):
layer_dropdown = gr.Dropdown(layers[revision])
return layer_dropdown
frequency = gr.Number(0, label="Total frequency (%)")
extra_tokens = gr.Number(0, label="Extra Max Act Tokens")
# layer_dropdown.input(update_features, layer_dropdown, feature_dropdown)
# histogram = gr.LinePlot(x="activation", y="freq")
revision_dropdown.input(update_layers, revision_dropdown, layer_dropdown)
feature_input = gr.Number(0, label="Feature")
autoi_expl = gr.Textbox(label="AutoInterp Explanation")
selfe_expl = gr.Textbox(label="SelfExplain Explanation")
cm = gr.HighlightedText()
frame = gr.Highlightedtext()
def update(model, revision, layer, feature, extra_tokens):
correction = 1
if "gemma" in model:
correction = 0
token_table = token_tables[model]
tokenizer_name = tokenizers[model]
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
table = pq.read_table(make_cache_name(layer, revision, models[model]))
table_feat = table.filter(pc.field("feature") == feature).to_pandas()
# freq_t = table_feat[["activation", "freq"]]
total_freq = float(table_feat["freq"].sum()) * 100
table_feat = table_feat[table_feat["activation"] > 0]
table_feat = table_feat[table_feat["freq"] > 0]
table_feat = table_feat.sort_values("activation", ascending=False)
texts = table_feat["token"].apply(
lambda x: [tokenizer.decode(y).replace("\n", " ") for y in token_table[max(0, x - nearby + correction - extra_tokens):x + extra_tokens + nearby + 1 + correction]["tokens"].to_numpy()]
).tolist()
# texts = [tokenizer.tokenize(text) for text in texts]
activations = table_feat["nearby"].to_numpy()
activations = [[0] * extra_tokens + a.tolist() + [0] * extra_tokens for i, a in enumerate(activations) if len(texts[i]) > 0]
texts = [text for text in texts if len(text) > 0]
for t, a in zip(texts, activations):
assert len(t) == len(a)
if len(activations) > 0:
activations = np.stack(activations) * stride
max_act = table_feat["activation"].max()
activations = activations / max_act
highlight_data = [
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)]
for text, activation in zip(texts, activations)
]
flat_data = [item for sublist in highlight_data for item in sublist]
color_map_data = [i / n_bins for i in range(n_bins + 1)]
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data]
else:
flat_data = []
color_map_data = []
if feature in features[revision][layer]:
autoi_expl = features[revision][layer][feature]["explanation"]
selfe_expl = features[revision][layer][feature]["gen_explanations"]
if selfe_expl is not None:
selfe_expl = "\n".join(
f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl)
)
else:
autoi_expl = "No explanation found"
selfe_expl = "No explanation found"
return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl
# feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl])
feature_input.change(update, [model_name, revision_dropdown, layer_dropdown, feature_input, extra_tokens], [frame, cm, frequency, autoi_expl, selfe_expl])
if __name__ == "__main__":
demo.launch(share=True)