Spaces:
Runtime error
Runtime error
from omegaconf import OmegaConf | |
import torch | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Iterable, | |
List, | |
NamedTuple, | |
NewType, | |
Optional, | |
Sized, | |
Tuple, | |
Type, | |
TypeVar, | |
Union, | |
) | |
try: | |
from typing import Literal | |
except ImportError: | |
from typing_extensions import Literal | |
# Tensor dtype | |
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md | |
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt | |
# Config type | |
from omegaconf import DictConfig | |
# PyTorch Tensor type | |
from torch import Tensor | |
# Runtime type checking decorator | |
from typeguard import typechecked as typechecker | |
def broadcast(tensor, src=0): | |
if not _distributed_available(): | |
return tensor | |
else: | |
torch.distributed.broadcast(tensor, src=src) | |
return tensor | |
def _distributed_available(): | |
return torch.distributed.is_available() and torch.distributed.is_initialized() | |
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: | |
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword | |
if '--local-rank' in cfg: | |
del cfg['--local-rank'] | |
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword | |
scfg = OmegaConf.structured(fields(**cfg)) | |
return scfg |