Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import copy | |
import timm | |
from torch.nn import Parameter | |
from src.utils.no_grad import no_grad | |
from typing import Callable, Iterator, Tuple | |
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from torchvision.transforms import Normalize | |
from src.diffusion.base.training import * | |
from src.diffusion.base.scheduling import BaseScheduler | |
def inverse_sigma(alpha, sigma): | |
return 1/sigma**2 | |
def snr(alpha, sigma): | |
return alpha/sigma | |
def minsnr(alpha, sigma, threshold=5): | |
return torch.clip(alpha/sigma, min=threshold) | |
def maxsnr(alpha, sigma, threshold=5): | |
return torch.clip(alpha/sigma, max=threshold) | |
def constant(alpha, sigma): | |
return 1 | |
class DINOv2(nn.Module): | |
def __init__(self, weight_path:str): | |
super(DINOv2, self).__init__() | |
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) | |
self.pos_embed = copy.deepcopy(self.encoder.pos_embed) | |
self.encoder.head = torch.nn.Identity() | |
self.patch_size = self.encoder.patch_embed.patch_size | |
self.precomputed_pos_embed = dict() | |
def fetch_pos(self, h, w): | |
key = (h, w) | |
if key in self.precomputed_pos_embed: | |
return self.precomputed_pos_embed[key] | |
value = timm.layers.pos_embed.resample_abs_pos_embed( | |
self.pos_embed.data, [h, w], | |
) | |
self.precomputed_pos_embed[key] = value | |
return value | |
def forward(self, x): | |
b, c, h, w = x.shape | |
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) | |
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') | |
b, c, h, w = x.shape | |
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] | |
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) | |
self.encoder.pos_embed.data = pos_embed_data | |
feature = self.encoder.forward_features(x)['x_norm_patchtokens'] | |
return feature | |
class REPATrainer(BaseTrainer): | |
def __init__( | |
self, | |
scheduler: BaseScheduler, | |
loss_weight_fn:Callable=constant, | |
feat_loss_weight: float=0.5, | |
lognorm_t=False, | |
encoder_weight_path=None, | |
align_layer=8, | |
proj_denoiser_dim=256, | |
proj_hidden_dim=256, | |
proj_encoder_dim=256, | |
*args, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.lognorm_t = lognorm_t | |
self.scheduler = scheduler | |
self.loss_weight_fn = loss_weight_fn | |
self.feat_loss_weight = feat_loss_weight | |
self.align_layer = align_layer | |
self.encoder = DINOv2(encoder_weight_path) | |
no_grad(self.encoder) | |
self.proj = nn.Sequential( | |
nn.Sequential( | |
nn.Linear(proj_denoiser_dim, proj_hidden_dim), | |
nn.SiLU(), | |
nn.Linear(proj_hidden_dim, proj_hidden_dim), | |
nn.SiLU(), | |
nn.Linear(proj_hidden_dim, proj_encoder_dim), | |
) | |
) | |
def _impl_trainstep(self, net, ema_net, raw_images, x, y): | |
batch_size, c, height, width = x.shape | |
if self.lognorm_t: | |
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() | |
else: | |
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) | |
t = base_t | |
noise = torch.randn_like(x) | |
alpha = self.scheduler.alpha(t) | |
dalpha = self.scheduler.dalpha(t) | |
sigma = self.scheduler.sigma(t) | |
dsigma = self.scheduler.dsigma(t) | |
x_t = alpha * x + noise * sigma | |
v_t = dalpha * x + dsigma * noise | |
src_feature = [] | |
def forward_hook(net, input, output): | |
src_feature.append(output) | |
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) | |
out = net(x_t, t, y) | |
src_feature = self.proj(src_feature[0]) | |
handle.remove() | |
with torch.no_grad(): | |
dst_feature = self.encoder(raw_images) | |
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) | |
cos_loss = 1 - cos_sim | |
weight = self.loss_weight_fn(alpha, sigma) | |
fm_loss = weight*(out - v_t)**2 | |
out = dict( | |
fm_loss=fm_loss.mean(), | |
cos_loss=cos_loss.mean(), | |
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), | |
) | |
return out | |
def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | |
self.proj.state_dict( | |
destination=destination, | |
prefix=prefix + "proj.", | |
keep_vars=keep_vars) | |