|
import os |
|
import copy |
|
import matplotlib.pyplot as plt |
|
from typing import Any |
|
|
|
import torch |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from torch.optim import AdamW, Optimizer |
|
from torch.utils.data import DataLoader |
|
from lightning import LightningModule |
|
|
|
from torchmetrics import MetricCollection |
|
from torchmetrics.image import PeakSignalNoiseRatio as PSNR |
|
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM |
|
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity as LPIPS |
|
|
|
from model.model import MultiInputResShift |
|
|
|
from utils.utils import denorm, make_grid_images |
|
from utils.ema import EMA |
|
from utils.inter_frame_idx import get_inter_frame_temp_index |
|
from utils.raft import raft_flow |
|
|
|
|
|
class TrainPipline(LightningModule): |
|
def __init__(self, |
|
confg: dict, |
|
test_dataloader: DataLoader): |
|
super(TrainPipline, self).__init__() |
|
|
|
self.test_dataloader = test_dataloader |
|
|
|
self.confg = confg |
|
|
|
self.mean, self.sd = confg["data_confg"]["mean"], confg["data_confg"]["sd"] |
|
|
|
self.model = MultiInputResShift(**confg["model_confg"]) |
|
self.model.flow_model.requires_grad_(False).eval() |
|
|
|
self.ema = EMA(beta=0.995) |
|
self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False) |
|
|
|
self.charbonnier_loss = lambda x, y: torch.mean(torch.sqrt((x - y)**2 + 1e-6)) |
|
self.lpips_loss = LPIPS(net_type='vgg') |
|
|
|
self.train_metrics = MetricCollection({ |
|
"train_lpips": LPIPS(net_type='alex'), |
|
"train_psnr": PSNR(), |
|
"train_ssim": SSIM() |
|
}) |
|
self.val_metrics = MetricCollection({ |
|
"val_lpips": LPIPS(net_type='alex'), |
|
"val_psnr": PSNR(), |
|
"val_ssim": SSIM() |
|
}) |
|
|
|
def loss_fn(self, |
|
x: torch.Tensor, |
|
predicted_x: torch.Tensor) -> torch.Tensor: |
|
percep_loss = 0.2 * self.lpips_loss(x, predicted_x.clamp(-1, 1)) |
|
pix2pix_loss = self.charbonnier_loss(x, predicted_x) |
|
return percep_loss + pix2pix_loss |
|
|
|
def sample_t(self, |
|
shape: tuple[int, ...], |
|
max_t: int, |
|
device: torch.device) -> torch.Tensor: |
|
p = torch.linspace(1, max_t, steps=max_t, device=device) ** 2 |
|
p = p / p.sum() |
|
t = torch.multinomial(p, num_samples=shape[0], replacement=True) |
|
return t |
|
|
|
def forward(self, |
|
I0: torch.Tensor, |
|
It: torch.Tensor, |
|
I1: torch.Tensor) -> torch.Tensor: |
|
flow0tot = raft_flow(I0, It, 'animation') |
|
flow1tot = raft_flow(I1, It, 'animation') |
|
mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype) |
|
|
|
tau = torch.stack([mid_idx, 1 - mid_idx], dim=1) |
|
|
|
if self.current_epoch > 5: |
|
t = torch.randint(low=1, high=self.model.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long) |
|
else: |
|
t = self.sample_t(shape=(It.shape[0],), max_t=self.model.timesteps, device=It.device) |
|
|
|
predicted_It = self.model(I0, It, I1, tau=tau, t=t) |
|
return predicted_It |
|
|
|
def get_step_plt_images(self, |
|
It: torch.Tensor, |
|
predicted_It: torch.Tensor) -> plt.Figure: |
|
fig, ax = plt.subplots(1, 2, figsize=(20, 10)) |
|
ax[0].imshow(denorm(predicted_It.clamp(-1, 1), self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy()) |
|
ax[0].axis("off") |
|
ax[0].set_title("Predicted") |
|
ax[1].imshow(denorm(It, self.mean, self.sd)[0].permute(1, 2, 0).cpu().numpy()) |
|
ax[1].axis("off") |
|
ax[1].set_title("Ground Truth") |
|
plt.tight_layout() |
|
|
|
|
|
plt.close(fig) |
|
return fig |
|
|
|
def training_step(self, batch: tuple[torch.Tensor, ...], _) -> torch.Tensor: |
|
I0, It, I1 = batch |
|
predicted_It = self(I0, It, I1) |
|
loss = self.loss_fn(It, predicted_It) |
|
|
|
self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True, on_step=True, on_epoch=False, sync_dist=True) |
|
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True) |
|
|
|
self.ema.step_ema(self.ema_model, self.model) |
|
with torch.inference_mode(): |
|
fig = self.get_step_plt_images(It, predicted_It) |
|
self.logger.experiment.add_figure("Train Predictions", fig, self.global_step) |
|
mets = self.train_metrics(It, predicted_It.clamp(-1, 1)) |
|
self.log_dict(mets, prog_bar=True, on_step=True,on_epoch=False) |
|
return loss |
|
|
|
@torch.no_grad() |
|
def validation_step(self, batch: tuple[torch.Tensor, ...], _) -> None: |
|
I0, It, I1 = batch |
|
predicted_It = self(I0, It, I1) |
|
loss = self.loss_fn(It, predicted_It) |
|
|
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) |
|
|
|
mets = self.val_metrics(It, predicted_It.clamp(-1, 1)) |
|
self.log_dict(mets, prog_bar=True, on_step=False, on_epoch=True) |
|
|
|
@torch.inference_mode() |
|
def on_train_epoch_end(self) -> None: |
|
torch.save(self.ema_model.state_dict(), |
|
os.path.join("_checkpoint", f"resshift_diff_{self.current_epoch}.pth")) |
|
|
|
batch = next(iter(self.test_dataloader)) |
|
I0, It, I1 = batch |
|
I0, It, I1 = I0.to(self.device), It.to(self.device), I1.to(self.device) |
|
|
|
flow0tot = raft_flow(I0, It, 'animation') |
|
flow1tot = raft_flow(I1, It, 'animation') |
|
mid_idx = get_inter_frame_temp_index(I0, It, I1, flow0tot, flow1tot).to(It.dtype) |
|
tau = torch.stack([mid_idx, 1 - mid_idx], dim=1) |
|
|
|
predicted_It = self.ema_model.reverse_process([I0, I1], tau) |
|
|
|
I0 = denorm(I0, self.mean, self.sd) |
|
I1 = denorm(I1, self.mean, self.sd) |
|
It = denorm(It, self.mean, self.sd) |
|
predicted_It = denorm(predicted_It.clamp(-1, 1), self.mean, self.sd) |
|
|
|
|
|
grid = make_grid_images([I0, It, predicted_It, I1], nrow=1) |
|
self.logger.experiment.add_image("Predicted Images", grid, self.global_step) |
|
|
|
def configure_optimizers(self) -> tuple[list[Optimizer], list[dict[str, Any]]]: |
|
optimizer = [AdamW( |
|
self.model.parameters(), |
|
**self.confg["optim_confg"]['optimizer_confg'] |
|
)] |
|
|
|
scheduler = [{ |
|
'scheduler': ReduceLROnPlateau( |
|
optimizer[0], |
|
**self.confg["optim_confg"]['scheduler_confg'] |
|
), |
|
'monitor': 'val_loss', |
|
'interval': 'epoch', |
|
'frequency': 1, |
|
'strict': True, |
|
}] |
|
|
|
return optimizer, scheduler |
|
|
|
|