Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import json | |
import os | |
import sys | |
from pathlib import Path | |
import altair as alt | |
import pandas as pd | |
from omegaconf import OmegaConf | |
from pydantic import BaseModel | |
class PlotEntropiesConfig(BaseModel): | |
data_path: str | None | |
chart_path: str | |
score_override_path: str | None = None | |
threshold_override: float | None = None | |
class Config: | |
extra = "forbid" | |
class PlotEntropiesData(BaseModel): | |
text: str | |
threshold: float = 1.335442066192627 | |
dataframe_json: str | None | |
class Config: | |
extra = "forbid" | |
def main(): | |
config_path = sys.argv[1] | |
file_config = OmegaConf.load(config_path) | |
# Omit program name and config file name | |
cli_conf = OmegaConf.from_cli(sys.argv[2:]) | |
conf_dict = OmegaConf.to_container( | |
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True | |
) | |
plot_config = PlotEntropiesConfig(**conf_dict) | |
with open(plot_config.data_path) as f: | |
json_data = f.read() | |
plot_data = PlotEntropiesData.model_validate_json(json_data) | |
df = pd.read_json(plot_data.dataframe_json) | |
print("LEN", len(df)) | |
if plot_config.threshold_override is None: | |
threshold = plot_data.threshold | |
else: | |
threshold = plot_config.threshold_override | |
if plot_config.score_override_path is not None: | |
with open(plot_config.score_override_path) as f: | |
scores = json.load(f)["score"] | |
assert len(scores) == len(df) | |
df["entropies"] = scores | |
df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] | |
x_ticks = [] | |
for row in df.itertuples(): | |
position = row.position | |
token = row.tokens | |
x_ticks.append(f"{str(position).zfill(3)}|{token}") | |
df["position_with_token"] = x_ticks | |
print(df) | |
x_axis = alt.Axis( | |
labelExpr="split(datum.label, '|')[1]", | |
grid=False, | |
labelOverlap=False, | |
labelAngle=0, | |
) | |
width = 1200 | |
height = 150 | |
base = alt.Chart(df).properties(width=width, height=height) | |
points = base.mark_line(point=True).encode( | |
x=alt.X("position_with_token:O", title=None, axis=x_axis), | |
y=alt.Y( | |
"entropies", | |
title="Entropy of Next Byte", | |
), | |
) | |
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( | |
y=alt.datum(threshold), | |
) | |
patch_rules = ( | |
alt.Chart(df[df["start"] > 0]) | |
.properties(width=width, height=height) | |
.mark_rule(color="#474747", strokeDash=[4, 2]) | |
.encode(x=alt.X("position_with_token:O", axis=x_axis)) | |
) | |
chart = patch_rules + rule + points | |
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15) | |
path = Path(plot_config.chart_path) | |
path.parent.mkdir(exist_ok=True) | |
chart.save(path) | |
if __name__ == "__main__": | |
main() | |