whyun13's picture
Upload folder using huggingface_hub
882f6e2 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch as th
import os
import re
import glob
import copy
from typing import Dict, Any, Iterator, Mapping, Optional, Union, Tuple, List
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter
from omegaconf import OmegaConf, DictConfig
from torch.optim.lr_scheduler import LRScheduler
from visualize.ca_body.utils.torch import to_device
from visualize.ca_body.utils.module_loader import load_class, build_optimizer
import torch.nn as nn
import logging
logging.basicConfig(
format="[%(asctime)s][%(levelname)s][%(name)s]:%(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def process_losses(
loss_dict: Dict[str, Any], reduce: bool = True, detach: bool = True
) -> Dict[str, th.Tensor]:
"""Preprocess the dict of losses outputs."""
result = {k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")}
if detach:
result = {k: v.detach() for k, v in result.items()}
if reduce:
result = {k: float(v.mean().item()) for k, v in result.items()}
return result
def load_config(path: str) -> DictConfig:
# NOTE: THIS IS THE ONLY PLACE WHERE WE MODIFY CONFIG
config = OmegaConf.load(path)
# TODO: we should need to get rid of this in favor of DB
assert 'CARE_ROOT' in os.environ
config.CARE_ROOT = os.environ['CARE_ROOT']
logger.info(f'{config.CARE_ROOT=}')
if not os.path.isabs(config.train.run_dir):
config.train.run_dir = os.path.join(os.environ['CARE_ROOT'], config.train.run_dir)
logger.info(f'{config.train.run_dir=}')
os.makedirs(config.train.run_dir, exist_ok=True)
return config
def load_from_config(config: Mapping[str, Any], **kwargs):
"""Instantiate an object given a config and arguments."""
assert 'class_name' in config and 'module_name' not in config
config = copy.deepcopy(config)
ckpt = None if 'ckpt' not in config else config.pop('ckpt')
class_name = config.pop('class_name')
object_class = load_class(class_name)
instance = object_class(**config, **kwargs)
if ckpt is not None:
load_checkpoint(
ckpt_path=ckpt.path,
modules={ckpt.get('module_name', 'model'): instance},
ignore_names=ckpt.get('ignore_names', []),
strict=ckpt.get('strict', False),
)
return instance
def save_checkpoint(ckpt_path, modules: Dict[str, Any], iteration=None, keep_last_k=None):
if keep_last_k is not None:
raise NotImplementedError()
ckpt_dict = {}
if os.path.isdir(ckpt_path):
assert iteration is not None
ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt")
ckpt_dict["iteration"] = iteration
for name, mod in modules.items():
if hasattr(mod, "module"):
mod = mod.module
ckpt_dict[name] = mod.state_dict()
th.save(ckpt_dict, ckpt_path)
def filter_params(params, ignore_names):
return OrderedDict(
[
(k, v)
for k, v in params.items()
if not any([re.match(n, k) is not None for n in ignore_names])
]
)
def save_file_summaries(path: str, summaries: Dict[str, Tuple[str, Any]]):
"""Saving regular summaries for monitoring purposes."""
for name, (value, ext) in summaries.items():
#save(f'{path}/{name}.{ext}', value)
raise NotImplementedError()
def load_checkpoint(
ckpt_path: str,
modules: Dict[str, Any],
iteration: int =None,
strict: bool =False,
map_location: Optional[str] =None,
ignore_names: Optional[Dict[str, List[str]]]=None,
):
"""Load a checkpoint.
Args:
ckpt_path: directory or the full path to the checkpoint
"""
if map_location is None:
map_location = "cpu"
# adding
if os.path.isdir(ckpt_path):
if iteration is None:
# lookup latest iteration
iteration = max(
[
int(os.path.splitext(os.path.basename(p))[0])
for p in glob.glob(os.path.join(ckpt_path, "*.pt"))
]
)
ckpt_path = os.path.join(ckpt_path, f"{iteration:06d}.pt")
logger.info(f"loading checkpoint {ckpt_path}")
ckpt_dict = th.load(ckpt_path, map_location=map_location)
for name, mod in modules.items():
params = ckpt_dict[name]
if ignore_names is not None and name in ignore_names:
logger.info(f"skipping: {ignore_names[name]}")
params = filter_params(params, ignore_names[name])
mod.load_state_dict(params, strict=strict)
def train(
model: nn.Module,
loss_fn: nn.Module,
optimizer: th.optim.Optimizer,
train_data: Iterator,
config: Mapping[str, Any],
lr_scheduler: Optional[LRScheduler] = None,
train_writer: Optional[SummaryWriter] = None,
saving_enabled: bool = True,
logging_enabled: bool = True,
iteration: int = 0,
device: Optional[Union[th.device, str]] = "cuda:0",
) -> None:
for batch in train_data:
if batch is None:
logger.info("skipping empty batch")
continue
batch = to_device(batch, device)
batch["iteration"] = iteration
# leaving only inputs acutally used by the model
preds = model(**filter_inputs(batch, model, required_only=False))
# TODO: switch to the old-school loss computation
loss, loss_dict = loss_fn(preds, batch, iteration=iteration)
assert not th.isnan(loss), "loss is NaN"
if th.isnan(loss):
_loss_dict = process_losses(loss_dict)
loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()])
logger.info(f"iter={iteration}: {loss_str}")
raise ValueError("loss is NaN")
optimizer.zero_grad()
loss.backward()
optimizer.step()
if logging_enabled and iteration % config.train.log_every_n_steps == 0:
_loss_dict = process_losses(loss_dict)
loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()])
logger.info(f"iter={iteration}: {loss_str}")
if logging_enabled and train_writer and iteration % config.train.log_every_n_steps == 0:
for name, value in _loss_dict.items():
train_writer.add_scalar(f"Losses/{name}", value, global_step=iteration)
train_writer.flush()
if saving_enabled and iteration % config.train.ckpt_every_n_steps == 0:
logger.info(f"iter={iteration}: saving checkpoint to `{config.train.ckpt_dir}`")
save_checkpoint(
config.train.ckpt_dir,
{"model": model, "optimizer": optimizer},
iteration=iteration,
)
if logging_enabled and iteration % config.train.summary_every_n_steps == 0:
summaries = model.compute_summaries(preds, batch)
save_file_summaries(config.train.run_dir, summaries, prefix="train")
if lr_scheduler is not None and iteration and iteration % config.train.update_lr_every == 0:
lr_scheduler.step()
iteration += 1
if iteration >= config.train.n_max_iters:
logger.info(f"reached max number of iters ({config.train.n_max_iters})")
break
if saving_enabled:
logger.info(f"saving the final checkpoint to `{config.train.run_dir}/model.pt`")
save_checkpoint(f"{config.train.run_dir}/model.pt", {"model": model})