# Copyright 2024 MIT Han Lab # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import os import pathlib from typing import Any, Callable, Optional, Union import numpy as np from PIL import Image from torch.utils.data.dataset import Dataset from torchvision.datasets import ImageFolder __all__ = ["load_image", "load_image_from_dir", "DMCrop", "CustomImageFolder", "ImageDataset"] def load_image(data_path: str, mode="rgb") -> Image.Image: img = Image.open(data_path) if mode == "rgb": img = img.convert("RGB") return img def load_image_from_dir( dir_path: str, suffix: Union[str, tuple[str, ...], list[str]] = (".jpg", ".JPEG", ".png"), return_mode="path", k: Optional[int] = None, shuffle_func: Optional[Callable] = None, ) -> Union[list, tuple[list, list]]: suffix = [suffix] if isinstance(suffix, str) else suffix file_list = [] for dirpath, _, fnames in os.walk(dir_path): for fname in fnames: if pathlib.Path(fname).suffix not in suffix: continue image_path = os.path.join(dirpath, fname) file_list.append(image_path) if shuffle_func is not None and k is not None: shuffle_file_list = shuffle_func(file_list) file_list = shuffle_file_list or file_list file_list = file_list[:k] file_list = sorted(file_list) if return_mode == "path": return file_list else: files = [] path_list = [] for file_path in file_list: try: files.append(load_image(file_path)) path_list.append(file_path) except Exception: print(f"Fail to load {file_path}") if return_mode == "image": return files else: return path_list, files class DMCrop: """center/random crop used in diffusion models""" def __init__(self, size: int) -> None: self.size = size def __call__(self, pil_image: Image.Image) -> Image.Image: """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ image_size = self.size if pil_image.size == (image_size, image_size): return pil_image while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) class CustomImageFolder(ImageFolder): def __init__(self, root: str, transform: Optional[Callable] = None, return_dict: bool = False): root = os.path.expanduser(root) self.return_dict = return_dict super().__init__(root, transform) def __getitem__(self, index: int) -> Union[dict[str, Any], tuple[Any, Any]]: path, target = self.samples[index] image = load_image(path) if self.transform is not None: image = self.transform(image) if self.return_dict: return { "index": index, "image_path": path, "image": image, "label": target, } else: return image, target class ImageDataset(Dataset): def __init__( self, data_dirs: Union[str, list[str]], splits: Optional[Union[str, list[Optional[str]]]] = None, transform: Optional[Callable] = None, suffix=(".jpg", ".JPEG", ".png"), pil=True, return_dict=True, ) -> None: super().__init__() self.data_dirs = [data_dirs] if isinstance(data_dirs, str) else data_dirs if isinstance(splits, list): assert len(splits) == len(self.data_dirs) self.splits = splits elif isinstance(splits, str): assert len(self.data_dirs) == 1 self.splits = [splits] else: self.splits = [None for _ in range(len(self.data_dirs))] self.transform = transform self.pil = pil self.return_dict = return_dict # load all images [image_path] self.samples = [] for data_dir, split in zip(self.data_dirs, self.splits): if split is None: samples = load_image_from_dir(data_dir, suffix, return_mode="path") else: samples = [] with open(split) as fin: for line in fin.readlines(): relative_path = line[:-1] full_path = os.path.join(data_dir, relative_path) samples.append(full_path) self.samples += samples def __len__(self) -> int: return len(self.samples) def __getitem__(self, index: int, skip_image=False) -> dict[str, Any]: image_path = self.samples[index] if skip_image: image = None else: try: image = load_image(image_path, return_pil=self.pil) except Exception: print(f"Fail to load {image_path}") raise OSError if self.transform is not None: image = self.transform(image) if self.return_dict: return { "index": index, "image_path": image_path, "image_name": os.path.basename(image_path), "data": image, } else: return image