Spaces:
Paused
Paused
File size: 1,284 Bytes
1c72248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
_dwt = None
def _get_wavelet_loss(device, dtype):
global _dwt
if _dwt is not None:
return _dwt
# init wavelets
from pytorch_wavelets import DWTForward
# wave='db1' wave='haar'
dwt = DWTForward(J=1, mode='zero', wave='haar').to(
device=device, dtype=dtype)
_dwt = dwt
return dwt
def wavelet_loss(model_pred, latents, noise):
model_pred = model_pred.float()
latents = latents.float()
noise = noise.float()
dwt = _get_wavelet_loss(model_pred.device, model_pred.dtype)
with torch.no_grad():
model_input_xll, model_input_xh = dwt(latents)
model_input_xlh, model_input_xhl, model_input_xhh = torch.unbind(model_input_xh[0], dim=2)
model_input = torch.cat([model_input_xll, model_input_xlh, model_input_xhl, model_input_xhh], dim=1)
# reverse the noise to get the model prediction of the pure latents
model_pred = noise - model_pred
model_pred_xll, model_pred_xh = dwt(model_pred)
model_pred_xlh, model_pred_xhl, model_pred_xhh = torch.unbind(model_pred_xh[0], dim=2)
model_pred = torch.cat([model_pred_xll, model_pred_xlh, model_pred_xhl, model_pred_xhh], dim=1)
return torch.nn.functional.mse_loss(model_pred, model_input, reduction="none") |