ramimu's picture
Upload 586 files
1c72248 verified
raw
history blame
1.28 kB
import torch
_dwt = None
def _get_wavelet_loss(device, dtype):
global _dwt
if _dwt is not None:
return _dwt
# init wavelets
from pytorch_wavelets import DWTForward
# wave='db1' wave='haar'
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
device=device, dtype=dtype)
_dwt = dwt
return dwt
def wavelet_loss(model_pred, latents, noise):
model_pred = model_pred.float()
latents = latents.float()
noise = noise.float()
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
with torch.no_grad():
model_input_xll, model_input_xh = dwt(latents)
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
# reverse the noise to get the model prediction of the pure latents
model_pred = noise - model_pred
model_pred_xll, model_pred_xh = dwt(model_pred)
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none")