LoRa_Streamlit / ai-toolkit /toolkit /dataloader_mixins.py
ramimu's picture
Upload 586 files
1c72248 verified
raw
history blame
98.1 kB
import base64
import glob
import hashlib
import json
import math
import os
import random
from collections import OrderedDict
from typing import TYPE_CHECKING, List, Dict, Union
import traceback
import cv2
import numpy as np
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
from toolkit.basic import flush, value_map
from toolkit.buckets import get_bucket_for_image_size, get_resolution
from toolkit.config_modules import ControlTypes
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image, ImageFilter, ImageOps
from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
from toolkit.train_tools import get_torch_dtype
if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
from toolkit.stable_diffusion_model import StableDiffusion
accelerator = get_accelerator()
# def get_associated_caption_from_img_path(img_path):
# https://demo.albumentations.ai/
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
'RandomEqualize': transforms.RandomEqualize(p=0.2),
}
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
def standardize_images(images):
"""
Standardize the given batch of images using the specified mean and std.
Expects values of 0 - 1
Args:
images (torch.Tensor): A batch of images in the shape of (N, C, H, W),
where N is the number of images, C is the number of channels,
H is the height, and W is the width.
Returns:
torch.Tensor: Standardized images.
"""
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# Define the normalization transform
normalize = transforms.Normalize(mean=mean, std=std)
# Apply normalization to each image in the batch
standardized_images = torch.stack([normalize(img) for img in images])
return standardized_images
def clean_caption(caption):
# this doesnt make any sense anymore in a world that is not based on comma seperated tokens
# # remove any newlines
# caption = caption.replace('\n', ', ')
# # remove new lines for all operating systems
# caption = caption.replace('\r', ', ')
# caption_split = caption.split(',')
# # remove empty strings
# caption_split = [p.strip() for p in caption_split if p.strip()]
# # join back together
# caption = ', '.join(caption_split)
return caption
class CaptionMixin:
def get_caption_item(self: 'AiToolkitDataset', index):
if not hasattr(self, 'caption_type'):
raise Exception('caption_type not found on class instance')
if not hasattr(self, 'file_list'):
raise Exception('file_list not found on class instance')
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = None
ext = self.dataset_config.caption_ext
prompt_path = path_no_ext + ext
else:
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + ext
# allow folders to have a default prompt
default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt')
default_prompt_path_with_ext = os.path.join(os.path.dirname(img_path), 'default' + ext)
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# check if is json
if prompt_path.endswith('.json'):
prompt = json.loads(prompt)
if 'caption' in prompt:
prompt = prompt['caption']
prompt = clean_caption(prompt)
elif os.path.exists(default_prompt_path_with_ext):
with open(default_prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
prompt = clean_caption(prompt)
elif os.path.exists(default_prompt_path):
with open(default_prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
prompt = clean_caption(prompt)
else:
prompt = ''
# get default_prompt if it exists on the class instance
if hasattr(self, 'default_prompt'):
prompt = self.default_prompt
if hasattr(self, 'default_caption'):
prompt = self.default_caption
# handle replacements
replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else []
for replacement in replacement_list:
from_string, to_string = replacement.split('|')
prompt = prompt.replace(from_string, to_string)
return prompt
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
from toolkit.data_transfer_object.data_loader import FileItemDTO
class Bucket:
def __init__(self, width: int, height: int):
self.width = width
self.height = height
self.file_list_idx: List[int] = []
class BucketsMixin:
def __init__(self):
self.buckets: Dict[str, Bucket] = {}
self.batch_indices: List[List[int]] = []
def build_batch_indices(self: 'AiToolkitDataset'):
self.batch_indices = []
for key, bucket in self.buckets.items():
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
batch = bucket.file_list_idx[start_idx:end_idx]
self.batch_indices.append(batch)
def shuffle_buckets(self: 'AiToolkitDataset'):
for key, bucket in self.buckets.items():
random.shuffle(bucket.file_list_idx)
def setup_buckets(self: 'AiToolkitDataset', quiet=False):
if not hasattr(self, 'file_list'):
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
if not hasattr(self, 'dataset_config'):
raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
if self.epoch_num > 0 and self.dataset_config.poi is None:
# no need to rebuild buckets for now
# todo handle random cropping for buckets
return
self.buckets = {} # clear it
config: 'DatasetConfig' = self.dataset_config
resolution = config.resolution
bucket_tolerance = config.bucket_tolerance
file_list: List['FileItemDTO'] = self.file_list
# for file_item in enumerate(file_list):
for idx, file_item in enumerate(file_list):
file_item: 'FileItemDTO' = file_item
width = int(file_item.width * file_item.dataset_config.scale)
height = int(file_item.height * file_item.dataset_config.scale)
did_process_poi = False
if file_item.has_point_of_interest:
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
did_process_poi = file_item.setup_poi_bucket()
if self.dataset_config.square_crop:
# we scale first so smallest size matches resolution
scale_factor_x = resolution / width
scale_factor_y = resolution / height
scale_factor = max(scale_factor_x, scale_factor_y)
file_item.scale_to_width = math.ceil(width * scale_factor)
file_item.scale_to_height = math.ceil(height * scale_factor)
file_item.crop_width = resolution
file_item.crop_height = resolution
if width > height:
file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2)
file_item.crop_y = 0
else:
file_item.crop_x = 0
file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2)
elif not did_process_poi:
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
divisibility=bucket_tolerance
)
# Calculate scale factors for width and height
width_scale_factor = bucket_resolution["width"] / width
height_scale_factor = bucket_resolution["height"] / height
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
# round up
file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]
if self.dataset_config.random_crop:
# random crop
crop_x = random.randint(0, file_item.scale_to_width - new_width)
crop_y = random.randint(0, file_item.scale_to_height - new_height)
file_item.crop_x = crop_x
file_item.crop_y = crop_y
else:
# do central crop
file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
if file_item.crop_y < 0 or file_item.crop_x < 0:
print_acc('debug')
# check if bucket exists, if not, create it
bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
if bucket_key not in self.buckets:
self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height)
self.buckets[bucket_key].file_list_idx.append(idx)
# print the buckets
self.shuffle_buckets()
self.build_batch_indices()
if not quiet:
print_acc(f'Bucket sizes for {self.dataset_path}:')
for key, bucket in self.buckets.items():
print_acc(f'{key}: {len(bucket.file_list_idx)} files')
print_acc(f'{len(self.buckets)} buckets made')
class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.raw_caption: str = None
self.raw_caption_short: str = None
self.caption: str = None
self.caption_short: str = None
dataset_config: DatasetConfig = kwargs.get('dataset_config', None)
self.extra_values: List[float] = dataset_config.extra_values
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
if self.raw_caption is not None:
# we already loaded it
pass
elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
self.raw_caption = caption_dict[self.path]["caption"]
if 'caption_short' in caption_dict[self.path]:
self.raw_caption_short = caption_dict[self.path]["caption_short"]
if self.dataset_config.use_short_captions:
self.raw_caption = caption_dict[self.path]["caption_short"]
else:
# see if prompt file exists
path_no_ext = os.path.splitext(self.path)[0]
prompt_ext = self.dataset_config.caption_ext
prompt_path = f"{path_no_ext}.{prompt_ext}"
short_caption = None
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
short_caption = None
if prompt_path.endswith('.json'):
# replace any line endings with commas for \n \r \r\n
prompt = prompt.replace('\r\n', ' ')
prompt = prompt.replace('\n', ' ')
prompt = prompt.replace('\r', ' ')
prompt_json = json.loads(prompt)
if 'caption' in prompt_json:
prompt = prompt_json['caption']
if 'caption_short' in prompt_json:
short_caption = prompt_json['caption_short']
if self.dataset_config.use_short_captions:
prompt = short_caption
if 'extra_values' in prompt_json:
self.extra_values = prompt_json['extra_values']
prompt = clean_caption(prompt)
if short_caption is not None:
short_caption = clean_caption(short_caption)
else:
prompt = ''
if self.dataset_config.default_caption is not None:
prompt = self.dataset_config.default_caption
if short_caption is None:
short_caption = self.dataset_config.default_caption
self.raw_caption = prompt
self.raw_caption_short = short_caption
self.caption = self.get_caption()
if self.raw_caption_short is not None:
self.caption_short = self.get_caption(short_caption=True)
def get_caption(
self: 'FileItemDTO',
trigger=None,
to_replace_list=None,
add_if_not_present=False,
short_caption=False
):
if short_caption:
raw_caption = self.raw_caption_short
else:
raw_caption = self.raw_caption
if raw_caption is None:
raw_caption = ''
# handle dropout
if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
# get a random float form 0 to 1
rand = random.random()
if rand < self.dataset_config.caption_dropout_rate:
# drop the caption
return ''
# get tokens
token_list = raw_caption.split(',')
# trim whitespace
token_list = [x.strip() for x in token_list]
# remove empty strings
token_list = [x for x in token_list if x]
# handle token dropout
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
new_token_list = []
keep_tokens: int = self.dataset_config.keep_tokens
for idx, token in enumerate(token_list):
if idx < keep_tokens:
new_token_list.append(token)
elif self.dataset_config.token_dropout_rate >= 1.0:
# drop the token
pass
else:
# get a random float form 0 to 1
rand = random.random()
if rand > self.dataset_config.token_dropout_rate:
# keep the token
new_token_list.append(token)
token_list = new_token_list
if self.dataset_config.shuffle_tokens:
random.shuffle(token_list)
# join back together
caption = ', '.join(token_list)
# caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
if self.dataset_config.random_triggers:
num_triggers = self.dataset_config.random_triggers_max
if num_triggers > 1:
num_triggers = random.randint(0, num_triggers)
if num_triggers > 0:
triggers = random.sample(self.dataset_config.random_triggers, num_triggers)
caption = caption + ', ' + ', '.join(triggers)
# add random triggers
# for i in range(num_triggers):
# # fastest method
# trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))]
# caption = caption + ', ' + trigger
if self.dataset_config.shuffle_tokens:
# shuffle again
token_list = caption.split(',')
# trim whitespace
token_list = [x.strip() for x in token_list]
# remove empty strings
token_list = [x for x in token_list if x]
random.shuffle(token_list)
caption = ', '.join(token_list)
return caption
class ImageProcessingDTOMixin:
def load_and_process_video(
self: 'FileItemDTO',
transform: Union[None, transforms.Compose],
only_load_latents=False
):
if self.is_latent_cached:
raise Exception('Latent caching not supported for videos')
if self.augments is not None and len(self.augments) > 0:
raise Exception('Augments not supported for videos')
if self.has_augmentations:
raise Exception('Augmentations not supported for videos')
if not self.dataset_config.buckets:
raise Exception('Buckets required for video processing')
try:
# Use OpenCV to capture video frames
cap = cv2.VideoCapture(self.path)
if not cap.isOpened():
raise Exception(f"Failed to open video file: {self.path}")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = cap.get(cv2.CAP_PROP_FPS)
# Calculate the max valid frame index (accounting for zero-indexing)
max_frame_index = total_frames - 1
# Only log video properties if in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f"Video properties: {self.path}")
print_acc(f" Total frames: {total_frames}")
print_acc(f" Max valid frame index: {max_frame_index}")
print_acc(f" FPS: {video_fps}")
frames_to_extract = []
# Always stretch/shrink to the requested number of frames if needed
if self.dataset_config.shrink_video_to_frames or total_frames < self.dataset_config.num_frames:
# Distribute frames evenly across the entire video
interval = max_frame_index / (self.dataset_config.num_frames - 1) if self.dataset_config.num_frames > 1 else 0
frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.dataset_config.num_frames)]
else:
# Calculate frame interval based on FPS ratio
fps_ratio = video_fps / self.dataset_config.fps
frame_interval = max(1, int(round(fps_ratio)))
# Calculate max consecutive frames we can extract at desired FPS
max_consecutive_frames = (total_frames // frame_interval)
if max_consecutive_frames < self.dataset_config.num_frames:
# Not enough frames at desired FPS, so stretch instead
interval = max_frame_index / (self.dataset_config.num_frames - 1) if self.dataset_config.num_frames > 1 else 0
frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.dataset_config.num_frames)]
else:
# Calculate max start frame to ensure we can get all num_frames
max_start_frame = max_frame_index - ((self.dataset_config.num_frames - 1) * frame_interval)
start_frame = random.randint(0, max(0, max_start_frame))
# Generate list of frames to extract
frames_to_extract = [start_frame + (i * frame_interval) for i in range(self.dataset_config.num_frames)]
# Final safety check - ensure no frame exceeds max valid index
frames_to_extract = [min(frame_idx, max_frame_index) for frame_idx in frames_to_extract]
# Only log frames to extract if in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f" Frames to extract: {frames_to_extract}")
# Extract frames
frames = []
for frame_idx in frames_to_extract:
# Safety check - ensure frame_idx is within bounds (silently fix)
if frame_idx > max_frame_index:
frame_idx = max_frame_index
# Set frame position
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
# Silently verify position was set correctly (no warnings unless debug mode)
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
actual_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if actual_pos != frame_idx:
print_acc(f"Warning: Failed to set exact frame position. Requested: {frame_idx}, Actual: {actual_pos}")
ret, frame = cap.read()
if not ret:
# Try to provide more detailed error information
actual_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
frame_pos_info = f"Requested frame: {frame_idx}, Actual frame position: {actual_frame}"
# Try to read the next available frame as a fallback
fallback_success = False
for fallback_offset in [1, -1, 5, -5, 10, -10]:
fallback_pos = max(0, min(frame_idx + fallback_offset, max_frame_index))
cap.set(cv2.CAP_PROP_POS_FRAMES, fallback_pos)
fallback_ret, fallback_frame = cap.read()
if fallback_ret:
# Only log in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f"Falling back to nearby frame {fallback_pos} instead of {frame_idx}")
frame = fallback_frame
fallback_success = True
break
else:
# No fallback worked, raise a more detailed exception
video_info = f"Video: {self.path}, Total frames: {total_frames}, FPS: {video_fps}"
raise Exception(f"Failed to read frame {frame_idx} from video. {frame_pos_info}. {video_info}")
# Convert BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
img = Image.fromarray(frame)
# Apply the same processing as for single images
img = img.convert('RGB')
if self.flip_x:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
# Apply bucketing
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
# Apply transform if provided
if transform:
img = transform(img)
frames.append(img)
# Release the video capture
cap.release()
# Stack frames into tensor [frames, channels, height, width]
self.tensor = torch.stack(frames)
# Only log success in debug mode
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
print_acc(f"Successfully loaded video with {len(frames)} frames: {self.path}")
except Exception as e:
# Print full traceback
traceback.print_exc()
# Provide more context about the error
error_msg = str(e)
try:
if 'Failed to read frame' in error_msg and cap is not None:
# Try to get more info about the video that failed
cap_status = "Opened" if cap.isOpened() else "Closed"
current_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if cap.isOpened() else "Unknown"
reported_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if cap.isOpened() else "Unknown"
print_acc(f"Video details when error occurred:")
print_acc(f" Cap status: {cap_status}")
print_acc(f" Current position: {current_pos}")
print_acc(f" Reported total frames: {reported_total}")
# Try to verify if the video is corrupted
if cap.isOpened():
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Go to start
start_ret, _ = cap.read()
# Try to read the last frame to check if it's accessible
if reported_total > 0:
cap.set(cv2.CAP_PROP_POS_FRAMES, reported_total - 1)
end_ret, _ = cap.read()
print_acc(f" Can read first frame: {start_ret}, Can read last frame: {end_ret}")
# Close the cap if it's still open
cap.release()
except Exception as debug_err:
print_acc(f"Error during error diagnosis: {debug_err}")
print_acc(f"Error: {error_msg}")
print_acc(f"Error loading video: {self.path}")
# Re-raise with more detailed information
raise Exception(f"Video loading error ({self.path}): {error_msg}") from e
def load_and_process_image(
self: 'FileItemDTO',
transform: Union[None, transforms.Compose],
only_load_latents=False
):
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# if we are caching latents, just do that
if self.is_latent_cached:
self.get_latent()
if self.has_control_image:
self.load_control_image()
if self.has_inpaint_image:
self.load_inpaint_image()
if self.has_clip_image:
self.load_clip_image()
if self.has_mask_image:
self.load_mask_image()
if self.has_unconditional:
self.load_unconditional_image()
return
try:
img = Image.open(self.path)
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.path}")
if self.use_alpha_as_mask:
# we do this to make sure it does not replace the alpha with another color
# we want the image just without the alpha channel
np_img = np.array(img)
# strip off alpha
np_img = np_img[:, :, :3]
img = Image.fromarray(np_img)
img = img.convert('RGB')
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
print_acc(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
print_acc(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
# todo look into this. This still happens sometimes
print_acc('size mismatch')
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
else:
# Downscale the source image first
# TODO this is nto right
img = img.resize(
(int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
Image.BICUBIC)
min_img_size = min(img.size)
if self.dataset_config.random_crop:
if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution:
if min_img_size < self.dataset_config.resolution:
print_acc(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}")
scale_size = self.dataset_config.resolution
else:
scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
scaler = scale_size / min_img_size
scale_width = int((img.width + 5) * scaler)
scale_height = int((img.height + 5) * scaler)
img = img.resize((scale_width, scale_height), Image.BICUBIC)
img = transforms.RandomCrop(self.dataset_config.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
if self.augments is not None and len(self.augments) > 0:
# do augmentations
for augment in self.augments:
if augment in transforms_dict:
img = transforms_dict[augment](img)
if self.has_augmentations:
# augmentations handles transforms
img = self.augment_image(img, transform=transform)
elif transform:
img = transform(img)
self.tensor = img
if not only_load_latents:
if self.has_control_image:
self.load_control_image()
if self.has_inpaint_image:
self.load_inpaint_image()
if self.has_clip_image:
self.load_clip_image()
if self.has_mask_image:
self.load_mask_image()
if self.has_unconditional:
self.load_unconditional_image()
class InpaintControlFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_inpaint_image = False
self.inpaint_path: Union[str, None] = None
self.inpaint_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.inpaint_path is not None:
# find the control image path
inpaint_path = dataset_config.inpaint_path
# we are using control images
img_path = kwargs.get('path', None)
img_inpaint_ext_list = ['.png', '.webp']
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_inpaint_ext_list:
p = os.path.join(inpaint_path, file_name_no_ext + ext)
if os.path.exists(p):
self.inpaint_path = p
self.has_inpaint_image = True
break
def load_inpaint_image(self: 'FileItemDTO'):
try:
# image must have alpha channel for inpaint
img = Image.open(self.inpaint_path)
# make sure has aplha
if img.mode != 'RGBA':
return
img = exif_transpose(img)
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Inpaint images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
tensor = self.augment_spatial_control(img, transform=transform)
else:
tensor = transform(img)
# is 0 to 1 with alpha
self.inpaint_tensor = tensor
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.inpaint_path}")
def cleanup_inpaint(self: 'FileItemDTO'):
self.inpaint_tensor = None
class ControlFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_control_image = False
self.control_path: Union[str, List[str], None] = None
self.control_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False
if dataset_config.control_path is not None:
# find the control image path
control_path_list = dataset_config.control_path
if not isinstance(control_path_list, list):
control_path_list = [control_path_list]
self.full_size_control_images = dataset_config.full_size_control_images
# we are using control images
img_path = kwargs.get('path', None)
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
found_control_images = []
for control_path in control_path_list:
for ext in img_ext_list:
if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
found_control_images.append(os.path.join(control_path, file_name_no_ext + ext))
self.has_control_image = True
break
self.control_path = found_control_images
if len(self.control_path) == 0:
self.control_path = None
elif len(self.control_path) == 1:
# only do one
self.control_path = self.control_path[0]
def load_control_image(self: 'FileItemDTO'):
control_tensors = []
control_path_list = self.control_path
if not isinstance(self.control_path, list):
control_path_list = [self.control_path]
for control_path in control_path_list:
try:
img = Image.open(control_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {control_path}")
if not self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
else:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
tensor = self.augment_spatial_control(img, transform=transform)
else:
tensor = transform(img)
control_tensors.append(tensor)
if len(control_tensors) == 0:
self.control_tensor = None
elif len(control_tensors) == 1:
self.control_tensor = control_tensors[0]
else:
self.control_tensor = torch.stack(control_tensors, dim=0)
def cleanup_control(self: 'FileItemDTO'):
self.control_tensor = None
class ClipImageFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_clip_image = False
self.clip_image_path: Union[str, None] = None
self.clip_image_tensor: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[dict, None] = None
self.clip_image_embeds_unconditional: Union[dict, None] = None
self.has_clip_augmentations = False
self.clip_image_aug_transform: Union[None, A.Compose] = None
self.clip_image_processor: Union[None, CLIPImageProcessor] = None
self.clip_image_encoder_path: Union[str, None] = None
self.is_caching_clip_vision_to_disk = False
self.is_vision_clip_cached = False
self.clip_vision_is_quad = False
self.clip_vision_load_device = 'cpu'
self.clip_vision_unconditional_paths: Union[List[str], None] = None
self._clip_vision_embeddings_path: Union[str, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder:
# copy the clip image processor so the dataloader can do it
sd = kwargs.get('sd', None)
if hasattr(sd.adapter, 'clip_image_processor'):
self.clip_image_processor = sd.adapter.clip_image_processor
if dataset_config.clip_image_path is not None:
# find the control image path
clip_image_path = dataset_config.clip_image_path
# we are using control images
img_path = kwargs.get('path', None)
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)):
self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
self.has_clip_image = True
break
self.build_clip_imag_augmentation_transform()
if dataset_config.clip_image_from_same_folder:
# assume we have one. We will pull it on load.
self.has_clip_image = True
self.build_clip_imag_augmentation_transform()
def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0:
self.has_clip_augmentations = True
augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations]
if self.dataset_config.clip_image_shuffle_augmentations:
random.shuffle(augmentations)
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
self.clip_image_aug_transform = A.Compose(augmentation_list)
def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
if self.dataset_config.clip_image_shuffle_augmentations:
self.build_clip_imag_augmentation_transform()
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
if self.clip_vision_is_quad:
# image is in a 2x2 gris. split, run augs, and recombine
# split
img1, img2 = np.hsplit(open_cv_image, 2)
img1_1, img1_2 = np.vsplit(img1, 2)
img2_1, img2_2 = np.vsplit(img2, 2)
# apply augmentations
img1_1 = self.clip_image_aug_transform(image=img1_1)["image"]
img1_2 = self.clip_image_aug_transform(image=img1_2)["image"]
img2_1 = self.clip_image_aug_transform(image=img2_1)["image"]
img2_2 = self.clip_image_aug_transform(image=img2_2)["image"]
# recombine
augmented = np.vstack((np.hstack((img1_1, img1_2)), np.hstack((img2_1, img2_2))))
else:
# apply augmentations
augmented = self.clip_image_aug_transform(image=open_cv_image)["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def get_clip_vision_info_dict(self: 'FileItemDTO'):
item = OrderedDict([
("image_encoder_path", self.clip_image_encoder_path),
("filename", os.path.basename(self.clip_image_path)),
("is_quad", self.clip_vision_is_quad)
])
# when adding items, do it after so we dont change old latents
if self.flip_x:
item["flip_x"] = True
if self.flip_y:
item["flip_y"] = True
return item
def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False):
if self._clip_vision_embeddings_path is not None and not recalculate:
return self._clip_vision_embeddings_path
else:
# we store latents in a folder in same path as image called _latent_cache
img_dir = os.path.dirname(self.clip_image_path)
latent_dir = os.path.join(img_dir, '_clip_vision_cache')
hash_dict = self.get_clip_vision_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._clip_vision_embeddings_path
def get_new_clip_image_path(self: 'FileItemDTO'):
if self.dataset_config.clip_image_from_same_folder:
# randomly grab an image path from the same folder
pool_folder = os.path.dirname(self.path)
# find all images in the folder
img_files = []
for ext in img_ext_list:
img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
# remove the current image if len is greater than 1
if len(img_files) > 1:
img_files.remove(self.path)
# randomly grab one
return random.choice(img_files)
else:
return self.clip_image_path
def load_clip_image(self: 'FileItemDTO'):
is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \
isinstance(self.clip_image_processor, SiglipImageProcessor)
if self.clip_image_processor is None:
is_dynamic_size_and_aspect = True # serving it raw
if self.is_vision_clip_cached:
self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
# get a random unconditional image
if self.clip_vision_unconditional_paths is not None:
unconditional_path = random.choice(self.clip_vision_unconditional_paths)
self.clip_image_embeds_unconditional = load_file(unconditional_path)
return
clip_image_path = self.get_new_clip_image_path()
try:
img = Image.open(clip_image_path).convert('RGB')
img = exif_transpose(img)
except Exception as e:
# make a random noise image
img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {clip_image_path}")
img = img.convert('RGB')
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if is_dynamic_size_and_aspect:
pass # let the image processor handle it
elif img.width != img.height:
min_size = min(img.width, img.height)
if self.dataset_config.square_crop:
# center crop to a square
img = transforms.CenterCrop(min_size)(img)
else:
# image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data
# resize to the smallest dimension
img = img.resize((min_size, min_size), Image.BICUBIC)
if self.has_clip_augmentations:
self.clip_image_tensor = self.augment_clip_image(img, transform=None)
else:
self.clip_image_tensor = transforms.ToTensor()(img)
# random crop
# if self.dataset_config.clip_image_random_crop:
# # crop up to 20% on all sides. Keep is square
# crop_percent = random.randint(0, 20) / 100
# crop_width = int(self.clip_image_tensor.shape[2] * crop_percent)
# crop_height = int(self.clip_image_tensor.shape[1] * crop_percent)
# crop_left = random.randint(0, crop_width)
# crop_top = random.randint(0, crop_height)
# crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left
# crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top
# if len(self.clip_image_tensor.shape) == 3:
# self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right]
# elif len(self.clip_image_tensor.shape) == 4:
# self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right]
if self.clip_image_processor is not None:
# run it
tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)
clip_out = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
self.clip_image_tensor = clip_out.squeeze(0).clone().detach()
def cleanup_clip_image(self: 'FileItemDTO'):
self.clip_image_tensor = None
self.clip_image_embeds = None
class AugmentationFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_augmentations = False
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.aug_transform: Union[None, A.Compose] = None
self.aug_replay_spatial_transforms = None
self.build_augmentation_transform()
def build_augmentation_transform(self: 'FileItemDTO'):
if self.dataset_config.augmentations is not None and len(self.dataset_config.augmentations) > 0:
self.has_augmentations = True
augmentations = [Augments(**aug) for aug in self.dataset_config.augmentations]
if self.dataset_config.shuffle_augmentations:
random.shuffle(augmentations)
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
# add additional targets so we can augment the control image
self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'})
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
# rebuild each time if shuffle
if self.dataset_config.shuffle_augmentations:
self.build_augmentation_transform()
# save the original tensor
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
transformed = self.aug_transform(image=open_cv_image)
augmented = transformed["image"]
# save just the spatial transforms for controls and masks
augmented_params = transformed["replay"]
spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop',
'ElasticTransform', 'GridDistortion', 'OpticalDistortion']
# only store the spatial transforms
augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
if self.dataset_config.replay_transforms:
self.aug_replay_spatial_transforms = augmented_params
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
# augment control images spatially consistent with transforms done to the main image
def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ):
if self.aug_replay_spatial_transforms is None:
# no transforms
return transform(img)
# save colorspace to convert back to
colorspace = img.mode
# convert to rgb
img = img.convert('RGB')
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# Replay transforms
transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image)
augmented = transformed["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
# convert back to original colorspace
augmented = augmented.convert(colorspace)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def cleanup_control(self: 'FileItemDTO'):
self.unaugmented_tensor = None
class MaskFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_mask_image = False
self.mask_path: Union[str, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.use_alpha_as_mask: bool = False
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.mask_min_value = dataset_config.mask_min_value
if dataset_config.alpha_mask:
self.use_alpha_as_mask = True
self.mask_path = kwargs.get('path', None)
self.has_mask_image = True
elif dataset_config.mask_path is not None:
# find the control image path
mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask
# we are using control images
img_path = kwargs.get('path', None)
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)):
self.mask_path = os.path.join(mask_path, file_name_no_ext + ext)
self.has_mask_image = True
break
def load_mask_image(self: 'FileItemDTO'):
try:
img = Image.open(self.mask_path)
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.mask_path}")
if self.use_alpha_as_mask:
# pipeline expectws an rgb image so we need to put alpha in all channels
np_img = np.array(img)
np_img[:, :, :3] = np_img[:, :, 3:]
np_img = np_img[:, :, :3]
img = Image.fromarray(np_img)
img = img.convert('RGB')
if self.dataset_config.invert_mask:
img = ImageOps.invert(img)
w, h = img.size
fix_size = False
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
fix_size = True
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
print_acc(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
fix_size = True
if fix_size:
# swap all the sizes
self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width
self.crop_width, self.crop_height = self.crop_height, self.crop_width
self.crop_x, self.crop_y = self.crop_y, self.crop_x
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
# randomly apply a blur up to 0.5% of the size of the min (width, height)
min_size = min(img.width, img.height)
blur_radius = int(min_size * random.random() * 0.005)
img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
# make grayscale
img = img.convert('L')
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Mask images not supported for non-bucket datasets")
transform = transforms.Compose([
transforms.ToTensor(),
])
if self.aug_replay_spatial_transforms:
self.mask_tensor = self.augment_spatial_control(img, transform=transform)
else:
self.mask_tensor = transform(img)
self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
# convert to grayscale
def cleanup_mask(self: 'FileItemDTO'):
self.mask_tensor = None
class UnconditionalFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_unconditional = False
self.unconditional_path: Union[str, None] = None
self.unconditional_tensor: Union[torch.Tensor, None] = None
self.unconditional_latent: Union[torch.Tensor, None] = None
self.unconditional_transforms = self.dataloader_transforms
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.unconditional_path is not None:
# we are using control images
img_path = kwargs.get('path', None)
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
for ext in img_ext_list:
if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)):
self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)
self.has_unconditional = True
break
def load_unconditional_image(self: 'FileItemDTO'):
try:
img = Image.open(self.unconditional_path)
img = exif_transpose(img)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error loading image: {self.mask_path}")
img = img.convert('RGB')
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img = img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Unconditional images are not supported for non-bucket datasets")
if self.aug_replay_spatial_transforms:
self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms)
else:
self.unconditional_tensor = self.unconditional_transforms(img)
def cleanup_unconditional(self: 'FileItemDTO'):
self.unconditional_tensor = None
self.unconditional_latent = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
# poi is a name of the box point of interest in the caption json file
dataset_config = kwargs.get('dataset_config', None)
path = kwargs.get('path', None)
self.poi: Union[str, None] = dataset_config.poi
self.has_point_of_interest = self.poi is not None
self.poi_x: Union[int, None] = None
self.poi_y: Union[int, None] = None
self.poi_width: Union[int, None] = None
self.poi_height: Union[int, None] = None
if self.poi is not None:
# make sure latent caching is off
if dataset_config.cache_latents or dataset_config.cache_latents_to_disk:
raise Exception(
f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config"
)
# make sure we are loading through json
if dataset_config.caption_ext != 'json':
raise Exception(
f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config"
)
self.poi = self.poi.strip()
# get the caption path
file_path_no_ext = os.path.splitext(path)[0]
caption_path = file_path_no_ext + '.json'
if not os.path.exists(caption_path):
raise Exception(f"Error: caption file not found for poi: {caption_path}")
with open(caption_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
if 'poi' not in json_data:
print_acc(f"Warning: poi not found in caption file: {caption_path}")
if self.poi not in json_data['poi']:
print_acc(f"Warning: poi not found in caption file: {caption_path}")
# poi has, x, y, width, height
# do full image if no poi
self.poi_x = 0
self.poi_y = 0
self.poi_width = self.width
self.poi_height = self.height
try:
if self.poi in json_data['poi']:
poi = json_data['poi'][self.poi]
self.poi_x = int(poi['x'])
self.poi_y = int(poi['y'])
self.poi_width = int(poi['width'])
self.poi_height = int(poi['height'])
except Exception as e:
pass
# handle flipping
if kwargs.get('flip_x', False):
# flip the poi
self.poi_x = self.width - self.poi_x - self.poi_width
if kwargs.get('flip_y', False):
# flip the poi
self.poi_y = self.height - self.poi_y - self.poi_height
def setup_poi_bucket(self: 'FileItemDTO'):
initial_width = int(self.width * self.dataset_config.scale)
initial_height = int(self.height * self.dataset_config.scale)
# we are using poi, so we need to calculate the bucket based on the poi
# if img resolution is less than dataset resolution, just return and let the normal bucketing happen
img_resolution = get_resolution(initial_width, initial_height)
if img_resolution <= self.dataset_config.resolution:
return False # will trigger normal bucketing
bucket_tolerance = self.dataset_config.bucket_tolerance
poi_x = int(self.poi_x * self.dataset_config.scale)
poi_y = int(self.poi_y * self.dataset_config.scale)
poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale)
# loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better
num_loops = 0
while True:
# crop left
if poi_x > 0:
poi_x = random.randint(0, poi_x)
else:
poi_x = 0
# crop right
cr_min = poi_x + poi_width
if cr_min < initial_width:
crop_right = random.randint(poi_x + poi_width, initial_width)
else:
crop_right = initial_width
poi_width = crop_right - poi_x
if poi_y > 0:
poi_y = random.randint(0, poi_y)
else:
poi_y = 0
if poi_y + poi_height < initial_height:
crop_bottom = random.randint(poi_y + poi_height, initial_height)
else:
crop_bottom = initial_height
poi_height = crop_bottom - poi_y
try:
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
current_resolution = get_resolution(poi_width, poi_height)
except Exception as e:
print_acc(f"Error: {e}")
print_acc(f"Error getting resolution: {self.path}")
raise e
return False
if current_resolution >= self.dataset_config.resolution:
# We can break now
break
else:
num_loops += 1
if num_loops > 100:
print_acc(
f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
return False
new_width = poi_width
new_height = poi_height
bucket_resolution = get_bucket_for_image_size(
new_width, new_height,
resolution=self.dataset_config.resolution,
divisibility=bucket_tolerance
)
width_scale_factor = bucket_resolution["width"] / new_width
height_scale_factor = bucket_resolution["height"] / new_height
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
self.scale_to_width = math.ceil(initial_width * max_scale_factor)
self.scale_to_height = math.ceil(initial_height * max_scale_factor)
self.crop_width = bucket_resolution['width']
self.crop_height = bucket_resolution['height']
self.crop_x = int(poi_x * max_scale_factor)
self.crop_y = int(poi_y * max_scale_factor)
if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height:
# todo look into this. This still happens sometimes
print_acc('size mismatch')
return True
class ArgBreakMixin:
# just stops super calls form hitting object
def __init__(self, *args, **kwargs):
pass
class LatentCachingFileItemDTOMixin:
def __init__(self, *args, **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self._encoded_latent: Union[torch.Tensor, None] = None
self._latent_path: Union[str, None] = None
self.is_latent_cached = False
self.is_caching_to_disk = False
self.is_caching_to_memory = False
self.latent_load_device = 'cpu'
# sd1 or sdxl or others
self.latent_space_version = 'sd1'
# todo, increment this if we change the latent format to invalidate cache
self.latent_version = 1
def get_latent_info_dict(self: 'FileItemDTO'):
item = OrderedDict([
("filename", os.path.basename(self.path)),
("scale_to_width", self.scale_to_width),
("scale_to_height", self.scale_to_height),
("crop_x", self.crop_x),
("crop_y", self.crop_y),
("crop_width", self.crop_width),
("crop_height", self.crop_height),
("latent_space_version", self.latent_space_version),
("latent_version", self.latent_version),
])
# when adding items, do it after so we dont change old latents
if self.flip_x:
item["flip_x"] = True
if self.flip_y:
item["flip_y"] = True
return item
def get_latent_path(self: 'FileItemDTO', recalculate=False):
if self._latent_path is not None and not recalculate:
return self._latent_path
else:
# we store latents in a folder in same path as image called _latent_cache
img_dir = os.path.dirname(self.path)
latent_dir = os.path.join(img_dir, '_latent_cache')
hash_dict = self.get_latent_info_dict()
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
return self._latent_path
def cleanup_latent(self):
if self._encoded_latent is not None:
if not self.is_caching_to_memory:
# we are caching on disk, don't save in memory
self._encoded_latent = None
else:
# move it back to cpu
self._encoded_latent = self._encoded_latent.to('cpu')
def get_latent(self, device=None):
if not self.is_latent_cached:
return None
if self._encoded_latent is None:
# load it from disk
state_dict = load_file(
self.get_latent_path(),
# device=device if device is not None else self.latent_load_device
device='cpu'
)
self._encoded_latent = state_dict['latent']
return self._encoded_latent
class LatentCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.latent_cache = {}
def cache_latents_all_latents(self: 'AiToolkitDataset'):
if self.dataset_config.num_frames > 1:
raise Exception("Error: caching latents is not supported for multi-frame datasets")
with accelerator.main_process_first():
print_acc(f"Caching latents for {self.dataset_path}")
# cache all latents to disk
to_disk = self.is_caching_latents_to_disk
to_memory = self.is_caching_latents_to_memory
if to_disk:
print_acc(" - Saving latents to disk")
if to_memory:
print_acc(" - Keeping latents in memory")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_latents')
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
# set latent space version
if self.sd.model_config.latent_space_version is not None:
file_item.latent_space_version = self.sd.model_config.latent_space_version
elif self.sd.is_xl:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_v3:
file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux1'
elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl'
else:
file_item.latent_space_version = self.sd.model_config.arch
file_item.is_caching_to_disk = to_disk
file_item.is_caching_to_memory = to_memory
file_item.latent_load_device = self.sd.device
latent_path = file_item.get_latent_path(recalculate=True)
# check if it is saved to disk already
if os.path.exists(latent_path):
if to_memory:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
save_file(state_dict, latent_path, metadata=meta)
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
class CLIPCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
# if we have super, call it
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.clip_vision_num_unconditional_cache = 20
self.clip_vision_unconditional_cache = []
def cache_clip_vision_to_disk(self: 'AiToolkitDataset'):
if not self.is_caching_clip_vision_to_disk:
return
with torch.no_grad():
print_acc(f"Caching clip vision for {self.dataset_path}")
print_acc(" - Saving clip to disk")
# move sd items to cpu except for vae
self.sd.set_device_state_preset('cache_clip')
# make sure the adapter has attributes
if self.sd.adapter is None:
raise Exception("Error: must have an adapter to cache clip vision to disk")
clip_image_processor: CLIPImageProcessor = None
if hasattr(self.sd.adapter, 'clip_image_processor'):
clip_image_processor = self.sd.adapter.clip_image_processor
if clip_image_processor is None:
raise Exception("Error: must have a clip image processor to cache clip vision to disk")
vision_encoder: CLIPVisionModelWithProjection = None
if hasattr(self.sd.adapter, 'image_encoder'):
vision_encoder = self.sd.adapter.image_encoder
if hasattr(self.sd.adapter, 'vision_encoder'):
vision_encoder = self.sd.adapter.vision_encoder
if vision_encoder is None:
raise Exception("Error: must have a vision encoder to cache clip vision to disk")
# move vision encoder to device
vision_encoder.to(self.sd.device)
is_quad = self.sd.adapter.config.quad_image
image_encoder_path = self.sd.adapter.config.image_encoder_path
dtype = self.sd.torch_dtype
device = self.sd.device_torch
if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero:
# just to do this, we did :)
# need more samples as it is random noise
self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache
else:
# only need one since it doesnt change
self.clip_vision_num_unconditional_cache = 1
# cache unconditionals
print_acc(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
unconditional_paths = []
is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero
for i in range(self.clip_vision_num_unconditional_cache):
hash_dict = OrderedDict([
("image_encoder_path", image_encoder_path),
("is_quad", is_quad),
("is_noise_zero", is_noise_zero),
])
# get base64 hash of md5 checksum of hash_dict
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
hash_str = hash_str.replace('=', '')
uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors')
if os.path.exists(uncond_path):
# skip it
unconditional_paths.append(uncond_path)
continue
# generate a random image
img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size)
if is_noise_zero:
tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32)
else:
tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32)
clip_image = clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
if is_quad:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
clip_output = vision_encoder(
clip_image.to(device, dtype=dtype),
output_hidden_states=True
)
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
state_dict = OrderedDict([
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
])
os.makedirs(os.path.dirname(uncond_path), exist_ok=True)
save_file(state_dict, uncond_path)
unconditional_paths.append(uncond_path)
self.clip_vision_unconditional_cache = unconditional_paths
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'):
file_item.is_caching_clip_vision_to_disk = True
file_item.clip_vision_load_device = self.sd.device
file_item.clip_vision_is_quad = is_quad
file_item.clip_image_encoder_path = image_encoder_path
file_item.clip_vision_unconditional_paths = unconditional_paths
if file_item.has_clip_augmentations:
raise Exception("Error: clip vision caching is not supported with clip augmentations")
embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True)
# check if it is saved to disk already
if not os.path.exists(embedding_path):
# load the image first
file_item.load_clip_image()
# add batch dimension
clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype)
if is_quad:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
clip_output = vision_encoder(
clip_image.to(device, dtype=dtype),
output_hidden_states=True
)
# make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
state_dict = OrderedDict([
('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict())
os.makedirs(os.path.dirname(embedding_path), exist_ok=True)
save_file(state_dict, embedding_path, metadata=meta)
del clip_image
del clip_output
del file_item.clip_image_tensor
# flush(garbage_collect=False)
file_item.is_vision_clip_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()
class ControlCachingMixin:
def __init__(self: 'AiToolkitDataset', **kwargs):
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
self.control_depth_model = None
self.control_pose_model = None
self.control_line_model = None
self.control_bg_remover = None
def get_control_path(self: 'AiToolkitDataset', file_item:'FileItemDTO', control_type: ControlTypes):
coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls')
file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0]
file_name_no_ext_control = f"{file_name_no_ext}.{control_type}"
for ext in img_ext_list:
possible_path = os.path.join(coltrols_folder, file_name_no_ext_control + ext)
if os.path.exists(possible_path):
return possible_path
# if we get here, we need to generate the control
return None
def add_control_path_to_file_item(self: 'AiToolkitDataset', file_item: 'FileItemDTO', control_path: str, control_type: ControlTypes):
if control_type == 'inpaint':
file_item.inpaint_path = control_path
file_item.has_inpaint_image = True
elif control_type == 'mask':
file_item.mask_path = control_path
file_item.has_mask_image = True
else:
if file_item.control_path is None:
file_item.control_path = [control_path]
elif isinstance(file_item.control_path, str):
file_item.control_path = [file_item.control_path, control_path]
elif isinstance(file_item.control_path, list):
file_item.control_path.append(control_path)
else:
raise Exception(f"Error: control_path is not a string or list: {file_item.control_path}")
file_item.has_control_image = True
def setup_controls(self: 'AiToolkitDataset'):
if not self.is_generating_controls:
return
with torch.no_grad():
print_acc(f"Generating controls for {self.dataset_path}")
has_unloaded = False
device = self.sd.device
# controls 'depth', 'line', 'pose', 'inpaint', 'mask'
# use tqdm to show progress
i = 0
for file_item in tqdm(self.file_list, desc=f'Generating Controls'):
coltrols_folder = os.path.join(os.path.dirname(file_item.path), '_controls')
file_name_no_ext = os.path.splitext(os.path.basename(file_item.path))[0]
image: Image = None
for control_type in self.dataset_config.controls:
control_path = self.get_control_path(file_item, control_type)
if control_path is not None:
self.add_control_path_to_file_item(file_item, control_path, control_type)
else:
# we need to generate the control. Unload model if not unloaded
if not has_unloaded:
print("Unloading model to generate controls")
self.sd.set_device_state_preset('unload')
has_unloaded = True
if image is None:
# make sure image is loaded if we havent loaded it with another control
image = Image.open(file_item.path).convert('RGB')
image = exif_transpose(image)
# resize to a max of 1mp
max_size = 1024 * 1024
w, h = image.size
if w * h > max_size:
scale = math.sqrt(max_size / (w * h))
w = int(w * scale)
h = int(h * scale)
image = image.resize((w, h), Image.BICUBIC)
save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.jpg")
os.makedirs(coltrols_folder, exist_ok=True)
if control_type == 'depth':
if self.control_depth_model is None:
from transformers import pipeline
self.control_depth_model = pipeline(
task="depth-estimation",
model="depth-anything/Depth-Anything-V2-Large-hf",
device=device,
torch_dtype=torch.float16
)
img = image.copy()
in_size = img.size
output = self.control_depth_model(img)
out_tensor = output["predicted_depth"] # shape (1, H, W) 0 - 255
out_tensor = out_tensor.clamp(0, 255)
out_tensor = out_tensor.squeeze(0).cpu().numpy()
img = Image.fromarray(out_tensor.astype('uint8'))
img = img.resize(in_size, Image.LANCZOS)
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
elif control_type == 'pose':
if self.control_pose_model is None:
from controlnet_aux import OpenposeDetector
self.control_pose_model = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to(device)
img = image.copy()
detect_res = int(math.sqrt(img.size[0] * img.size[1]))
img = self.control_pose_model(img, hand_and_face=True, detect_resolution=detect_res, image_resolution=detect_res)
img = img.convert('RGB')
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
elif control_type == 'line':
if self.control_line_model is None:
from controlnet_aux import TEEDdetector
self.control_line_model = TEEDdetector.from_pretrained("fal-ai/teed", filename="5_model.pth").to(device)
img = image.copy()
img = self.control_line_model(img, detect_resolution=1024)
img = img.convert('RGB')
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
elif control_type == 'inpaint' or control_type == 'mask':
img = image.copy()
if self.control_bg_remover is None:
from transformers import AutoModelForImageSegmentation
self.control_bg_remover = AutoModelForImageSegmentation.from_pretrained(
'ZhengPeng7/BiRefNet_HR',
trust_remote_code=True,
revision="595e212b3eaa6a1beaad56cee49749b1e00b1596",
torch_dtype=torch.float16
).to(device)
self.control_bg_remover.eval()
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(img).unsqueeze(0).to('cuda').to(torch.float16)
# Prediction
preds = self.control_bg_remover(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(img.size)
if control_type == 'inpaint':
# inpainting feature currently only supports "erased" section desired to inpaint
mask = ImageOps.invert(mask)
img.putalpha(mask)
save_path = os.path.join(coltrols_folder, f"{file_name_no_ext}.{control_type}.webp")
else:
img = mask
img = img.convert('RGB')
img.save(save_path)
self.add_control_path_to_file_item(file_item, save_path, control_type)
else:
raise Exception(f"Error: unknown control type {control_type}")
i += 1
# remove models
self.control_depth_model = None
self.control_pose_model = None
self.control_line_model = None
self.control_bg_remover = None
flush()
# restore device state
if has_unloaded:
self.sd.restore_device_state()