Unimernet_gradio / train.py
JohnTitor22's picture
Upload 228 files
12b7ec5 verified
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import unimernet.tasks as tasks
from unimernet.common.config import Config
from unimernet.common.dist_utils import get_rank, init_distributed_mode
from unimernet.common.logger import setup_logger
from unimernet.common.optims import (
LinearWarmupCosineLRScheduler,
LinearWarmupStepLRScheduler,
)
from unimernet.common.registry import registry
from unimernet.common.utils import now
# imports modules for registration
from unimernet.datasets.builders import *
from unimernet.models import *
from unimernet.processors import *
from unimernet.runners import *
from unimernet.tasks import *
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
parser = argparse.ArgumentParser(description="Training")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
# if 'LOCAL_RANK' not in os.environ:
# os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
def get_runner_class(cfg):
"""
Get runner class from config. Default to epoch-based runner.
"""
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
return runner_cls
def main():
# allow auto-dl completes on main process without timeout when using NCCL backend.
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
job_id = now()
cfg = Config(parse_args())
init_distributed_mode(cfg.run_cfg)
setup_seeds(cfg)
# set after init_distributed_mode() to only log on master.
setup_logger()
cfg.pretty_print()
task = tasks.setup_task(cfg)
datasets = task.build_datasets(cfg)
model = task.build_model(cfg)
runner = get_runner_class(cfg)(
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
)
runner.train()
if __name__ == "__main__":
main()