Spaces:
Running
on
Zero
Running
on
Zero
File size: 920 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import lightning.pytorch as pl
from lightning.pytorch.utilities import grad_norm
from torch.optim import Optimizer
class GradientMonitor(pl.Callback):
"""Logs the gradient norm"""
def __init__(self, norm_type: int = 2):
norm_type = float(norm_type)
if norm_type <= 0:
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
self.norm_type = norm_type
def on_before_optimizer_step(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
optimizer: Optimizer
) -> None:
norms = grad_norm(pl_module, norm_type=self.norm_type)
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) |