Spaces:
Running
on
Zero
Running
on
Zero
fix stool (#44)
Browse filesCo-authored-by: Srini Iyer <[email protected]>
- bytelatent/stool.py +8 -8
bytelatent/stool.py
CHANGED
@@ -4,14 +4,15 @@ import json
|
|
4 |
import os
|
5 |
import shutil
|
6 |
import subprocess
|
7 |
-
from
|
8 |
from typing import Any, Dict
|
9 |
|
10 |
from omegaconf import OmegaConf
|
11 |
|
12 |
|
13 |
-
|
14 |
-
|
|
|
15 |
config: Any = None
|
16 |
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
17 |
script: str = "apps.main.train" # The script to run.
|
@@ -64,7 +65,7 @@ source activate {conda_env_path}
|
|
64 |
export OMP_NUM_THREADS=1
|
65 |
export LAUNCH_WITH="SBATCH"
|
66 |
export DUMP_DIR={dump_dir}
|
67 |
-
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
|
68 |
"""
|
69 |
|
70 |
|
@@ -150,8 +151,8 @@ def validate_args(args) -> None:
|
|
150 |
def launch_job(args: StoolArgs):
|
151 |
# Set up args default and validate them depending on the cluster or partition requested
|
152 |
validate_args(args)
|
153 |
-
|
154 |
-
|
155 |
print("Creating directories...")
|
156 |
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
157 |
if args.override:
|
@@ -230,8 +231,7 @@ if __name__ == "__main__":
|
|
230 |
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
231 |
or just name=tictac for top level attributes.
|
232 |
"""
|
233 |
-
raise NotImplementedError("Update this to blt code")
|
234 |
args = OmegaConf.from_cli()
|
235 |
args.config = OmegaConf.load(args.config)
|
236 |
-
args =
|
237 |
launch_job(args)
|
|
|
4 |
import os
|
5 |
import shutil
|
6 |
import subprocess
|
7 |
+
from pydantic import BaseModel
|
8 |
from typing import Any, Dict
|
9 |
|
10 |
from omegaconf import OmegaConf
|
11 |
|
12 |
|
13 |
+
class StoolArgs(BaseModel):
|
14 |
+
name: str = None
|
15 |
+
dump_dir: str = None
|
16 |
config: Any = None
|
17 |
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
18 |
script: str = "apps.main.train" # The script to run.
|
|
|
65 |
export OMP_NUM_THREADS=1
|
66 |
export LAUNCH_WITH="SBATCH"
|
67 |
export DUMP_DIR={dump_dir}
|
68 |
+
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
|
69 |
"""
|
70 |
|
71 |
|
|
|
151 |
def launch_job(args: StoolArgs):
|
152 |
# Set up args default and validate them depending on the cluster or partition requested
|
153 |
validate_args(args)
|
154 |
+
job_name = args.name or args.config["name"]
|
155 |
+
dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
|
156 |
print("Creating directories...")
|
157 |
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
158 |
if args.override:
|
|
|
231 |
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
232 |
or just name=tictac for top level attributes.
|
233 |
"""
|
|
|
234 |
args = OmegaConf.from_cli()
|
235 |
args.config = OmegaConf.load(args.config)
|
236 |
+
args = StoolArgs.model_validate(args)
|
237 |
launch_job(args)
|