Spaces:
Running
on
Zero
Running
on
Zero
File size: 774 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import os.path
from typing import Optional, Dict, Any
import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from soupsieve.util import lower
class CheckpointHook(ModelCheckpoint):
"""Save checkpoint with only the incremental part of the model"""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.dirpath = trainer.default_root_dir
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
pl_module.strict_loading = False
def on_save_checkpoint(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
del checkpoint["callbacks"] |