Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,854 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import lightning.pytorch as pl
from lightning.pytorch import Callback
import os.path
import numpy
from PIL import Image
from typing import Sequence, Any, Dict
from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info
def process_fn(image, path):
Image.fromarray(image).save(path)
class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
self.save_dir = save_dir
self.max_save_num = max_save_num
self.compressed = compressed
def save_start(self, target_dir):
self.target_dir = target_dir
self.executor_pool = ThreadPoolExecutor(max_workers=8)
if not os.path.exists(self.target_dir):
os.makedirs(self.target_dir, exist_ok=True)
else:
if os.listdir(target_dir) and "debug" not in str(target_dir):
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
self.samples = []
self._have_saved_num = 0
rank_zero_info(f"Save images to {self.target_dir}")
def save_image(self, images, filenames):
images = images.permute(0, 2, 3, 1).cpu().numpy()
for sample, filename in zip(images, filenames):
if isinstance(filename, Sequence):
filename = filename[0]
path = f'{self.target_dir}/{filename}'
if self._have_saved_num >= self.max_save_num:
break
self.executor_pool.submit(process_fn, sample, path)
self._have_saved_num += 1
def process_batch(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: STEP_OUTPUT,
batch: Any,
) -> None:
b, c, h, w = samples.shape
xT, y, metadata = batch
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
self.save_image(samples, metadata)
if trainer.is_global_zero:
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
self.samples.append(all_samples)
def save_end(self):
if self.compressed and len(self.samples) > 0:
samples = numpy.concatenate(self.samples)
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
self.executor_pool.shutdown(wait=True)
self.samples = []
self.target_dir = None
self._have_saved_num = 0
self.executor_pool = None
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir)
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
self.save_start(target_dir)
def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, samples, batch)
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end() |