ramimu's picture
Upload 586 files
1c72248 verified
import os
import weakref
from _weakref import ReferenceType
from typing import TYPE_CHECKING, List, Union
import cv2
import torch
import random
from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
from toolkit.stable_diffusion_model import StableDiffusion
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class FileItemDTO(
LatentCachingFileItemDTOMixin,
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
InpaintControlFileItemDTOMixin,
ClipImageFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
UnconditionalFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
self.path = kwargs.get('path', '')
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.is_video = self.dataset_config.num_frames > 1
size_database = kwargs.get('size_database', {})
dataset_root = kwargs.get('dataset_root', None)
if dataset_root is not None:
# remove dataset root from path
file_key = self.path.replace(dataset_root, '')
else:
file_key = os.path.basename(self.path)
file_signature = get_quick_signature_string(self.path)
if file_signature is None:
raise Exception("Error: Could not get file signature for {self.path}")
use_db_entry = False
if file_key in size_database:
db_entry = size_database[file_key]
if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature:
use_db_entry = True
if use_db_entry:
w, h, _ = size_database[file_key]
elif self.is_video:
# Open the video file
video = cv2.VideoCapture(self.path)
# Check if video opened successfully
if not video.isOpened():
raise Exception(f"Error: Could not open video file {self.path}")
# Get width and height
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
w, h = width, height
# Release the video capture object immediately
video.release()
size_database[file_key] = (width, height, file_signature)
else:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
# process width and height
# try:
# w, h = image_utils.get_image_size(self.path)
# except image_utils.UnknownImageFormat:
# print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
# f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
w, h = img.size
size_database[file_key] = (w, h, file_signature)
self.width: int = w
self.height: int = h
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
super().__init__(*args, **kwargs)
# self.caption_path: str = kwargs.get('caption_path', None)
self.raw_caption: str = kwargs.get('raw_caption', None)
# we scale first, then crop
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
# crop values are from scaled size
self.crop_x: int = kwargs.get('crop_x', 0)
self.crop_y: int = kwargs.get('crop_y', 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.augments: List[str] = self.dataset_config.augments
self.loss_multiplier: float = self.dataset_config.loss_multiplier
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg
self.tensor: Union[torch.Tensor, None] = None
def cleanup(self):
self.tensor = None
self.cleanup_latent()
self.cleanup_control()
self.cleanup_inpaint()
self.cleanup_clip_image()
self.cleanup_mask()
self.cleanup_unconditional()
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
is_latents_cached = self.file_items[0].is_latent_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latents: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[List[dict], None] = None
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
self.control_tensor: Union[torch.Tensor, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
if any([x.control_tensor is not None for x in self.file_items]):
# find one to use as a base
base_control_tensor = None
for x in self.file_items:
if x.control_tensor is not None:
base_control_tensor = x.control_tensor
break
control_tensors = []
for x in self.file_items:
if x.control_tensor is None:
control_tensors.append(torch.zeros_like(base_control_tensor))
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
self.inpaint_tensor: Union[torch.Tensor, None] = None
if any([x.inpaint_tensor is not None for x in self.file_items]):
# find one to use as a base
base_inpaint_tensor = None
for x in self.file_items:
if x.inpaint_tensor is not None:
base_inpaint_tensor = x.inpaint_tensor
break
inpaint_tensors = []
for x in self.file_items:
if x.inpaint_tensor is None:
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
else:
inpaint_tensors.append(x.inpaint_tensor)
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
if any([x.clip_image_tensor is not None for x in self.file_items]):
# find one to use as a base
base_clip_image_tensor = None
for x in self.file_items:
if x.clip_image_tensor is not None:
base_clip_image_tensor = x.clip_image_tensor
break
clip_image_tensors = []
for x in self.file_items:
if x.clip_image_tensor is None:
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
else:
clip_image_tensors.append(x.clip_image_tensor)
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
if any([x.mask_tensor is not None for x in self.file_items]):
# find one to use as a base
base_mask_tensor = None
for x in self.file_items:
if x.mask_tensor is not None:
base_mask_tensor = x.mask_tensor
break
mask_tensors = []
for x in self.file_items:
if x.mask_tensor is None:
mask_tensors.append(torch.zeros_like(base_mask_tensor))
else:
mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
# add unaugmented tensors for ones with augments
if any([x.unaugmented_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unaugmented_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unaugmented_tensor = x.unaugmented_tensor
break
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
# add unconditional tensors
if any([x.unconditional_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unconditional_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unconditional_tensor = x.unconditional_tensor
break
unconditional_tensor = []
for x in self.file_items:
if x.unconditional_tensor is None:
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
else:
unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
if any([x.clip_image_embeds is not None for x in self.file_items]):
self.clip_image_embeds = []
for x in self.file_items:
if x.clip_image_embeds is not None:
self.clip_image_embeds.append(x.clip_image_embeds)
else:
raise Exception("clip_image_embeds is None for some file items")
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
self.clip_image_embeds_unconditional = []
for x in self.file_items:
if x.clip_image_embeds_unconditional is not None:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]
def get_network_weight_list(self):
return [x.network_weight for x in self.file_items]
def get_caption_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
return [x.caption for x in self.file_items]
def get_caption_short_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
):
return [x.caption_short for x in self.file_items]
def cleanup(self):
del self.latents
del self.tensor
del self.control_tensor
for file_item in self.file_items:
file_item.cleanup()