Spaces:
Paused
Paused
import copy | |
import gc | |
import inspect | |
import json | |
import random | |
import shutil | |
import typing | |
from typing import Optional, Union, List, Literal | |
import os | |
from collections import OrderedDict | |
import copy | |
import yaml | |
from PIL import Image | |
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg | |
from torch.nn import Parameter | |
from tqdm import tqdm | |
from torchvision.transforms import Resize, transforms | |
from toolkit.clip_vision_adapter import ClipVisionAdapter | |
from toolkit.custom_adapter import CustomAdapter | |
from toolkit.ip_adapter import IPAdapter | |
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch | |
from toolkit.models.decorator import Decorator | |
from toolkit.paths import KEYMAPS_ROOT | |
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds | |
from toolkit.reference_adapter import ReferenceAdapter | |
from toolkit.sd_device_states_presets import empty_preset | |
from toolkit.train_tools import get_torch_dtype, apply_noise_offset | |
import torch | |
from toolkit.pipelines import CustomStableDiffusionXLPipeline | |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ | |
LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel | |
import diffusers | |
from diffusers import \ | |
AutoencoderKL, \ | |
UNet2DConditionModel | |
from diffusers import PixArtAlphaPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection | |
from toolkit.accelerator import get_accelerator, unwrap_model | |
from typing import TYPE_CHECKING | |
from toolkit.print import print_acc | |
if TYPE_CHECKING: | |
from toolkit.lora_special import LoRASpecialNetwork | |
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
# tell it to shut up | |
diffusers.logging.set_verbosity(diffusers.logging.ERROR) | |
SD_PREFIX_VAE = "vae" | |
SD_PREFIX_UNET = "unet" | |
SD_PREFIX_REFINER_UNET = "refiner_unet" | |
SD_PREFIX_TEXT_ENCODER = "te" | |
SD_PREFIX_TEXT_ENCODER1 = "te0" | |
SD_PREFIX_TEXT_ENCODER2 = "te1" | |
# prefixed diffusers keys | |
DO_NOT_TRAIN_WEIGHTS = [ | |
"unet_time_embedding.linear_1.bias", | |
"unet_time_embedding.linear_1.weight", | |
"unet_time_embedding.linear_2.bias", | |
"unet_time_embedding.linear_2.weight", | |
"refiner_unet_time_embedding.linear_1.bias", | |
"refiner_unet_time_embedding.linear_1.weight", | |
"refiner_unet_time_embedding.linear_2.bias", | |
"refiner_unet_time_embedding.linear_2.weight", | |
] | |
DeviceStatePreset = Literal['cache_latents', 'generate'] | |
class BlankNetwork: | |
def __init__(self): | |
self.multiplier = 1.0 | |
self.is_active = True | |
self.is_merged_in = False | |
self.can_merge_in = False | |
def __enter__(self): | |
self.is_active = True | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.is_active = False | |
def train(self): | |
pass | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 | |
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 | |
class BaseModel: | |
# override these in child classes | |
arch = None | |
def __init__( | |
self, | |
device, | |
model_config: ModelConfig, | |
dtype='fp16', | |
custom_pipeline=None, | |
noise_scheduler=None, | |
**kwargs | |
): | |
self.accelerator = get_accelerator() | |
self.custom_pipeline = custom_pipeline | |
self.device = str(self.accelerator.device) | |
self.dtype = dtype | |
self.torch_dtype = get_torch_dtype(dtype) | |
self.device_torch = self.accelerator.device | |
self.vae_device_torch = self.accelerator.device | |
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) | |
self.te_device_torch = self.accelerator.device | |
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) | |
self.model_config = model_config | |
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" | |
self.device_state = None | |
self.pipeline: Union[None, 'StableDiffusionPipeline', | |
'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] | |
self.vae: Union[None, 'AutoencoderKL'] | |
self.model: Union[None, 'Transformer2DModel', 'UNet2DConditionModel'] | |
self.text_encoder: Union[None, 'CLIPTextModel', | |
List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] | |
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] | |
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler | |
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None | |
self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None | |
# sdxl stuff | |
self.logit_scale = None | |
self.ckppt_info = None | |
self.is_loaded = False | |
# to hold network if there is one | |
self.network = None | |
self.adapter: Union['ControlNetModel', 'T2IAdapter', | |
'IPAdapter', 'ReferenceAdapter', None] = None | |
self.decorator: Union[Decorator, None] = None | |
self.arch: ModelArch = model_config.arch | |
self.use_text_encoder_1 = model_config.use_text_encoder_1 | |
self.use_text_encoder_2 = model_config.use_text_encoder_2 | |
self.config_file = None | |
self.is_flow_matching = False | |
self.quantize_device = self.device_torch | |
self.low_vram = self.model_config.low_vram | |
# merge in and preview active with -1 weight | |
self.invert_assistant_lora = False | |
self._after_sample_img_hooks = [] | |
self._status_update_hooks = [] | |
self.is_transformer = False | |
# properties for old arch for backwards compatibility | |
def unet(self): | |
return self.model | |
# set unet to model | |
def unet(self, value): | |
self.model = value | |
def transformer(self): | |
return self.model | |
def transformer(self, value): | |
self.model = value | |
def unet_unwrapped(self): | |
return unwrap_model(self.model) | |
def model_unwrapped(self): | |
return unwrap_model(self.model) | |
def is_xl(self): | |
return self.arch == 'sdxl' | |
def is_v2(self): | |
return self.arch == 'sd2' | |
def is_ssd(self): | |
return self.arch == 'ssd' | |
def is_v3(self): | |
return self.arch == 'sd3' | |
def is_vega(self): | |
return self.arch == 'vega' | |
def is_pixart(self): | |
return self.arch == 'pixart' | |
def is_auraflow(self): | |
return self.arch == 'auraflow' | |
def is_flux(self): | |
return self.arch == 'flux' | |
def is_lumina2(self): | |
return self.arch == 'lumina2' | |
def get_bucket_divisibility(self): | |
if self.vae is None: | |
return 8 | |
try: | |
divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) | |
except: | |
# if we have a custom vae, it might not have this | |
divisibility = 8 | |
# flux packs this again, | |
if self.is_flux: | |
divisibility = divisibility * 2 | |
return divisibility | |
# these must be implemented in child classes | |
def load_model(self): | |
# override this in child classes | |
raise NotImplementedError( | |
"load_model must be implemented in child classes") | |
def get_generation_pipeline(self): | |
# override this in child classes | |
raise NotImplementedError( | |
"get_generation_pipeline must be implemented in child classes") | |
def generate_single_image( | |
self, | |
pipeline, | |
gen_config: GenerateImageConfig, | |
conditional_embeds: PromptEmbeds, | |
unconditional_embeds: PromptEmbeds, | |
generator: torch.Generator, | |
extra: dict, | |
): | |
# override this in child classes | |
raise NotImplementedError( | |
"generate_single_image must be implemented in child classes") | |
def get_noise_prediction( | |
latent_model_input: torch.Tensor, | |
timestep: torch.Tensor, # 0 to 1000 scale | |
text_embeddings: PromptEmbeds, | |
**kwargs | |
): | |
raise NotImplementedError( | |
"get_noise_prediction must be implemented in child classes") | |
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: | |
raise NotImplementedError( | |
"get_prompt_embeds must be implemented in child classes") | |
def get_model_has_grad(self): | |
raise NotImplementedError( | |
"get_model_has_grad must be implemented in child classes") | |
def get_te_has_grad(self): | |
raise NotImplementedError( | |
"get_te_has_grad must be implemented in child classes") | |
def save_model(self, output_path, meta, save_dtype): | |
# todo handle dtype without overloading anything (vram, cpu, etc) | |
unwrap_model(self.pipeline).save_pretrained( | |
save_directory=output_path, | |
safe_serialization=True, | |
) | |
# save out meta config | |
meta_path = os.path.join(output_path, 'aitk_meta.yaml') | |
with open(meta_path, 'w') as f: | |
yaml.dump(meta, f) | |
# end must be implemented in child classes | |
def te_train(self): | |
if isinstance(self.text_encoder, list): | |
for te in self.text_encoder: | |
te.train() | |
elif self.text_encoder is not None: | |
self.text_encoder.train() | |
def te_eval(self): | |
if isinstance(self.text_encoder, list): | |
for te in self.text_encoder: | |
te.eval() | |
elif self.text_encoder is not None: | |
self.text_encoder.eval() | |
def _after_sample_image(self, img_num, total_imgs): | |
# process all hooks | |
for hook in self._after_sample_img_hooks: | |
hook(img_num, total_imgs) | |
def add_after_sample_image_hook(self, func): | |
self._after_sample_img_hooks.append(func) | |
def _status_update(self, status: str): | |
for hook in self._status_update_hooks: | |
hook(status) | |
def print_and_status_update(self, status: str): | |
print_acc(status) | |
self._status_update(status) | |
def add_status_update_hook(self, func): | |
self._status_update_hooks.append(func) | |
def generate_images( | |
self, | |
image_configs: List[GenerateImageConfig], | |
sampler=None, | |
pipeline: Union[None, StableDiffusionPipeline, | |
StableDiffusionXLPipeline] = None, | |
): | |
network = unwrap_model(self.network) | |
merge_multiplier = 1.0 | |
flush() | |
# if using assistant, unfuse it | |
if self.model_config.assistant_lora_path is not None: | |
print_acc("Unloading assistant lora") | |
if self.invert_assistant_lora: | |
self.assistant_lora.is_active = True | |
# move weights on to the device | |
self.assistant_lora.force_to( | |
self.device_torch, self.torch_dtype) | |
else: | |
self.assistant_lora.is_active = False | |
if self.model_config.inference_lora_path is not None: | |
print_acc("Loading inference lora") | |
self.assistant_lora.is_active = True | |
# move weights on to the device | |
self.assistant_lora.force_to(self.device_torch, self.torch_dtype) | |
if network is not None: | |
network.eval() | |
# check if we have the same network weight for all samples. If we do, we can merge in th | |
# the network to drastically speed up inference | |
unique_network_weights = set( | |
[x.network_multiplier for x in image_configs]) | |
if len(unique_network_weights) == 1 and network.can_merge_in: | |
can_merge_in = True | |
merge_multiplier = unique_network_weights.pop() | |
network.merge_in(merge_weight=merge_multiplier) | |
else: | |
network = BlankNetwork() | |
self.save_device_state() | |
self.set_device_state_preset('generate') | |
# save current seed state for training | |
rng_state = torch.get_rng_state() | |
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None | |
if pipeline is None: | |
pipeline = self.get_generation_pipeline() | |
try: | |
pipeline.set_progress_bar_config(disable=True) | |
except: | |
pass | |
start_multiplier = 1.0 | |
if network is not None: | |
start_multiplier = network.multiplier | |
# pipeline.to(self.device_torch) | |
with network: | |
with torch.no_grad(): | |
if network is not None: | |
assert network.is_active | |
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): | |
gen_config = image_configs[i] | |
extra = {} | |
validation_image = None | |
if self.adapter is not None and gen_config.adapter_image_path is not None: | |
validation_image = Image.open(gen_config.adapter_image_path) | |
if ".inpaint." not in gen_config.adapter_image_path: | |
validation_image = validation_image.convert("RGB") | |
else: | |
# make sure it has an alpha | |
if validation_image.mode != "RGBA": | |
raise ValueError("Inpainting images must have an alpha channel") | |
if isinstance(self.adapter, T2IAdapter): | |
# not sure why this is double?? | |
validation_image = validation_image.resize( | |
(gen_config.width * 2, gen_config.height * 2)) | |
extra['image'] = validation_image | |
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale | |
if isinstance(self.adapter, ControlNetModel): | |
validation_image = validation_image.resize( | |
(gen_config.width, gen_config.height)) | |
extra['image'] = validation_image | |
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale | |
if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: | |
validation_image = validation_image.resize((gen_config.width, gen_config.height)) | |
extra['control_image'] = validation_image | |
extra['control_image_idx'] = gen_config.ctrl_idx | |
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
validation_image = transform(validation_image) | |
if isinstance(self.adapter, CustomAdapter): | |
# todo allow loading multiple | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
validation_image = transform(validation_image) | |
self.adapter.num_images = 1 | |
if isinstance(self.adapter, ReferenceAdapter): | |
# need -1 to 1 | |
validation_image = transforms.ToTensor()(validation_image) | |
validation_image = validation_image * 2.0 - 1.0 | |
validation_image = validation_image.unsqueeze(0) | |
self.adapter.set_reference_images(validation_image) | |
if network is not None: | |
network.multiplier = gen_config.network_multiplier | |
torch.manual_seed(gen_config.seed) | |
torch.cuda.manual_seed(gen_config.seed) | |
generator = torch.manual_seed(gen_config.seed) | |
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ | |
and gen_config.adapter_image_path is not None: | |
# run through the adapter to saturate the embeds | |
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
validation_image) | |
self.adapter(conditional_clip_embeds) | |
if self.adapter is not None and isinstance(self.adapter, CustomAdapter): | |
# handle condition the prompts | |
gen_config.prompt = self.adapter.condition_prompt( | |
gen_config.prompt, | |
is_unconditional=False, | |
) | |
gen_config.prompt_2 = gen_config.prompt | |
gen_config.negative_prompt = self.adapter.condition_prompt( | |
gen_config.negative_prompt, | |
is_unconditional=True, | |
) | |
gen_config.negative_prompt_2 = gen_config.negative_prompt | |
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: | |
self.adapter.trigger_pre_te( | |
tensors_0_1=validation_image, | |
is_training=False, | |
has_been_preprocessed=False, | |
quad_count=4 | |
) | |
# encode the prompt ourselves so we can do fun stuff with embeddings | |
if isinstance(self.adapter, CustomAdapter): | |
self.adapter.is_unconditional_run = False | |
conditional_embeds = self.encode_prompt( | |
gen_config.prompt, gen_config.prompt_2, force_all=True) | |
if isinstance(self.adapter, CustomAdapter): | |
self.adapter.is_unconditional_run = True | |
unconditional_embeds = self.encode_prompt( | |
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True | |
) | |
if isinstance(self.adapter, CustomAdapter): | |
self.adapter.is_unconditional_run = False | |
# allow any manipulations to take place to embeddings | |
gen_config.post_process_embeddings( | |
conditional_embeds, | |
unconditional_embeds, | |
) | |
if self.decorator is not None: | |
# apply the decorator to the embeddings | |
conditional_embeds.text_embeds = self.decorator( | |
conditional_embeds.text_embeds) | |
unconditional_embeds.text_embeds = self.decorator( | |
unconditional_embeds.text_embeds, is_unconditional=True) | |
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ | |
and gen_config.adapter_image_path is not None: | |
# apply the image projection | |
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
validation_image) | |
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, | |
True) | |
conditional_embeds = self.adapter( | |
conditional_embeds, conditional_clip_embeds, is_unconditional=False) | |
unconditional_embeds = self.adapter( | |
unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) | |
if self.adapter is not None and isinstance(self.adapter, CustomAdapter): | |
conditional_embeds = self.adapter.condition_encoded_embeds( | |
tensors_0_1=validation_image, | |
prompt_embeds=conditional_embeds, | |
is_training=False, | |
has_been_preprocessed=False, | |
is_generating_samples=True, | |
) | |
unconditional_embeds = self.adapter.condition_encoded_embeds( | |
tensors_0_1=validation_image, | |
prompt_embeds=unconditional_embeds, | |
is_training=False, | |
has_been_preprocessed=False, | |
is_unconditional=True, | |
is_generating_samples=True, | |
) | |
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( | |
gen_config.extra_values) > 0: | |
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, | |
dtype=self.torch_dtype) | |
# apply extra values to the embeddings | |
self.adapter.add_extra_values( | |
extra_values, is_unconditional=False) | |
self.adapter.add_extra_values(torch.zeros_like( | |
extra_values), is_unconditional=True) | |
pass # todo remove, for debugging | |
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: | |
# if we have a refiner loaded, set the denoising end at the refiner start | |
extra['denoising_end'] = gen_config.refiner_start_at | |
extra['output_type'] = 'latent' | |
if not self.is_xl: | |
raise ValueError( | |
"Refiner is only supported for XL models") | |
conditional_embeds = conditional_embeds.to( | |
self.device_torch, dtype=self.unet.dtype) | |
unconditional_embeds = unconditional_embeds.to( | |
self.device_torch, dtype=self.unet.dtype) | |
img = self.generate_single_image( | |
pipeline, | |
gen_config, | |
conditional_embeds, | |
unconditional_embeds, | |
generator, | |
extra, | |
) | |
gen_config.save_image(img, i) | |
gen_config.log_image(img, i) | |
self._after_sample_image(i, len(image_configs)) | |
flush() | |
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): | |
self.adapter.clear_memory() | |
# clear pipeline and cache to reduce vram usage | |
del pipeline | |
torch.cuda.empty_cache() | |
# restore training state | |
torch.set_rng_state(rng_state) | |
if cuda_rng_state is not None: | |
torch.cuda.set_rng_state(cuda_rng_state) | |
self.restore_device_state() | |
if network is not None: | |
network.train() | |
network.multiplier = start_multiplier | |
self.unet.to(self.device_torch, dtype=self.torch_dtype) | |
if network.is_merged_in: | |
network.merge_out(merge_multiplier) | |
# self.tokenizer.to(original_device_dict['tokenizer']) | |
# refuse loras | |
if self.model_config.assistant_lora_path is not None: | |
print_acc("Loading assistant lora") | |
if self.invert_assistant_lora: | |
self.assistant_lora.is_active = False | |
# move weights off the device | |
self.assistant_lora.force_to('cpu', self.torch_dtype) | |
else: | |
self.assistant_lora.is_active = True | |
if self.model_config.inference_lora_path is not None: | |
print_acc("Unloading inference lora") | |
self.assistant_lora.is_active = False | |
# move weights off the device | |
self.assistant_lora.force_to('cpu', self.torch_dtype) | |
flush() | |
def get_latent_noise( | |
self, | |
height=None, | |
width=None, | |
pixel_height=None, | |
pixel_width=None, | |
batch_size=1, | |
noise_offset=0.0, | |
): | |
VAE_SCALE_FACTOR = 2 ** ( | |
len(self.vae.config['block_out_channels']) - 1) | |
if height is None and pixel_height is None: | |
raise ValueError("height or pixel_height must be specified") | |
if width is None and pixel_width is None: | |
raise ValueError("width or pixel_width must be specified") | |
if height is None: | |
height = pixel_height // VAE_SCALE_FACTOR | |
if width is None: | |
width = pixel_width // VAE_SCALE_FACTOR | |
num_channels = self.unet_unwrapped.config['in_channels'] | |
if self.is_flux: | |
# has 64 channels in for some reason | |
num_channels = 16 | |
noise = torch.randn( | |
( | |
batch_size, | |
num_channels, | |
height, | |
width, | |
), | |
device=self.unet.device, | |
) | |
noise = apply_noise_offset(noise, noise_offset) | |
return noise | |
def get_latent_noise_from_latents( | |
self, | |
latents: torch.Tensor, | |
noise_offset=0.0 | |
): | |
noise = torch.randn_like(latents) | |
noise = apply_noise_offset(noise, noise_offset) | |
return noise | |
def add_noise( | |
self, | |
original_samples: torch.FloatTensor, | |
noise: torch.FloatTensor, | |
timesteps: torch.IntTensor, | |
**kwargs, | |
) -> torch.FloatTensor: | |
original_samples_chunks = torch.chunk( | |
original_samples, original_samples.shape[0], dim=0) | |
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) | |
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) | |
if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): | |
timesteps_chunks = [timesteps_chunks[0]] * \ | |
len(original_samples_chunks) | |
noisy_latents_chunks = [] | |
for idx in range(original_samples.shape[0]): | |
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], | |
timesteps_chunks[idx]) | |
noisy_latents_chunks.append(noisy_latents) | |
noisy_latents = torch.cat(noisy_latents_chunks, dim=0) | |
return noisy_latents | |
def predict_noise( | |
self, | |
latents: torch.Tensor, | |
text_embeddings: Union[PromptEmbeds, None] = None, | |
timestep: Union[int, torch.Tensor] = 1, | |
guidance_scale=7.5, | |
guidance_rescale=0, | |
add_time_ids=None, | |
conditional_embeddings: Union[PromptEmbeds, None] = None, | |
unconditional_embeddings: Union[PromptEmbeds, None] = None, | |
is_input_scaled=False, | |
detach_unconditional=False, | |
rescale_cfg=None, | |
return_conditional_pred=False, | |
guidance_embedding_scale=1.0, | |
bypass_guidance_embedding=False, | |
batch: Union[None, 'DataLoaderBatchDTO'] = None, | |
**kwargs, | |
): | |
conditional_pred = None | |
# get the embeddings | |
if text_embeddings is None and conditional_embeddings is None: | |
raise ValueError( | |
"Either text_embeddings or conditional_embeddings must be specified") | |
if text_embeddings is None and unconditional_embeddings is not None: | |
text_embeddings = concat_prompt_embeds([ | |
unconditional_embeddings, # negative embedding | |
conditional_embeddings, # positive embedding | |
]) | |
elif text_embeddings is None and conditional_embeddings is not None: | |
# not doing cfg | |
text_embeddings = conditional_embeddings | |
# CFG is comparing neg and positive, if we have concatenated embeddings | |
# then we are doing it, otherwise we are not and takes half the time. | |
do_classifier_free_guidance = True | |
# check if batch size of embeddings matches batch size of latents | |
if isinstance(text_embeddings.text_embeds, list): | |
te_batch_size = text_embeddings.text_embeds[0].shape[0] | |
else: | |
te_batch_size = text_embeddings.text_embeds.shape[0] | |
if latents.shape[0] == te_batch_size: | |
do_classifier_free_guidance = False | |
elif latents.shape[0] * 2 != te_batch_size: | |
raise ValueError( | |
"Batch size of latents must be the same or half the batch size of text embeddings") | |
latents = latents.to(self.device_torch) | |
text_embeddings = text_embeddings.to(self.device_torch) | |
timestep = timestep.to(self.device_torch) | |
# if timestep is zero dim, unsqueeze it | |
if len(timestep.shape) == 0: | |
timestep = timestep.unsqueeze(0) | |
# if we only have 1 timestep, we can just use the same timestep for all | |
if timestep.shape[0] == 1 and latents.shape[0] > 1: | |
# check if it is rank 1 or 2 | |
if len(timestep.shape) == 1: | |
timestep = timestep.repeat(latents.shape[0]) | |
else: | |
timestep = timestep.repeat(latents.shape[0], 0) | |
# handle t2i adapters | |
if 'down_intrablock_additional_residuals' in kwargs: | |
# go through each item and concat if doing cfg and it doesnt have the same shape | |
for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): | |
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: | |
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([ | |
item] * 2, dim=0) | |
# handle controlnet | |
if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: | |
# go through each item and concat if doing cfg and it doesnt have the same shape | |
for idx, item in enumerate(kwargs['down_block_additional_residuals']): | |
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: | |
kwargs['down_block_additional_residuals'][idx] = torch.cat([ | |
item] * 2, dim=0) | |
for idx, item in enumerate(kwargs['mid_block_additional_residual']): | |
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: | |
kwargs['mid_block_additional_residual'][idx] = torch.cat( | |
[item] * 2, dim=0) | |
def scale_model_input(model_input, timestep_tensor): | |
if is_input_scaled: | |
return model_input | |
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) | |
timestep_chunks = torch.chunk( | |
timestep_tensor, timestep_tensor.shape[0], dim=0) | |
out_chunks = [] | |
# unsqueeze if timestep is zero dim | |
for idx in range(model_input.shape[0]): | |
# if scheduler has step_index | |
if hasattr(self.noise_scheduler, '_step_index'): | |
self.noise_scheduler._step_index = None | |
out_chunks.append( | |
self.noise_scheduler.scale_model_input( | |
mi_chunks[idx], timestep_chunks[idx]) | |
) | |
return torch.cat(out_chunks, dim=0) | |
with torch.no_grad(): | |
if do_classifier_free_guidance: | |
# if we are doing classifier free guidance, need to double up | |
latent_model_input = torch.cat([latents] * 2, dim=0) | |
timestep = torch.cat([timestep] * 2) | |
else: | |
latent_model_input = latents | |
latent_model_input = scale_model_input( | |
latent_model_input, timestep) | |
# check if we need to concat timesteps | |
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: | |
ts_bs = timestep.shape[0] | |
if ts_bs != latent_model_input.shape[0]: | |
if ts_bs == 1: | |
timestep = torch.cat( | |
[timestep] * latent_model_input.shape[0]) | |
elif ts_bs * 2 == latent_model_input.shape[0]: | |
timestep = torch.cat([timestep] * 2, dim=0) | |
else: | |
raise ValueError( | |
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") | |
# predict the noise residual | |
if self.unet.device != self.device_torch: | |
self.unet.to(self.device_torch) | |
if self.unet.dtype != self.torch_dtype: | |
self.unet = self.unet.to(dtype=self.torch_dtype) | |
# check if get_noise prediction has guidance_embedding_scale | |
# if it does not, we dont pass it | |
signatures = inspect.signature(self.get_noise_prediction).parameters | |
if 'guidance_embedding_scale' in signatures: | |
kwargs['guidance_embedding_scale'] = guidance_embedding_scale | |
if 'bypass_guidance_embedding' in signatures: | |
kwargs['bypass_guidance_embedding'] = bypass_guidance_embedding | |
if 'batch' in signatures: | |
kwargs['batch'] = batch | |
noise_pred = self.get_noise_prediction( | |
latent_model_input=latent_model_input, | |
timestep=timestep, | |
text_embeddings=text_embeddings, | |
**kwargs | |
) | |
conditional_pred = noise_pred | |
if do_classifier_free_guidance: | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) | |
conditional_pred = noise_pred_text | |
if detach_unconditional: | |
noise_pred_uncond = noise_pred_uncond.detach() | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
if rescale_cfg is not None and rescale_cfg != guidance_scale: | |
with torch.no_grad(): | |
# do cfg at the target rescale so we can match it | |
target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
target_mean = target_pred_mean_std.mean( | |
[1, 2, 3], keepdim=True).detach() | |
target_std = target_pred_mean_std.std( | |
[1, 2, 3], keepdim=True).detach() | |
pred_mean = noise_pred.mean( | |
[1, 2, 3], keepdim=True).detach() | |
pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() | |
# match the mean and std | |
noise_pred = (noise_pred - pred_mean) / pred_std | |
noise_pred = (noise_pred * target_std) + target_mean | |
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 | |
if guidance_rescale > 0.0: | |
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
noise_pred = rescale_noise_cfg( | |
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | |
if return_conditional_pred: | |
return noise_pred, conditional_pred | |
return noise_pred | |
def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): | |
if noise_scheduler is None: | |
noise_scheduler = self.noise_scheduler | |
# // sometimes they are on the wrong device, no idea why | |
if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): | |
try: | |
noise_scheduler.betas = noise_scheduler.betas.to( | |
self.device_torch) | |
noise_scheduler.alphas = noise_scheduler.alphas.to( | |
self.device_torch) | |
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to( | |
self.device_torch) | |
except Exception as e: | |
pass | |
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) | |
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) | |
timestep_chunks = torch.chunk( | |
timestep_tensor, timestep_tensor.shape[0], dim=0) | |
out_chunks = [] | |
if len(timestep_chunks) == 1 and len(mi_chunks) > 1: | |
# expand timestep to match | |
timestep_chunks = timestep_chunks * len(mi_chunks) | |
for idx in range(model_input.shape[0]): | |
# Reset it so it is unique for the | |
if hasattr(noise_scheduler, '_step_index'): | |
noise_scheduler._step_index = None | |
if hasattr(noise_scheduler, 'is_scale_input_called'): | |
noise_scheduler.is_scale_input_called = True | |
out_chunks.append( | |
noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ | |
0] | |
) | |
return torch.cat(out_chunks, dim=0) | |
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 | |
def diffuse_some_steps( | |
self, | |
latents: torch.FloatTensor, | |
text_embeddings: PromptEmbeds, | |
total_timesteps: int = 1000, | |
start_timesteps=0, | |
guidance_scale=1, | |
add_time_ids=None, | |
bleed_ratio: float = 0.5, | |
bleed_latents: torch.FloatTensor = None, | |
is_input_scaled=False, | |
return_first_prediction=False, | |
**kwargs, | |
): | |
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] | |
first_prediction = None | |
for timestep in tqdm(timesteps_to_run, leave=False): | |
timestep = timestep.unsqueeze_(0) | |
noise_pred, conditional_pred = self.predict_noise( | |
latents, | |
text_embeddings, | |
timestep, | |
guidance_scale=guidance_scale, | |
add_time_ids=add_time_ids, | |
is_input_scaled=is_input_scaled, | |
return_conditional_pred=True, | |
**kwargs, | |
) | |
# some schedulers need to run separately, so do that. (euler for example) | |
if return_first_prediction and first_prediction is None: | |
first_prediction = conditional_pred | |
latents = self.step_scheduler(noise_pred, latents, timestep) | |
# if not last step, and bleeding, bleed in some latents | |
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: | |
latents = (latents * (1 - bleed_ratio)) + \ | |
(bleed_latents * bleed_ratio) | |
# only skip first scaling | |
is_input_scaled = False | |
# return latents_steps | |
if return_first_prediction: | |
return latents, first_prediction | |
return latents | |
def encode_prompt( | |
self, | |
prompt, | |
prompt2=None, | |
num_images_per_prompt=1, | |
force_all=False, | |
long_prompts=False, | |
max_length=None, | |
dropout_prob=0.0, | |
) -> PromptEmbeds: | |
# sd1.5 embeddings are (bs, 77, 768) | |
prompt = prompt | |
# if it is not a list, make it one | |
if not isinstance(prompt, list): | |
prompt = [prompt] | |
if prompt2 is not None and not isinstance(prompt2, list): | |
prompt2 = [prompt2] | |
return self.get_prompt_embeds(prompt) | |
def encode_images( | |
self, | |
image_list: List[torch.Tensor], | |
device=None, | |
dtype=None | |
): | |
if device is None: | |
device = self.vae_device_torch | |
if dtype is None: | |
dtype = self.vae_torch_dtype | |
latent_list = [] | |
# Move to vae to device if on cpu | |
if self.vae.device == 'cpu': | |
self.vae.to(device) | |
self.vae.eval() | |
self.vae.requires_grad_(False) | |
# move to device and dtype | |
image_list = [image.to(device, dtype=dtype) for image in image_list] | |
VAE_SCALE_FACTOR = 2 ** ( | |
len(self.vae.config['block_out_channels']) - 1) | |
# resize images if not divisible by 8 | |
for i in range(len(image_list)): | |
image = image_list[i] | |
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: | |
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, | |
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) | |
images = torch.stack(image_list) | |
if isinstance(self.vae, AutoencoderTiny): | |
latents = self.vae.encode(images, return_dict=False)[0] | |
else: | |
latents = self.vae.encode(images).latent_dist.sample() | |
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 | |
# flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 | |
# z = self.scale_factor * (z - self.shift_factor) | |
latents = self.vae.config['scaling_factor'] * (latents - shift) | |
latents = latents.to(device, dtype=dtype) | |
return latents | |
def decode_latents( | |
self, | |
latents: torch.Tensor, | |
device=None, | |
dtype=None | |
): | |
if device is None: | |
device = self.device | |
if dtype is None: | |
dtype = self.torch_dtype | |
# Move to vae to device if on cpu | |
if self.vae.device == 'cpu': | |
self.vae.to(self.device) | |
latents = latents.to(device, dtype=dtype) | |
latents = ( | |
latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] | |
images = self.vae.decode(latents).sample | |
images = images.to(device, dtype=dtype) | |
return images | |
def encode_image_prompt_pairs( | |
self, | |
prompt_list: List[str], | |
image_list: List[torch.Tensor], | |
device=None, | |
dtype=None | |
): | |
# todo check image types and expand and rescale as needed | |
# device and dtype are for outputs | |
if device is None: | |
device = self.device | |
if dtype is None: | |
dtype = self.torch_dtype | |
embedding_list = [] | |
latent_list = [] | |
# embed the prompts | |
for prompt in prompt_list: | |
embedding = self.encode_prompt(prompt).to( | |
self.device_torch, dtype=dtype) | |
embedding_list.append(embedding) | |
return embedding_list, latent_list | |
def get_weight_by_name(self, name): | |
# weights begin with te{te_num}_ for text encoder | |
# weights begin with unet_ for unet_ | |
if name.startswith('te'): | |
key = name[4:] | |
# text encoder | |
te_num = int(name[2]) | |
if isinstance(self.text_encoder, list): | |
return self.text_encoder[te_num].state_dict()[key] | |
else: | |
return self.text_encoder.state_dict()[key] | |
elif name.startswith('unet'): | |
key = name[5:] | |
# unet | |
return self.unet.state_dict()[key] | |
raise ValueError(f"Unknown weight name: {name}") | |
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): | |
return inject_trigger_into_prompt( | |
prompt, | |
trigger=trigger, | |
to_replace_list=to_replace_list, | |
add_if_not_present=add_if_not_present, | |
) | |
def state_dict(self, vae=True, text_encoder=True, unet=True): | |
state_dict = OrderedDict() | |
if vae: | |
for k, v in self.vae.state_dict().items(): | |
new_key = k if k.startswith( | |
f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" | |
state_dict[new_key] = v | |
if text_encoder: | |
if isinstance(self.text_encoder, list): | |
for i, encoder in enumerate(self.text_encoder): | |
for k, v in encoder.state_dict().items(): | |
new_key = k if k.startswith( | |
f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" | |
state_dict[new_key] = v | |
else: | |
for k, v in self.text_encoder.state_dict().items(): | |
new_key = k if k.startswith( | |
f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" | |
state_dict[new_key] = v | |
if unet: | |
for k, v in self.unet.state_dict().items(): | |
new_key = k if k.startswith( | |
f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" | |
state_dict[new_key] = v | |
return state_dict | |
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ | |
OrderedDict[ | |
str, Parameter]: | |
named_params: OrderedDict[str, Parameter] = OrderedDict() | |
if vae: | |
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): | |
named_params[name] = param | |
if text_encoder: | |
if isinstance(self.text_encoder, list): | |
for i, encoder in enumerate(self.text_encoder): | |
if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: | |
# dont add these params | |
continue | |
if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: | |
# dont add these params | |
continue | |
for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): | |
named_params[name] = param | |
else: | |
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): | |
named_params[name] = param | |
if unet: | |
if self.is_flux or self.is_lumina2 or self.is_transformer: | |
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): | |
named_params[name] = param | |
else: | |
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): | |
named_params[name] = param | |
if self.model_config.ignore_if_contains is not None: | |
# remove params that contain the ignore_if_contains from named params | |
for key in list(named_params.keys()): | |
if any([s in key for s in self.model_config.ignore_if_contains]): | |
del named_params[key] | |
if self.model_config.only_if_contains is not None: | |
# remove params that do not contain the only_if_contains from named params | |
for key in list(named_params.keys()): | |
if not any([s in key for s in self.model_config.only_if_contains]): | |
del named_params[key] | |
if refiner: | |
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): | |
named_params[name] = param | |
# convert to state dict keys, jsut replace . with _ on keys | |
if state_dict_keys: | |
new_named_params = OrderedDict() | |
for k, v in named_params.items(): | |
# replace only the first . with an _ | |
new_key = k.replace('.', '_', 1) | |
new_named_params[new_key] = v | |
named_params = new_named_params | |
return named_params | |
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): | |
self.save_model( | |
output_path=output_file, | |
meta=meta, | |
save_dtype=save_dtype | |
) | |
def prepare_optimizer_params( | |
self, | |
unet=False, | |
text_encoder=False, | |
text_encoder_lr=None, | |
unet_lr=None, | |
refiner_lr=None, | |
refiner=False, | |
default_lr=1e-6, | |
): | |
# todo maybe only get locon ones? | |
# not all items are saved, to make it match, we need to match out save mappings | |
# and not train anything not mapped. Also add learning rate | |
version = 'sd1' | |
if self.is_xl: | |
version = 'sdxl' | |
if self.is_v2: | |
version = 'sd2' | |
mapping_filename = f"stable_diffusion_{version}.json" | |
mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) | |
with open(mapping_path, 'r') as f: | |
mapping = json.load(f) | |
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] | |
trainable_parameters = [] | |
# we use state dict to find params | |
if unet: | |
named_params = self.named_parameters( | |
vae=False, unet=unet, text_encoder=False, state_dict_keys=True) | |
unet_lr = unet_lr if unet_lr is not None else default_lr | |
params = [] | |
for param in named_params.values(): | |
if param.requires_grad: | |
params.append(param) | |
param_data = {"params": params, "lr": unet_lr} | |
trainable_parameters.append(param_data) | |
print_acc(f"Found {len(params)} trainable parameter in unet") | |
if text_encoder: | |
named_params = self.named_parameters( | |
vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) | |
text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr | |
params = [] | |
for key, diffusers_key in ldm_diffusers_keymap.items(): | |
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: | |
if named_params[diffusers_key].requires_grad: | |
params.append(named_params[diffusers_key]) | |
param_data = {"params": params, "lr": text_encoder_lr} | |
trainable_parameters.append(param_data) | |
print_acc( | |
f"Found {len(params)} trainable parameter in text encoder") | |
if refiner: | |
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, | |
state_dict_keys=True) | |
refiner_lr = refiner_lr if refiner_lr is not None else default_lr | |
params = [] | |
for key, diffusers_key in ldm_diffusers_keymap.items(): | |
diffusers_key = f"refiner_{diffusers_key}" | |
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: | |
if named_params[diffusers_key].requires_grad: | |
params.append(named_params[diffusers_key]) | |
param_data = {"params": params, "lr": refiner_lr} | |
trainable_parameters.append(param_data) | |
print_acc(f"Found {len(params)} trainable parameter in refiner") | |
return trainable_parameters | |
def save_device_state(self): | |
# saves the current device state for all modules | |
# this is useful for when we want to alter the state and restore it | |
unet_has_grad = self.get_model_has_grad() | |
self.device_state = { | |
**empty_preset, | |
'vae': { | |
'training': self.vae.training, | |
'device': self.vae.device, | |
}, | |
'unet': { | |
'training': self.unet.training, | |
'device': self.unet.device, | |
'requires_grad': unet_has_grad, | |
}, | |
} | |
if isinstance(self.text_encoder, list): | |
self.device_state['text_encoder']: List[dict] = [] | |
for encoder in self.text_encoder: | |
te_has_grad = self.get_te_has_grad() | |
self.device_state['text_encoder'].append({ | |
'training': encoder.training, | |
'device': encoder.device, | |
# todo there has to be a better way to do this | |
'requires_grad': te_has_grad | |
}) | |
else: | |
te_has_grad = self.get_te_has_grad() | |
self.device_state['text_encoder'] = { | |
'training': self.text_encoder.training, | |
'device': self.text_encoder.device, | |
'requires_grad': te_has_grad | |
} | |
if self.adapter is not None: | |
if isinstance(self.adapter, IPAdapter): | |
requires_grad = self.adapter.image_proj_model.training | |
adapter_device = self.unet.device | |
elif isinstance(self.adapter, T2IAdapter): | |
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad | |
adapter_device = self.adapter.device | |
elif isinstance(self.adapter, ControlNetModel): | |
requires_grad = self.adapter.conv_in.training | |
adapter_device = self.adapter.device | |
elif isinstance(self.adapter, ClipVisionAdapter): | |
requires_grad = self.adapter.embedder.training | |
adapter_device = self.adapter.device | |
elif isinstance(self.adapter, CustomAdapter): | |
requires_grad = self.adapter.training | |
adapter_device = self.adapter.device | |
elif isinstance(self.adapter, ReferenceAdapter): | |
# todo update this!! | |
requires_grad = True | |
adapter_device = self.adapter.device | |
else: | |
raise ValueError(f"Unknown adapter type: {type(self.adapter)}") | |
self.device_state['adapter'] = { | |
'training': self.adapter.training, | |
'device': adapter_device, | |
'requires_grad': requires_grad, | |
} | |
if self.refiner_unet is not None: | |
self.device_state['refiner_unet'] = { | |
'training': self.refiner_unet.training, | |
'device': self.refiner_unet.device, | |
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, | |
} | |
def restore_device_state(self): | |
# restores the device state for all modules | |
# this is useful for when we want to alter the state and restore it | |
if self.device_state is None: | |
return | |
self.set_device_state(self.device_state) | |
self.device_state = None | |
def set_device_state(self, state): | |
if state['vae']['training']: | |
self.vae.train() | |
else: | |
self.vae.eval() | |
self.vae.to(state['vae']['device']) | |
if state['unet']['training']: | |
self.unet.train() | |
else: | |
self.unet.eval() | |
self.unet.to(state['unet']['device']) | |
if state['unet']['requires_grad']: | |
self.unet.requires_grad_(True) | |
else: | |
self.unet.requires_grad_(False) | |
if isinstance(self.text_encoder, list): | |
for i, encoder in enumerate(self.text_encoder): | |
if isinstance(state['text_encoder'], list): | |
if state['text_encoder'][i]['training']: | |
encoder.train() | |
else: | |
encoder.eval() | |
encoder.to(state['text_encoder'][i]['device']) | |
encoder.requires_grad_( | |
state['text_encoder'][i]['requires_grad']) | |
else: | |
if state['text_encoder']['training']: | |
encoder.train() | |
else: | |
encoder.eval() | |
encoder.to(state['text_encoder']['device']) | |
encoder.requires_grad_( | |
state['text_encoder']['requires_grad']) | |
else: | |
if state['text_encoder']['training']: | |
self.text_encoder.train() | |
else: | |
self.text_encoder.eval() | |
self.text_encoder.to(state['text_encoder']['device']) | |
self.text_encoder.requires_grad_( | |
state['text_encoder']['requires_grad']) | |
if self.adapter is not None: | |
self.adapter.to(state['adapter']['device']) | |
self.adapter.requires_grad_(state['adapter']['requires_grad']) | |
if state['adapter']['training']: | |
self.adapter.train() | |
else: | |
self.adapter.eval() | |
if self.refiner_unet is not None: | |
self.refiner_unet.to(state['refiner_unet']['device']) | |
self.refiner_unet.requires_grad_( | |
state['refiner_unet']['requires_grad']) | |
if state['refiner_unet']['training']: | |
self.refiner_unet.train() | |
else: | |
self.refiner_unet.eval() | |
flush() | |
def set_device_state_preset(self, device_state_preset: DeviceStatePreset): | |
# sets a preset for device state | |
# save current state first | |
self.save_device_state() | |
active_modules = [] | |
training_modules = [] | |
if device_state_preset in ['cache_latents']: | |
active_modules = ['vae'] | |
if device_state_preset in ['cache_clip']: | |
active_modules = ['clip'] | |
if device_state_preset in ['generate']: | |
active_modules = ['vae', 'unet', | |
'text_encoder', 'adapter', 'refiner_unet'] | |
state = copy.deepcopy(empty_preset) | |
# vae | |
state['vae'] = { | |
'training': 'vae' in training_modules, | |
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', | |
'requires_grad': 'vae' in training_modules, | |
} | |
# unet | |
state['unet'] = { | |
'training': 'unet' in training_modules, | |
'device': self.device_torch if 'unet' in active_modules else 'cpu', | |
'requires_grad': 'unet' in training_modules, | |
} | |
if self.refiner_unet is not None: | |
state['refiner_unet'] = { | |
'training': 'refiner_unet' in training_modules, | |
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', | |
'requires_grad': 'refiner_unet' in training_modules, | |
} | |
# text encoder | |
if isinstance(self.text_encoder, list): | |
state['text_encoder'] = [] | |
for i, encoder in enumerate(self.text_encoder): | |
state['text_encoder'].append({ | |
'training': 'text_encoder' in training_modules, | |
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', | |
'requires_grad': 'text_encoder' in training_modules, | |
}) | |
else: | |
state['text_encoder'] = { | |
'training': 'text_encoder' in training_modules, | |
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', | |
'requires_grad': 'text_encoder' in training_modules, | |
} | |
if self.adapter is not None: | |
state['adapter'] = { | |
'training': 'adapter' in training_modules, | |
'device': self.device_torch if 'adapter' in active_modules else 'cpu', | |
'requires_grad': 'adapter' in training_modules, | |
} | |
self.set_device_state(state) | |
def text_encoder_to(self, *args, **kwargs): | |
if isinstance(self.text_encoder, list): | |
for encoder in self.text_encoder: | |
encoder.to(*args, **kwargs) | |
else: | |
self.text_encoder.to(*args, **kwargs) | |
def convert_lora_weights_before_save(self, state_dict): | |
# can be overridden in child classes to convert weights before saving | |
return state_dict | |
def convert_lora_weights_before_load(self, state_dict): | |
# can be overridden in child classes to convert weights before loading | |
return state_dict | |
def condition_noisy_latents(self, latents: torch.Tensor, batch:'DataLoaderBatchDTO'): | |
# can be overridden in child classes to condition latents before noise prediction | |
return latents | |
def get_transformer_block_names(self) -> Optional[List[str]]: | |
# override in child classes to get transformer block names for lora targeting | |
return None | |
def get_base_model_version(self) -> str: | |
# override in child classes to get the base model version | |
return "unknown" | |