Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,798 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
import os.path
import copy
import torch
import torch.nn as nn
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Optimizer
from lightning.pytorch.callbacks import Callback
from src.models.vae import BaseVAE, fp2uint8
from src.models.conditioner import BaseConditioner
from src.utils.model_loader import ModelLoader
from src.callbacks.simple_ema import SimpleEMA
from src.diffusion.base.sampling import BaseSampler
from src.diffusion.base.training import BaseTrainer
from src.utils.no_grad import no_grad, filter_nograd_tensors
from src.utils.copy import copy_params
EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
OptimizerCallable = Callable[[Iterable], Optimizer]
LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
class LightningModel(pl.LightningModule):
def __init__(self,
vae: BaseVAE,
conditioner: BaseConditioner,
denoiser: nn.Module,
diffusion_trainer: BaseTrainer,
diffusion_sampler: BaseSampler,
ema_tracker: Optional[EMACallable] = None,
optimizer: OptimizerCallable = None,
lr_scheduler: LRSchedulerCallable = None,
):
super().__init__()
self.vae = vae
self.conditioner = conditioner
self.denoiser = denoiser
self.ema_denoiser = copy.deepcopy(self.denoiser)
self.diffusion_sampler = diffusion_sampler
self.diffusion_trainer = diffusion_trainer
self.ema_tracker = ema_tracker
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
# self.model_loader = ModelLoader()
self._strict_loading = False
def configure_model(self) -> None:
self.trainer.strategy.barrier()
# self.denoiser = self.model_loader.load(self.denoiser)
copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
# self.denoiser = torch.compile(self.denoiser)
# disable grad for conditioner and vae
no_grad(self.conditioner)
no_grad(self.vae)
no_grad(self.diffusion_sampler)
no_grad(self.ema_denoiser)
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
return [ema_tracker]
def configure_optimizers(self) -> OptimizerLRScheduler:
params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
if self.lr_scheduler is None:
return dict(
optimizer=optimizer
)
else:
lr_scheduler = self.lr_scheduler(optimizer)
return dict(
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
def training_step(self, batch, batch_idx):
raw_images, x, y = batch
with torch.no_grad():
x = self.vae.encode(x)
condition, uncondition = self.conditioner(y)
loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
return loss["loss"]
def predict_step(self, batch, batch_idx):
xT, y, metadata = batch
with torch.no_grad():
condition, uncondition = self.conditioner(y)
# Sample images:
samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
samples = self.vae.decode(samples)
# fp32 -1,1 -> uint8 0,255
samples = fp2uint8(samples)
return samples
def validation_step(self, batch, batch_idx):
samples = self.predict_step(batch, batch_idx)
return samples
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = {}
self._save_to_state_dict(destination, prefix, keep_vars)
self.denoiser.state_dict(
destination=destination,
prefix=prefix+"denoiser.",
keep_vars=keep_vars)
self.ema_denoiser.state_dict(
destination=destination,
prefix=prefix+"ema_denoiser.",
keep_vars=keep_vars)
self.diffusion_trainer.state_dict(
destination=destination,
prefix=prefix+"diffusion_trainer.",
keep_vars=keep_vars)
return destination |