import torch _dwt = None def _get_wavelet_loss(device, dtype): global _dwt if _dwt is not None: return _dwt # init wavelets from pytorch_wavelets import DWTForward # wave='db1' wave='haar' dwt = DWTForward(J=1, mode='zero', wave='haar').to( device=device, dtype=dtype) _dwt = dwt return dwt def wavelet_loss(model_pred, latents, noise): model_pred = model_pred.float() latents = latents.float() noise = noise.float() dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype) with torch.no_grad(): model_input_xll, model_input_xh = dwt(latents) model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2) model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1) # reverse the noise to get the model prediction of the pure latents model_pred = noise - model_pred model_pred_xll, model_pred_xh = dwt(model_pred) model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2) model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1) return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")