Srinivasan Iyer sviyer commited on
Commit
6fbaf72
·
unverified ·
1 Parent(s): 7cf8fab

fix stool (#44)

Browse files

Co-authored-by: Srini Iyer <[email protected]>

Files changed (1) hide show
  1. bytelatent/stool.py +8 -8
bytelatent/stool.py CHANGED
@@ -4,14 +4,15 @@ import json
4
  import os
5
  import shutil
6
  import subprocess
7
- from dataclasses import dataclass
8
  from typing import Any, Dict
9
 
10
  from omegaconf import OmegaConf
11
 
12
 
13
- @dataclass
14
- class StoolArgs:
 
15
  config: Any = None
16
  launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
17
  script: str = "apps.main.train" # The script to run.
@@ -64,7 +65,7 @@ source activate {conda_env_path}
64
  export OMP_NUM_THREADS=1
65
  export LAUNCH_WITH="SBATCH"
66
  export DUMP_DIR={dump_dir}
67
- srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
68
  """
69
 
70
 
@@ -150,8 +151,8 @@ def validate_args(args) -> None:
150
  def launch_job(args: StoolArgs):
151
  # Set up args default and validate them depending on the cluster or partition requested
152
  validate_args(args)
153
- dump_dir = args.config["dump_dir"]
154
- job_name = args.config["name"]
155
  print("Creating directories...")
156
  os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
157
  if args.override:
@@ -230,8 +231,7 @@ if __name__ == "__main__":
230
  Then you can pass model.dim=32 to change values in LMTransformerArgs
231
  or just name=tictac for top level attributes.
232
  """
233
- raise NotImplementedError("Update this to blt code")
234
  args = OmegaConf.from_cli()
235
  args.config = OmegaConf.load(args.config)
236
- args = dataclass_from_dict(StoolArgs, args)
237
  launch_job(args)
 
4
  import os
5
  import shutil
6
  import subprocess
7
+ from pydantic import BaseModel
8
  from typing import Any, Dict
9
 
10
  from omegaconf import OmegaConf
11
 
12
 
13
+ class StoolArgs(BaseModel):
14
+ name: str = None
15
+ dump_dir: str = None
16
  config: Any = None
17
  launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
18
  script: str = "apps.main.train" # The script to run.
 
65
  export OMP_NUM_THREADS=1
66
  export LAUNCH_WITH="SBATCH"
67
  export DUMP_DIR={dump_dir}
68
+ srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
69
  """
70
 
71
 
 
151
  def launch_job(args: StoolArgs):
152
  # Set up args default and validate them depending on the cluster or partition requested
153
  validate_args(args)
154
+ job_name = args.name or args.config["name"]
155
+ dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
156
  print("Creating directories...")
157
  os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
158
  if args.override:
 
231
  Then you can pass model.dim=32 to change values in LMTransformerArgs
232
  or just name=tictac for top level attributes.
233
  """
 
234
  args = OmegaConf.from_cli()
235
  args.config = OmegaConf.load(args.config)
236
+ args = StoolArgs.model_validate(args)
237
  launch_job(args)