""" Misc Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) Please cite our work if the code is helpful to you. """ import os import warnings from collections import abc import numpy as np import torch from importlib import import_module class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def intersection_and_union(output, target, K, ignore_index=-1): # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. assert output.ndim in [1, 2, 3] assert output.shape == target.shape output = output.reshape(output.size).copy() target = target.reshape(target.size) output[np.where(target == ignore_index)[0]] = ignore_index intersection = output[np.where(output == target)[0]] area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) area_output, _ = np.histogram(output, bins=np.arange(K + 1)) area_target, _ = np.histogram(target, bins=np.arange(K + 1)) area_union = area_output + area_target - area_intersection return area_intersection, area_union, area_target def intersection_and_union_gpu(output, target, k, ignore_index=-1): # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. assert output.dim() in [1, 2, 3] assert output.shape == target.shape output = output.view(-1) target = target.view(-1) output[target == ignore_index] = ignore_index intersection = output[output == target] area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) area_output = torch.histc(output, bins=k, min=0, max=k - 1) area_target = torch.histc(target, bins=k, min=0, max=k - 1) area_union = area_output + area_target - area_intersection return area_intersection, area_union, area_target def make_dirs(dir_name): if not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) def find_free_port(): import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() # NOTE: there is still a chance the port could be taken by other processes. return port def is_seq_of(seq, expected_type, seq_type=None): """Check whether it is a sequence of some type. Args: seq (Sequence): The sequence to be checked. expected_type (type): Expected type of sequence items. seq_type (type, optional): Expected sequence type. Returns: bool: Whether the sequence is valid. """ if seq_type is None: exp_seq_type = abc.Sequence else: assert isinstance(seq_type, type) exp_seq_type = seq_type if not isinstance(seq, exp_seq_type): return False for item in seq: if not isinstance(item, expected_type): return False return True def is_str(x): """Whether the input is an string instance. Note: This method is deprecated since python 2 is no longer supported. """ return isinstance(x, str) def import_modules_from_strings(imports, allow_failed_imports=False): """Import modules from the given list of strings. Args: imports (list | str | None): The given module names to be imported. allow_failed_imports (bool): If True, the failed imports will return None. Otherwise, an ImportError is raise. Default: False. Returns: list[module] | module | None: The imported modules. Examples: >>> osp, sys = import_modules_from_strings( ... ['os.path', 'sys']) >>> import os.path as osp_ >>> import sys as sys_ >>> assert osp == osp_ >>> assert sys == sys_ """ if not imports: return single_import = False if isinstance(imports, str): single_import = True imports = [imports] if not isinstance(imports, list): raise TypeError(f"custom_imports must be a list but got type {type(imports)}") imported = [] for imp in imports: if not isinstance(imp, str): raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") try: imported_tmp = import_module(imp) except ImportError: if allow_failed_imports: warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) imported_tmp = None else: raise ImportError imported.append(imported_tmp) if single_import: imported = imported[0] return imported class DummyClass: def __init__(self): pass