Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import sys | |
from pathlib import Path | |
import altair as alt | |
import pandas as pd | |
import pydantic | |
from omegaconf import OmegaConf | |
class ScalingPlotsConfig(pydantic.BaseModel): | |
df_dir: str | |
output_chart_dir: str | |
frame_files: list[str] | |
class Config: | |
extra = "forbid" | |
def determine_family(key: str): | |
if key.startswith("Megabyte++"): | |
return "Megabyte++" | |
elif key.startswith("BLT"): | |
return "BLT" | |
elif key.startswith("LLaMA"): | |
return "LLaMA" | |
elif key.startswith("Space"): | |
return "Space" | |
file_to_vars = {} | |
def create_chart(df: pd.DataFrame, output_file: str): | |
df["metric"] = df["bpb/not_heldout.jsonl"] | |
df["family"] = df["key"].map(determine_family) | |
model_domain = [ | |
"BLT Space ps=6", | |
"BLT Space w/o cross-attn", | |
"SpaceByte", | |
"LLaMA 3 BPE", | |
"Megabyte++ ps=4", | |
"Megabyte++ ps=6", | |
] | |
color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"] | |
shape_range = [ | |
"circle", | |
"square", | |
"cross", | |
"diamond", | |
"triangle-up", | |
"triangle-down", | |
] | |
color_scale = alt.Scale(domain=model_domain, range=color_range) | |
shape_scale = alt.Scale( | |
domain=model_domain, | |
range=shape_range, | |
) | |
base_chart = alt.Chart(df).encode( | |
x=alt.X("flops", title="Training FLOPS") | |
.scale(type="log", domain=[2e20, 1.25e22]) | |
.axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]), | |
y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False), | |
) | |
lines = base_chart.encode( | |
color=alt.Color("key", title="Model Color", scale=color_scale, legend=None), | |
strokeDash=alt.StrokeDash("family", title="Model Family", legend=None), | |
).mark_line() | |
points = base_chart.encode( | |
color=alt.Color("key", title="Model", scale=color_scale), | |
shape=alt.Shape("key", title="", scale=shape_scale), | |
).mark_point(size=70) | |
chart = ( | |
(lines + points) | |
.resolve_scale( | |
color="independent", | |
shape="independent", | |
# strokeDash="independent", | |
) | |
.configure_legend(orient="right") | |
.properties(height=300, width=400) | |
) | |
print("Saving", output_file) | |
chart.save(output_file) | |
def main(): | |
config_path = sys.argv[1] | |
file_config = OmegaConf.load(config_path) | |
# Omit program name and config file name | |
cli_conf = OmegaConf.from_cli(sys.argv[2:]) | |
conf_dict = OmegaConf.to_container( | |
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True | |
) | |
plot_config = ScalingPlotsConfig(**conf_dict) | |
df_dir = Path(plot_config.df_dir) | |
chart_dir = Path(plot_config.output_chart_dir) | |
chart_dir.mkdir(exist_ok=True, parents=True) | |
for ff in plot_config.frame_files: | |
path = df_dir / ff | |
df = pd.read_json(path) | |
print(df) | |
print(df.columns) | |
create_chart(df, chart_dir / f"{path.name}.pdf") | |
if __name__ == "__main__": | |
main() | |