Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,686 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)
|