File size: 920 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import lightning.pytorch as pl
from lightning.pytorch.utilities import grad_norm
from torch.optim import Optimizer

class GradientMonitor(pl.Callback):
    """Logs the gradient norm"""

    def __init__(self, norm_type: int = 2):
        norm_type = float(norm_type)
        if norm_type <= 0:
            raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
        self.norm_type = norm_type

    def on_before_optimizer_step(
            self, trainer: "pl.Trainer",
            pl_module: "pl.LightningModule",
            optimizer: Optimizer
    ) -> None:
        norms = grad_norm(pl_module, norm_type=self.norm_type)
        max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
        pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})