|
|
|
import importlib |
|
import importlib.util |
|
import logging |
|
import numpy as np |
|
import os |
|
import random |
|
import sys |
|
from datetime import datetime |
|
import torch |
|
import socket |
|
import subprocess |
|
import time |
|
from . import comm |
|
|
|
__all__ = ["seed_all_rng"] |
|
|
|
|
|
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) |
|
""" |
|
PyTorch version as a tuple of 2 ints. Useful for comparison. |
|
""" |
|
|
|
|
|
def seed_all_rng(seed=None): |
|
""" |
|
Set the random seed for the RNG in torch, numpy and python. |
|
|
|
Args: |
|
seed (int): if None, will use a strong random seed. |
|
""" |
|
if seed is None: |
|
seed = ( |
|
os.getpid() |
|
+ int(datetime.now().strftime("%S%f")) |
|
+ int.from_bytes(os.urandom(2), "big") |
|
) |
|
logger = logging.getLogger(__name__) |
|
logger.info("Using a generated random seed {}".format(seed)) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
|
|
def _import_file(module_name, file_path, make_importable=False): |
|
spec = importlib.util.spec_from_file_location(module_name, file_path) |
|
module = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(module) |
|
if make_importable: |
|
sys.modules[module_name] = module |
|
return module |
|
|
|
|
|
def _configure_libraries(): |
|
""" |
|
Configurations for some libraries. |
|
""" |
|
|
|
|
|
disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False)) |
|
if disable_cv2: |
|
sys.modules["cv2"] = None |
|
else: |
|
|
|
|
|
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" |
|
try: |
|
import cv2 |
|
|
|
if int(cv2.__version__.split(".")[0]) >= 3: |
|
cv2.ocl.setUseOpenCL(False) |
|
except ModuleNotFoundError: |
|
|
|
|
|
|
|
pass |
|
|
|
def get_version(module, digit=2): |
|
return tuple(map(int, module.__version__.split(".")[:digit])) |
|
|
|
|
|
assert get_version(torch) >= (1, 4), "Requires torch>=1.4" |
|
import fvcore |
|
assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2" |
|
import yaml |
|
assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1" |
|
|
|
|
|
|
|
_ENV_SETUP_DONE = False |
|
|
|
|
|
def setup_environment(): |
|
"""Perform environment setup work. The default setup is a no-op, but this |
|
function allows the user to specify a Python source file or a module in |
|
the $DETECTRON2_ENV_MODULE environment variable, that performs |
|
custom setup work that may be necessary to their computing environment. |
|
""" |
|
global _ENV_SETUP_DONE |
|
if _ENV_SETUP_DONE: |
|
return |
|
_ENV_SETUP_DONE = True |
|
|
|
_configure_libraries() |
|
|
|
custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE") |
|
|
|
if custom_module_path: |
|
setup_custom_environment(custom_module_path) |
|
else: |
|
|
|
pass |
|
|
|
|
|
def setup_custom_environment(custom_module): |
|
""" |
|
Load custom environment setup by importing a Python source file or a |
|
module, and run the setup function. |
|
""" |
|
if custom_module.endswith(".py"): |
|
module = _import_file("detectron2.utils.env.custom_module", custom_module) |
|
else: |
|
module = importlib.import_module(custom_module) |
|
assert hasattr(module, "setup_environment") and callable(module.setup_environment), ( |
|
"Custom environment module defined in {} does not have the " |
|
"required callable attribute 'setup_environment'." |
|
).format(custom_module) |
|
module.setup_environment() |
|
|
|
def check_dist_portfile(): |
|
if "SLURM_JOB_ID" in os.environ and int(os.environ["SLURM_PROCID"]) == 0: |
|
hostfile = "dist_url_" + os.environ["SLURM_JOBID"] + ".txt" |
|
if os.path.exists(hostfile): |
|
os.remove(hostfile) |
|
|
|
def find_free_port(): |
|
s = socket.socket() |
|
s.bind(('', 0)) |
|
return s.getsockname()[1] |
|
|
|
def init_distributed_mode(args): |
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
|
if int(os.environ["RANK"])==0: |
|
print('this task is not running on cluster!') |
|
args.rank = int(os.environ["RANK"]) |
|
args.world_size = int(os.environ['WORLD_SIZE']) |
|
args.gpu = int(os.environ['LOCAL_RANK']) |
|
args.dist_url = 'env://' |
|
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
|
addr = socket.gethostname() |
|
|
|
elif 'SLURM_PROCID' in os.environ: |
|
proc_id = int(os.environ['SLURM_PROCID']) |
|
if proc_id==0: |
|
print('Init dist using slurm!') |
|
print("Job Id is {} on {} ".format(os.environ["SLURM_JOBID"], os.environ['SLURM_NODELIST'])) |
|
ntasks = int(os.environ['SLURM_NTASKS']) |
|
node_list = os.environ['SLURM_NODELIST'] |
|
num_gpus = torch.cuda.device_count() |
|
addr = subprocess.getoutput( |
|
'scontrol show hostname {} | head -n1'.format(node_list)) |
|
jobid = os.environ["SLURM_JOBID"] |
|
hostfile = "dist_url_" + jobid + ".txt" |
|
if proc_id == 0: |
|
args.tcp_port = str( find_free_port()) |
|
print('write port {} to file: {} '.format(args.tcp_port, hostfile)) |
|
with open(hostfile, "w") as f: |
|
f.write(args.tcp_port) |
|
else: |
|
print('read port from file: {}'.format(hostfile)) |
|
while not os.path.exists(hostfile): |
|
time.sleep(1) |
|
time.sleep(2) |
|
with open(hostfile, "r") as f: |
|
args.tcp_port = f.read() |
|
|
|
os.environ['MASTER_PORT'] =str(args.tcp_port) |
|
os.environ['MASTER_ADDR'] = addr |
|
os.environ['WORLD_SIZE'] = str(ntasks) |
|
os.environ['RANK'] = str(proc_id) |
|
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) |
|
os.environ['LOCAL_SIZE'] = str(num_gpus) |
|
args.dist_url = 'env://' |
|
args.world_size = ntasks |
|
args.rank = proc_id |
|
args.gpu = proc_id % num_gpus |
|
else: |
|
print('Not using distributed mode') |
|
args.distributed = False |
|
return |
|
|
|
args.distributed = True |
|
|
|
torch.cuda.set_device(args.gpu) |
|
args.dist_backend = 'nccl' |
|
print('rank: {} addr: {} port: {}'.format(args.rank, addr, os.environ['MASTER_PORT'])) |
|
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
|
world_size=args.world_size, rank=args.rank) |
|
torch.distributed.barrier() |
|
if 'SLURM_PROCID' in os.environ and args.rank == 0: |
|
if os.path.isfile(hostfile): |
|
os.remove(hostfile) |
|
if args.world_size >= 1: |
|
|
|
assert comm._LOCAL_PROCESS_GROUP is None |
|
num_gpus = torch.cuda.device_count() |
|
num_machines = args.world_size // num_gpus |
|
for i in range(num_machines): |
|
ranks_on_i = list(range(i * num_gpus, (i + 1) * num_gpus)) |
|
print('new_group: {}'.format(ranks_on_i)) |
|
pg = torch.distributed.new_group(ranks_on_i) |
|
if args.rank in ranks_on_i: |
|
|
|
comm._LOCAL_PROCESS_GROUP = pg |
|
|