Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,188 Bytes
82bc972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from pathlib import Path
import torch, os
from tqdm import tqdm
import pickle
import argparse
import logging, datetime
import torch.distributed as dist
from config import MyParser
from steps import trainer
from copy_codebase import copy_codebase
def world_info_from_env():
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
return local_rank, global_rank, world_size
if __name__ == "__main__":
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
torch.cuda.empty_cache()
args = MyParser().parse_args()
exp_dir = Path(args.exp_dir)
exp_dir.mkdir(exist_ok=True, parents=True)
logging.info(f"exp_dir: {str(exp_dir)}")
if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)):
if not os.path.exists("%s/bundle.pth" % args.exp_dir):
os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth")
resume = args.resume
assert(bool(args.exp_dir))
with open("%s/args.pkl" % args.exp_dir, "rb") as f:
old_args = pickle.load(f)
new_args = vars(args)
old_args = vars(old_args)
for key in new_args:
if key not in old_args or old_args[key] != new_args[key]:
old_args[key] = new_args[key]
args = argparse.Namespace(**old_args)
args.resume = resume
else:
args.resume = False
with open("%s/args.pkl" % args.exp_dir, "wb") as f:
pickle.dump(args, f)
# make timeout longer (for generation)
timeout = datetime.timedelta(seconds=7200) # 60 minutes
if args.multinodes:
_local_rank, _, _ = world_info_from_env()
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout)
else:
dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout)
if args.local_wandb:
os.environ["WANDB_MODE"] = "offline"
rank = dist.get_rank()
if rank == 0:
logging.info(args)
logging.info(f"exp_dir: {str(exp_dir)}")
world_size = dist.get_world_size()
local_rank = int(_local_rank) if args.multinodes else rank
num_devices= torch.cuda.device_count()
logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}")
for device_idx in range(num_devices):
device_name = torch.cuda.get_device_name(device_idx)
logging.info(f"Device {device_idx}: {device_name}")
torch.cuda.set_device(local_rank)
if rank == 0:
user_dir = os.path.expanduser("~")
codebase_name = "VoiceStar"
now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore"))
my_trainer = trainer.Trainer(args, world_size, rank, local_rank)
my_trainer.train() |