import os.path from typing import Optional, Dict, Any import lightning.pytorch as pl from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from soupsieve.util import lower class CheckpointHook(ModelCheckpoint): """Save checkpoint with only the incremental part of the model""" def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: self.dirpath = trainer.default_root_dir self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") pl_module.strict_loading = False def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: del checkpoint["callbacks"]