File size: 2,907 Bytes
bcc039b
d4ddb95
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
d4ddb95
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4ddb95
bcc039b
 
d4ddb95
 
 
 
 
 
 
 
 
 
 
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4ddb95
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import os
import sys
from pathlib import Path

import altair as alt
import pandas as pd
from omegaconf import OmegaConf
from pydantic import BaseModel


class PlotEntropiesConfig(BaseModel):
    data_path: str | None
    chart_path: str
    score_override_path: str | None = None
    threshold_override: float | None = None

    class Config:
        extra = "forbid"


class PlotEntropiesData(BaseModel):
    text: str
    threshold: float = 1.335442066192627
    dataframe_json: str | None

    class Config:
        extra = "forbid"


def main():
    config_path = sys.argv[1]
    file_config = OmegaConf.load(config_path)
    # Omit program name and config file name
    cli_conf = OmegaConf.from_cli(sys.argv[2:])
    conf_dict = OmegaConf.to_container(
        OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
    )
    plot_config = PlotEntropiesConfig(**conf_dict)
    with open(plot_config.data_path) as f:
        json_data = f.read()

    plot_data = PlotEntropiesData.model_validate_json(json_data)
    df = pd.read_json(plot_data.dataframe_json)
    print("LEN", len(df))
    if plot_config.threshold_override is None:
        threshold = plot_data.threshold
    else:
        threshold = plot_config.threshold_override
    if plot_config.score_override_path is not None:
        with open(plot_config.score_override_path) as f:
            scores = json.load(f)["score"]
            assert len(scores) == len(df)
            df["entropies"] = scores
            df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1]

    x_ticks = []
    for row in df.itertuples():
        position = row.position
        token = row.tokens
        x_ticks.append(f"{str(position).zfill(3)}|{token}")
    df["position_with_token"] = x_ticks
    print(df)

    x_axis = alt.Axis(
        labelExpr="split(datum.label, '|')[1]",
        grid=False,
        labelOverlap=False,
        labelAngle=0,
    )
    width = 1200
    height = 150
    base = alt.Chart(df).properties(width=width, height=height)
    points = base.mark_line(point=True).encode(
        x=alt.X("position_with_token:O", title=None, axis=x_axis),
        y=alt.Y(
            "entropies",
            title="Entropy of Next Byte",
        ),
    )
    rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
        y=alt.datum(threshold),
    )
    patch_rules = (
        alt.Chart(df[df["start"] > 0])
        .properties(width=width, height=height)
        .mark_rule(color="#474747", strokeDash=[4, 2])
        .encode(x=alt.X("position_with_token:O", axis=x_axis))
    )

    chart = patch_rules + rule + points
    chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
    path = Path(plot_config.chart_path)
    path.parent.mkdir(exist_ok=True)
    chart.save(path)


if __name__ == "__main__":
    main()