File size: 3,118 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Meta Platforms, Inc. and affiliates.
import sys
from pathlib import Path

import altair as alt
import pandas as pd
import pydantic
from omegaconf import OmegaConf


class ScalingPlotsConfig(pydantic.BaseModel):
    df_dir: str
    output_chart_dir: str
    frame_files: list[str]

    class Config:
        extra = "forbid"


def determine_family(key: str):
    if key.startswith("Megabyte++"):
        return "Megabyte++"
    elif key.startswith("BLT"):
        return "BLT"
    elif key.startswith("LLaMA"):
        return "LLaMA"
    elif key.startswith("Space"):
        return "Space"


file_to_vars = {}


def create_chart(df: pd.DataFrame, output_file: str):
    df["metric"] = df["bpb/not_heldout.jsonl"]
    df["family"] = df["key"].map(determine_family)
    model_domain = [
        "BLT Space ps=6",
        "BLT Space w/o cross-attn",
        "SpaceByte",
        "LLaMA 3 BPE",
        "Megabyte++ ps=4",
        "Megabyte++ ps=6",
    ]
    color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
    shape_range = [
        "circle",
        "square",
        "cross",
        "diamond",
        "triangle-up",
        "triangle-down",
    ]
    color_scale = alt.Scale(domain=model_domain, range=color_range)
    shape_scale = alt.Scale(
        domain=model_domain,
        range=shape_range,
    )
    base_chart = alt.Chart(df).encode(
        x=alt.X("flops", title="Training FLOPS")
        .scale(type="log", domain=[2e20, 1.25e22])
        .axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
        y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
    )
    lines = base_chart.encode(
        color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
        strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
    ).mark_line()
    points = base_chart.encode(
        color=alt.Color("key", title="Model", scale=color_scale),
        shape=alt.Shape("key", title="", scale=shape_scale),
    ).mark_point(size=70)
    chart = (
        (lines + points)
        .resolve_scale(
            color="independent",
            shape="independent",
            # strokeDash="independent",
        )
        .configure_legend(orient="right")
        .properties(height=300, width=400)
    )
    print("Saving", output_file)
    chart.save(output_file)


def main():
    config_path = sys.argv[1]
    file_config = OmegaConf.load(config_path)
    # Omit program name and config file name
    cli_conf = OmegaConf.from_cli(sys.argv[2:])
    conf_dict = OmegaConf.to_container(
        OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
    )
    plot_config = ScalingPlotsConfig(**conf_dict)
    df_dir = Path(plot_config.df_dir)
    chart_dir = Path(plot_config.output_chart_dir)
    chart_dir.mkdir(exist_ok=True, parents=True)
    for ff in plot_config.frame_files:
        path = df_dir / ff
        df = pd.read_json(path)
        print(df)
        print(df.columns)
        create_chart(df, chart_dir / f"{path.name}.pdf")


if __name__ == "__main__":
    main()