Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch as th | |
import os | |
import re | |
import glob | |
import copy | |
from typing import Dict, Any, Iterator, Mapping, Optional, Union, Tuple, List | |
from collections import OrderedDict | |
from torch.utils.tensorboard import SummaryWriter | |
from omegaconf import OmegaConf, DictConfig | |
from torch.optim.lr_scheduler import LRScheduler | |
from visualize.ca_body.utils.torch import to_device | |
from visualize.ca_body.utils.module_loader import load_class, build_optimizer | |
import torch.nn as nn | |
import logging | |
logging.basicConfig( | |
format="[%(asctime)s][%(levelname)s][%(name)s]:%(message)s", | |
level=logging.INFO, | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
logger = logging.getLogger(__name__) | |
def process_losses( | |
loss_dict: Dict[str, Any], reduce: bool = True, detach: bool = True | |
) -> Dict[str, th.Tensor]: | |
"""Preprocess the dict of losses outputs.""" | |
result = {k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")} | |
if detach: | |
result = {k: v.detach() for k, v in result.items()} | |
if reduce: | |
result = {k: float(v.mean().item()) for k, v in result.items()} | |
return result | |
def load_config(path: str) -> DictConfig: | |
# NOTE: THIS IS THE ONLY PLACE WHERE WE MODIFY CONFIG | |
config = OmegaConf.load(path) | |
# TODO: we should need to get rid of this in favor of DB | |
assert 'CARE_ROOT' in os.environ | |
config.CARE_ROOT = os.environ['CARE_ROOT'] | |
logger.info(f'{config.CARE_ROOT=}') | |
if not os.path.isabs(config.train.run_dir): | |
config.train.run_dir = os.path.join(os.environ['CARE_ROOT'], config.train.run_dir) | |
logger.info(f'{config.train.run_dir=}') | |
os.makedirs(config.train.run_dir, exist_ok=True) | |
return config | |
def load_from_config(config: Mapping[str, Any], **kwargs): | |
"""Instantiate an object given a config and arguments.""" | |
assert 'class_name' in config and 'module_name' not in config | |
config = copy.deepcopy(config) | |
ckpt = None if 'ckpt' not in config else config.pop('ckpt') | |
class_name = config.pop('class_name') | |
object_class = load_class(class_name) | |
instance = object_class(**config, **kwargs) | |
if ckpt is not None: | |
load_checkpoint( | |
ckpt_path=ckpt.path, | |
modules={ckpt.get('module_name', 'model'): instance}, | |
ignore_names=ckpt.get('ignore_names', []), | |
strict=ckpt.get('strict', False), | |
) | |
return instance | |
def save_checkpoint(ckpt_path, modules: Dict[str, Any], iteration=None, keep_last_k=None): | |
if keep_last_k is not None: | |
raise NotImplementedError() | |
ckpt_dict = {} | |
if os.path.isdir(ckpt_path): | |
assert iteration is not None | |
ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt") | |
ckpt_dict["iteration"] = iteration | |
for name, mod in modules.items(): | |
if hasattr(mod, "module"): | |
mod = mod.module | |
ckpt_dict[name] = mod.state_dict() | |
th.save(ckpt_dict, ckpt_path) | |
def filter_params(params, ignore_names): | |
return OrderedDict( | |
[ | |
(k, v) | |
for k, v in params.items() | |
if not any([re.match(n, k) is not None for n in ignore_names]) | |
] | |
) | |
def save_file_summaries(path: str, summaries: Dict[str, Tuple[str, Any]]): | |
"""Saving regular summaries for monitoring purposes.""" | |
for name, (value, ext) in summaries.items(): | |
#save(f'{path}/{name}.{ext}', value) | |
raise NotImplementedError() | |
def load_checkpoint( | |
ckpt_path: str, | |
modules: Dict[str, Any], | |
iteration: int =None, | |
strict: bool =False, | |
map_location: Optional[str] =None, | |
ignore_names: Optional[Dict[str, List[str]]]=None, | |
): | |
"""Load a checkpoint. | |
Args: | |
ckpt_path: directory or the full path to the checkpoint | |
""" | |
if map_location is None: | |
map_location = "cpu" | |
# adding | |
if os.path.isdir(ckpt_path): | |
if iteration is None: | |
# lookup latest iteration | |
iteration = max( | |
[ | |
int(os.path.splitext(os.path.basename(p))[0]) | |
for p in glob.glob(os.path.join(ckpt_path, "*.pt")) | |
] | |
) | |
ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt") | |
logger.info(f"loading checkpoint {ckpt_path}") | |
ckpt_dict = th.load(ckpt_path, map_location=map_location) | |
for name, mod in modules.items(): | |
params = ckpt_dict[name] | |
if ignore_names is not None and name in ignore_names: | |
logger.info(f"skipping: {ignore_names[name]}") | |
params = filter_params(params, ignore_names[name]) | |
mod.load_state_dict(params, strict=strict) | |
def train( | |
model: nn.Module, | |
loss_fn: nn.Module, | |
optimizer: th.optim.Optimizer, | |
train_data: Iterator, | |
config: Mapping[str, Any], | |
lr_scheduler: Optional[LRScheduler] = None, | |
train_writer: Optional[SummaryWriter] = None, | |
saving_enabled: bool = True, | |
logging_enabled: bool = True, | |
iteration: int = 0, | |
device: Optional[Union[th.device, str]] = "cuda:0", | |
) -> None: | |
for batch in train_data: | |
if batch is None: | |
logger.info("skipping empty batch") | |
continue | |
batch = to_device(batch, device) | |
batch["iteration"] = iteration | |
# leaving only inputs acutally used by the model | |
preds = model(**filter_inputs(batch, model, required_only=False)) | |
# TODO: switch to the old-school loss computation | |
loss, loss_dict = loss_fn(preds, batch, iteration=iteration) | |
assert not th.isnan(loss), "loss is NaN" | |
if th.isnan(loss): | |
_loss_dict = process_losses(loss_dict) | |
loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) | |
logger.info(f"iter={iteration}: {loss_str}") | |
raise ValueError("loss is NaN") | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if logging_enabled and iteration % config.train.log_every_n_steps == 0: | |
_loss_dict = process_losses(loss_dict) | |
loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) | |
logger.info(f"iter={iteration}: {loss_str}") | |
if logging_enabled and train_writer and iteration % config.train.log_every_n_steps == 0: | |
for name, value in _loss_dict.items(): | |
train_writer.add_scalar(f"Losses/{name}", value, global_step=iteration) | |
train_writer.flush() | |
if saving_enabled and iteration % config.train.ckpt_every_n_steps == 0: | |
logger.info(f"iter={iteration}: saving checkpoint to `{config.train.ckpt_dir}`") | |
save_checkpoint( | |
config.train.ckpt_dir, | |
{"model": model, "optimizer": optimizer}, | |
iteration=iteration, | |
) | |
if logging_enabled and iteration % config.train.summary_every_n_steps == 0: | |
summaries = model.compute_summaries(preds, batch) | |
save_file_summaries(config.train.run_dir, summaries, prefix="train") | |
if lr_scheduler is not None and iteration and iteration % config.train.update_lr_every == 0: | |
lr_scheduler.step() | |
iteration += 1 | |
if iteration >= config.train.n_max_iters: | |
logger.info(f"reached max number of iters ({config.train.n_max_iters})") | |
break | |
if saving_enabled: | |
logger.info(f"saving the final checkpoint to `{config.train.run_dir}/model.pt`") | |
save_checkpoint(f"{config.train.run_dir}/model.pt", {"model": model}) | |