Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pyarrow.parquet as pq | |
import pyarrow.compute as pc | |
from transformers import AutoTokenizer | |
from datasets import load_dataset | |
import os | |
import numpy as np | |
token_table = pq.read_table("weights/tokens.parquet") | |
cache_path = "weights/caches" | |
parquets = os.listdir(cache_path) | |
TOKENIZER = "microsoft/Phi-3-mini-4k-instruct" | |
dataset = load_dataset("kisate-team/feature-explanations", split="train") | |
layers = dataset.unique("layer") | |
features = {layer:{item["feature"]:item for item in dataset if item["layer"] == layer} for layer in layers} | |
nearby = 8 | |
stride = 0.25 | |
n_bins = 10 | |
def make_cache_name(layer): | |
return f"{cache_path}/phi-l{layer}-r4-st0.25x128-activations.parquet" | |
with gr.Blocks() as demo: | |
feature_table = gr.State(None) | |
tokenizer_name = gr.Textbox(TOKENIZER) | |
layer_dropdown = gr.Dropdown(layers) | |
feature_dropdown = gr.Dropdown() | |
def update_features(layer): | |
feature_dropdown = gr.Dropdown(features[layer].keys()) | |
return feature_dropdown | |
layer_dropdown.input(update_features, layer_dropdown, feature_dropdown) | |
frequency = gr.Number(0, label="Total frequency (%)") | |
# histogram = gr.LinePlot(x="activation", y="freq") | |
autoi_expl = gr.Textbox() | |
selfe_expl = gr.Textbox() | |
cm = gr.HighlightedText() | |
frame = gr.Highlightedtext() | |
def update(layer, feature, tokenizer_name): | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
table = pq.read_table(make_cache_name(layer)) | |
table_feat = table.filter(pc.field("feature") == feature).to_pandas() | |
# freq_t = table_feat[["activation", "freq"]] | |
total_freq = float(table_feat["freq"].sum()) * 100 | |
table_feat = table_feat[table_feat["activation"] > 0] | |
table_feat = table_feat[table_feat["freq"] > 0] | |
table_feat = table_feat.sort_values("activation", ascending=False) | |
texts = table_feat["token"].apply( | |
lambda x: tokenizer.decode(token_table[max(0, x - nearby - 1):x + nearby + 1]["tokens"].to_numpy()) | |
) | |
texts = [tokenizer.tokenize(text) for text in texts] | |
activations = table_feat["nearby"].to_numpy() | |
if len(activations) > 0: | |
activations = np.stack(activations) * stride | |
max_act = table_feat["activation"].max() | |
activations = activations / max_act | |
highlight_data = [ | |
[(token, activation) for token, activation in zip(text, activation)] + [("\n", 0)] | |
for text, activation in zip(texts, activations) | |
] | |
flat_data = [item for sublist in highlight_data for item in sublist] | |
color_map_data = [i / n_bins for i in range(n_bins + 1)] | |
color_map_data = [(f"{i*max_act:.2f}", i) for i in color_map_data] | |
else: | |
flat_data = [] | |
color_map_data = [] | |
autoi_expl = features[layer][feature]["explanation"] | |
selfe_expl = features[layer][feature]["gen_explanations"] | |
if selfe_expl is not None: | |
selfe_expl = "\n".join( | |
f"{i+1}. \"{x}\"" for i, x in enumerate(selfe_expl) | |
) | |
return flat_data, color_map_data, total_freq, autoi_expl, selfe_expl | |
feature_dropdown.change(update, [layer_dropdown, feature_dropdown, tokenizer_name], [frame, cm, frequency, autoi_expl, selfe_expl]) | |
# feature_input.change(update, [dropdown, feature_input, tokenizer_name, token_range], [frame, cm, histogram, frequency]) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |