File size: 3,854 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import lightning.pytorch as pl
from lightning.pytorch import Callback


import os.path
import numpy
from PIL import Image
from typing import Sequence, Any, Dict
from concurrent.futures import ThreadPoolExecutor

from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info

def process_fn(image, path):
    Image.fromarray(image).save(path)

class SaveImagesHook(Callback):
    def __init__(self, save_dir="val", max_save_num=0, compressed=True):
       self.save_dir = save_dir
       self.max_save_num = max_save_num
       self.compressed = compressed

    def save_start(self, target_dir):
        self.target_dir = target_dir
        self.executor_pool = ThreadPoolExecutor(max_workers=8)
        if not os.path.exists(self.target_dir):
            os.makedirs(self.target_dir, exist_ok=True)
        else:
            if os.listdir(target_dir) and "debug" not in str(target_dir):
                raise FileExistsError(f'{self.target_dir} already exists and not empty!')
        self.samples = []
        self._have_saved_num = 0
        rank_zero_info(f"Save images to {self.target_dir}")

    def save_image(self, images, filenames):
        images = images.permute(0, 2, 3, 1).cpu().numpy()
        for sample, filename in zip(images, filenames):
            if isinstance(filename, Sequence):
                filename = filename[0]
            path = f'{self.target_dir}/{filename}'
            if self._have_saved_num >= self.max_save_num:
                break
            self.executor_pool.submit(process_fn, sample, path)
            self._have_saved_num += 1

    def process_batch(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        samples: STEP_OUTPUT,
        batch: Any,
    ) -> None:
        b, c, h, w = samples.shape
        xT, y, metadata = batch
        all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
        self.save_image(samples, metadata)
        if trainer.is_global_zero:
            all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
            self.samples.append(all_samples)

    def save_end(self):
        if self.compressed and len(self.samples) > 0:
            samples = numpy.concatenate(self.samples)
            numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
        self.executor_pool.shutdown(wait=True)
        self.samples = []
        self.target_dir = None
        self._have_saved_num = 0
        self.executor_pool = None

    def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
        self.save_start(target_dir)

    def on_validation_batch_end(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        return self.process_batch(trainer, pl_module, outputs, batch)

    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.save_end()

    def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
        self.save_start(target_dir)

    def on_predict_batch_end(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        samples: Any,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        return self.process_batch(trainer, pl_module, samples, batch)

    def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        self.save_end()