Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import lightning.pytorch as pl | |
from lightning.pytorch.utilities import grad_norm | |
from torch.optim import Optimizer | |
class GradientMonitor(pl.Callback): | |
"""Logs the gradient norm""" | |
def __init__(self, norm_type: int = 2): | |
norm_type = float(norm_type) | |
if norm_type <= 0: | |
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") | |
self.norm_type = norm_type | |
def on_before_optimizer_step( | |
self, trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
optimizer: Optimizer | |
) -> None: | |
norms = grad_norm(pl_module, norm_type=self.norm_type) | |
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() | |
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) |