ramimu's picture
Upload 586 files
1c72248 verified
import os
import random
from collections import OrderedDict
from typing import Union, Literal, List, Optional
import numpy as np
from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel
import torch.functional as F
from safetensors.torch import load_file
from torch.utils.data import DataLoader, ConcatDataset
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.print import print_acc
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
apply_learnable_snr_gos, LearnableSNRGamma
import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
from diffusers import EMAModel
import math
from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
def flush():
torch.cuda.empty_cache()
gc.collect()
adapter_transforms = transforms.Compose([
transforms.ToTensor(),
])
class SDTrainer(BaseSDTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super().__init__(process_id, job, config, **kwargs)
self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None]
self.do_prior_prediction = False
self.do_long_prompts = False
self.do_guided_loss = False
self.taesd: Optional[AutoencoderTiny] = None
self._clip_image_embeds_unconditional: Union[List[str], None] = None
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
self.do_grad_scale = True
if self.is_fine_tuning and self.is_bfloat:
self.do_grad_scale = False
if self.adapter_config is not None:
if self.adapter_config.train:
self.do_grad_scale = False
# if self.train_config.dtype in ["fp16", "float16"]:
# # patch the scaler to allow fp16 training
# org_unscale_grads = self.scaler._unscale_grads_
# def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
# return org_unscale_grads(optimizer, inv_scale, found_inf, True)
# self.scaler._unscale_grads_ = _unscale_grads_replacer
self.cached_blank_embeds: Optional[PromptEmbeds] = None
self.cached_trigger_embeds: Optional[PromptEmbeds] = None
self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None
self.dfe: Optional[DiffusionFeatureExtractor] = None
if self.train_config.diff_output_preservation:
if self.trigger_word is None:
raise ValueError("diff_output_preservation requires a trigger_word to be set")
if self.network_config is None:
raise ValueError("diff_output_preservation requires a network to be set")
if self.train_config.train_text_encoder:
raise ValueError("diff_output_preservation is not supported with train_text_encoder")
# always do a prior prediction when doing diff output preservation
self.do_prior_prediction = True
def before_model_load(self):
pass
def before_dataset_load(self):
self.assistant_adapter = None
# get adapter assistant if one is set
if self.train_config.adapter_assist_name_or_path is not None:
adapter_path = self.train_config.adapter_assist_name_or_path
if self.train_config.adapter_assist_type == "t2i":
# dont name this adapter since we are not training it
self.assistant_adapter = T2IAdapter.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
).to(self.device_torch)
elif self.train_config.adapter_assist_type == "control_net":
self.assistant_adapter = ControlNetModel.from_pretrained(
adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype)
).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
else:
raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}")
self.assistant_adapter.eval()
self.assistant_adapter.requires_grad_(False)
flush()
if self.train_config.train_turbo and self.train_config.show_turbo_outputs:
if self.model_config.is_xl:
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl",
torch_dtype=get_torch_dtype(self.train_config.dtype))
else:
self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd",
torch_dtype=get_torch_dtype(self.train_config.dtype))
self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch)
self.taesd.eval()
self.taesd.requires_grad_(False)
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.train_config.do_prior_divergence:
self.do_prior_prediction = True
# move vae to device if we did not cache latents
if not self.is_latents_cached:
self.sd.vae.eval()
self.sd.vae.to(self.device_torch)
else:
# offload it. Already cached
self.sd.vae.to('cpu')
flush()
add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch)
if self.adapter is not None:
self.adapter.to(self.device_torch)
# check if we have regs and using adapter and caching clip embeddings
has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0
is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))])
if has_reg and is_caching_clip_embeddings:
# we need a list of unconditional clip image embeds from other datasets to handle regs
unconditional_clip_image_embeds = []
datasets = get_dataloader_datasets(self.data_loader)
for i in range(len(datasets)):
unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache
if len(unconditional_clip_image_embeds) == 0:
raise ValueError("No unconditional clip image embeds found. This should not happen")
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
if self.train_config.negative_prompt is not None:
if os.path.exists(self.train_config.negative_prompt):
with open(self.train_config.negative_prompt, 'r') as f:
self.negative_prompt_pool = f.readlines()
# remove empty
self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""]
else:
# single prompt
self.negative_prompt_pool = [self.train_config.negative_prompt]
# handle unload text encoder
if self.train_config.unload_text_encoder:
with torch.no_grad():
if self.train_config.train_text_encoder:
raise ValueError("Cannot unload text encoder if training text encoder")
# cache embeddings
print_acc("\n***** UNLOADING TEXT ENCODER *****")
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
print_acc("***********************************")
print_acc("")
self.sd.text_encoder_to(self.device_torch)
self.cached_blank_embeds = self.sd.encode_prompt("")
if self.trigger_word is not None:
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
# move back to cpu
self.sd.text_encoder_to('cpu')
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
vae = None
if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
vae = self.sd.vae
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
self.dfe.to(self.device_torch)
if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing:
# must be set to train for gradient checkpointing to work
self.dfe.vision_encoder.train()
self.dfe.vision_encoder.gradient_checkpointing = True
else:
self.dfe.eval()
# enable gradient checkpointing on the vae
if vae is not None and self.train_config.gradient_checkpointing:
vae.enable_gradient_checkpointing()
vae.train()
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
# to process turbo learning, we make one big step from our current timestep to the end
# we then denoise the prediction on that remaining step and target our loss to our target latents
# this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so.
# needs to be done on each item in batch as they may all have different timesteps
batch_size = pred.shape[0]
pred_chunks = torch.chunk(pred, batch_size, dim=0)
noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0)
timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0)
latent_chunks = torch.chunk(batch.latents, batch_size, dim=0)
noise_chunks = torch.chunk(noise, batch_size, dim=0)
with torch.no_grad():
# set the timesteps to 1000 so we can capture them to calculate the sigmas
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.config.num_train_timesteps,
device=self.device_torch
)
train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach()
train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach()
# set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step
self.sd.noise_scheduler.set_timesteps(
1,
device=self.device_torch
)
denoised_pred_chunks = []
target_pred_chunks = []
for i in range(batch_size):
pred_item = pred_chunks[i]
noisy_latents_item = noisy_latents_chunks[i]
timesteps_item = timesteps_chunks[i]
latents_item = latent_chunks[i]
noise_item = noise_chunks[i]
with torch.no_grad():
timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0]
single_step_timestep_schedule = [timesteps_item.squeeze().item()]
# extract the sigma idx for our midpoint timestep
sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch)
end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1)
end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch)
# add noise to our target
# build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step
# self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach()
self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach()
# set our single timstep
self.sd.noise_scheduler.timesteps = torch.from_numpy(
np.array(single_step_timestep_schedule, dtype=np.float32)
).to(device=self.device_torch)
# set the step index to None so it will be recalculated on first step
self.sd.noise_scheduler._step_index = None
denoised_latent = self.sd.noise_scheduler.step(
pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False
)[0]
residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype(
self.train_config.dtype))
# remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically)
denoised_latent = denoised_latent - residual_noise
denoised_pred_chunks.append(denoised_latent)
denoised_latents = torch.cat(denoised_pred_chunks, dim=0)
# set the scheduler back to the original timesteps
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.config.num_train_timesteps,
device=self.device_torch
)
output = denoised_latents / self.sd.vae.config['scaling_factor']
output = self.sd.vae.decode(output).sample
if self.train_config.show_turbo_outputs:
# since we are completely denoising, we can show them here
with torch.no_grad():
show_tensors(output)
# we return our big partial step denoised latents as our pred and our untouched latents as our target.
# you can do mse against the two here or run the denoised through the vae for pixel space loss against the
# input tensor images.
return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype))
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
is_reg = any(batch.get_is_reg_list())
additional_loss = 0.0
prior_mask_multiplier = None
target_mask_multiplier = None
dtype = get_torch_dtype(self.train_config.dtype)
has_mask = batch.mask_tensor is not None
with torch.no_grad():
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
if self.train_config.match_noise_norm:
# match the norm of the noise
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.pred_scaler != 1.0:
noise_pred = noise_pred * self.train_config.pred_scaler
target = None
if self.train_config.target_noise_multiplier != 1.0:
noise = noise * self.train_config.target_noise_multiplier
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# this only works if doing a prior pred
if prior_pred is not None:
prior_mean = prior_pred.mean([2,3], keepdim=True)
prior_std = prior_pred.std([2,3], keepdim=True)
noise_mean = noise_pred.mean([2,3], keepdim=True)
noise_std = noise_pred.std([2,3], keepdim=True)
mean_adjust = prior_mean - noise_mean
std_adjust = prior_std - noise_std
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
target_mean = noise_mean + mean_adjust
target_std = noise_std + std_adjust
eps = 1e-5
# match the noise to the prior
noise = (noise - noise_mean) / (noise_std + eps)
noise = noise * (target_std + eps) + target_mean
noise = noise.detach()
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
with torch.no_grad():
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
# resize to size of noise_pred
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
# stack first channel to match channels of noise_pred
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
prior_mask_multiplier = 1.0 - prior_mask
# scale so it is a mean of 1
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
if self.sd.is_flow_matching:
target = (noise - batch.latents).detach()
else:
target = noise
elif prior_pred is not None and not self.train_config.do_prior_divergence:
assert not self.train_config.train_turbo
# matching adapter prediction
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif hasattr(self.sd, 'get_loss_target'):
target = self.sd.get_loss_target(
noise=noise,
batch=batch,
timesteps=timesteps,
).detach()
elif self.sd.is_flow_matching:
# forward ODE
target = (noise - batch.latents).detach()
# reverse ODE
# target = (batch.latents - noise).detach()
else:
target = noise
if self.dfe is not None:
if self.dfe.version == 1:
# do diffusion feature extraction on target
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
# do diffusion feature extraction on prediction
pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \
self.train_config.diffusion_feature_extractor_weight
elif self.dfe.version == 2:
# version 2
# do diffusion feature extraction on target
with torch.no_grad():
rectified_flow_target = noise.float() - batch.latents.float()
target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1))
# do diffusion feature extraction on prediction
pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1))
dfe_loss = 0.0
for i in range(len(target_feature_list)):
dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean")
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0
elif self.dfe.version == 3:
dfe_loss = self.dfe(
noise=noise,
noise_pred=noise_pred,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
scheduler=self.sd.noise_scheduler
)
additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight
else:
raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}")
if target is None:
target = noise
pred = noise_pred
if self.train_config.train_turbo:
pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch)
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
assert not self.train_config.train_turbo
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
weighing = batch.sigmas ** -2.0
if loss_target == 'source':
# denoise the latent and compare to the latent in the batch
target = batch.latents
elif loss_target == 'unaugmented':
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
target = unaugmented_latents.detach()
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = target # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# mse loss without reduction
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
elif self.train_config.loss_type == "wavelet":
loss = wavelet_loss(pred, batch.latents, noise)
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
timesteps,
v2=self.train_config.linear_timesteps2
).to(loss.device, dtype=loss.dtype)
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
loss = loss * timestep_weight
if self.train_config.do_prior_divergence and prior_pred is not None:
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
if self.train_config.train_turbo:
mask_multiplier = mask_multiplier[:, 3:, :, :]
# resize to the size of the loss
mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest')
# multiply by our mask
try:
loss = loss * mask_multiplier
except:
# todo handle mask with video models
pass
prior_loss = None
if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None:
assert not self.train_config.train_turbo
if self.train_config.loss_type == "mae":
prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none")
else:
prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none")
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
print_acc("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss
# loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
# apply loss multiplier before prior loss
# multiply by our mask
try:
loss = loss * loss_multiplier
except:
# todo handle mask with video models
pass
if prior_loss is not None:
loss = loss + prior_loss
if not self.train_config.train_turbo:
if self.train_config.learnable_snr_gos:
# add snr_gamma
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma,
fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# check for additional losses
if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None:
loss = loss + self.adapter.additional_loss.mean()
self.adapter.additional_loss = None
if self.train_config.target_norm_std:
# seperate out the batch and channels
pred_std = noise_pred.std([2, 3], keepdim=True)
norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean()
loss = loss + norm_std_loss
return loss + additional_loss
def get_diff_output_preservation_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
def preprocess_batch(self, batch: 'DataLoaderBatchDTO'):
return batch
def get_guided_loss(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
**kwargs
):
loss = get_guidance_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
sd=self.sd,
unconditional_embeds=unconditional_embeds,
train_config=self.train_config,
**kwargs
)
return loss
def get_prior_prediction(
self,
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
unconditional_embeds: Optional[PromptEmbeds] = None,
conditioned_prompts=None,
**kwargs
):
# todo for embeddings, we need to run without trigger words
was_unet_training = self.sd.unet.training
was_network_active = False
if self.network is not None:
was_network_active = self.network.is_active
self.network.is_active = False
can_disable_adapter = False
was_adapter_active = False
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or
isinstance(self.adapter, ReferenceAdapter) or
(isinstance(self.adapter, CustomAdapter))
):
can_disable_adapter = True
was_adapter_active = self.adapter.is_active
self.adapter.is_active = False
if self.train_config.unload_text_encoder and self.adapter is not None and not isinstance(self.adapter, CustomAdapter):
raise ValueError("Prior predictions currently do not support unloading text encoder with adapter")
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
embeds_to_use = conditional_embeds.clone().detach()
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
prompt_list = batch.get_caption_list()
class_name = ''
triggers = ['[trigger]', '[name]']
remove_tokens = []
if self.embed_config is not None:
triggers.append(self.embed_config.trigger)
for i in range(1, self.embed_config.tokens):
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
if self.embed_config.trigger_class_name is not None:
class_name = self.embed_config.trigger_class_name
if self.adapter is not None:
triggers.append(self.adapter_config.trigger)
for i in range(1, self.adapter_config.num_tokens):
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
if self.adapter_config.trigger_class_name is not None:
class_name = self.adapter_config.trigger_class_name
for idx, prompt in enumerate(prompt_list):
for remove_token in remove_tokens:
prompt = prompt.replace(remove_token, '')
for trigger in triggers:
prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this
# self.network.multiplier = 0.0
self.sd.unet.eval()
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2:
# we need to remove the image embeds from the prompt except for flux
embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach()
end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens
embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :]
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.clone().detach()
unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos]
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
rescale_cfg=self.train_config.cfg_rescale,
batch=batch,
**pred_kwargs # adapter residuals in here
)
if was_unet_training:
self.sd.unet.train()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
del pred_kwargs['down_intrablock_additional_residuals']
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs:
del pred_kwargs['mid_block_additional_residual']
if can_disable_adapter:
self.adapter.is_active = was_adapter_active
# restore network
# self.network.multiplier = network_weight_list
if self.network is not None:
self.network.is_active = was_network_active
return prior_pred
def before_unet_predict(self):
pass
def after_unet_predict(self):
pass
def end_of_training_loop(self):
pass
def predict_noise(
self,
noisy_latents: torch.Tensor,
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
guidance_embedding_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
bypass_guidance_embedding=self.train_config.bypass_guidance_embedding,
batch=batch,
**kwargs
)
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
self.timer.start('preprocess_batch')
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_raw(batch)
batch = self.preprocess_batch(batch)
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_processed(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
if encoder.dtype != self.sd.te_torch_dtype:
encoder.to(self.sd.te_torch_dtype)
else:
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
self.sd.text_encoder.to(self.sd.te_torch_dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.train_config.do_cfg or self.train_config.do_random_cfg:
# pick random negative prompts
if self.negative_prompt_pool is not None:
negative_prompts = []
for i in range(noisy_latents.shape[0]):
num_neg = random.randint(1, self.train_config.max_negative_prompts)
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
this_neg_prompt = ', '.join(this_neg_prompts)
negative_prompts.append(this_neg_prompt)
self.batch_negative_prompt = negative_prompts
else:
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
if self.adapter and isinstance(self.adapter, CustomAdapter):
# condition the prompt
# todo handle more than one adapter image
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
network_weight_list = batch.get_network_weight_list()
if self.train_config.single_item_batching:
network_weight_list = network_weight_list + network_weight_list
has_adapter_img = batch.control_tensor is not None
has_clip_image = batch.clip_image_tensor is not None
has_clip_image_embeds = batch.clip_image_embeds is not None
# force it to be true if doing regs as we handle those differently
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
has_clip_image = True
if self._clip_image_embeds_unconditional is not None:
has_clip_image_embeds = True # we are caching embeds, handle that differently
has_clip_image = False
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
raise ValueError(
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
match_adapter_assist = False
# check if we are matching the adapter assistant
if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0:
match_adapter_assist = True
elif self.train_config.match_adapter_chance > 0.0:
match_adapter_assist = torch.rand(
(1,), device=self.device_torch, dtype=dtype
) < self.train_config.match_adapter_chance
self.timer.stop('preprocess_batch')
is_reg = False
with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
is_reg = True
adapter_images = None
sigmas = None
if has_adapter_img and (self.adapter or self.assistant_adapter):
with self.timer('get_adapter_images'):
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
# match in channels
if self.assistant_adapter is not None:
in_channels = self.assistant_adapter.config.in_channels
if adapter_images.shape[1] != in_channels:
# we need to match the channels
adapter_images = adapter_images[:, :in_channels, :, :]
else:
raise NotImplementedError("Adapter images now must be loaded with dataloader")
clip_images = None
if has_clip_image:
with self.timer('get_clip_images'):
# todo move this to data loader
if batch.clip_image_tensor is not None:
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
if batch.mask_tensor is not None:
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
# make avg 1.0
mask_multiplier = mask_multiplier / mask_multiplier.mean()
def get_adapter_multiplier():
if self.adapter and isinstance(self.adapter, T2IAdapter):
# training a t2i adapter, not using as assistant.
return 1.0
elif match_adapter_assist:
# training a texture. We want it high
adapter_strength_min = 0.9
adapter_strength_max = 1.0
else:
# training with assistance, we want it low
# adapter_strength_min = 0.4
# adapter_strength_max = 0.7
adapter_strength_min = 0.5
adapter_strength_max = 1.1
adapter_conditioning_scale = torch.rand(
(1,), device=self.device_torch, dtype=dtype
)
adapter_conditioning_scale = value_map(
adapter_conditioning_scale,
0.0,
1.0,
adapter_strength_min,
adapter_strength_max
)
return adapter_conditioning_scale
# flush()
with self.timer('grad_setup'):
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding is not None:
grad_on_text_encoder = True
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
grad_on_text_encoder = True
if self.adapter_config and self.adapter_config.type == 'te_augmenter':
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
# set the weights
network.multiplier = network_weight_list
# activate network if it exits
prompts_1 = conditioned_prompts
prompts_2 = None
if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl:
prompts_1 = batch.get_caption_short_list()
prompts_2 = conditioned_prompts
# make the batch splits
if self.train_config.single_item_batching:
if self.model_config.refiner_name_or_path is not None:
raise ValueError("Single item batching is not supported when training the refiner")
batch_size = noisy_latents.shape[0]
# chunk/split everything
noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0)
noise_list = torch.chunk(noise, batch_size, dim=0)
timesteps_list = torch.chunk(timesteps, batch_size, dim=0)
conditioned_prompts_list = [[prompt] for prompt in prompts_1]
if imgs is not None:
imgs_list = torch.chunk(imgs, batch_size, dim=0)
else:
imgs_list = [None for _ in range(batch_size)]
if adapter_images is not None:
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
else:
adapter_images_list = [None for _ in range(batch_size)]
if clip_images is not None:
clip_images_list = torch.chunk(clip_images, batch_size, dim=0)
else:
clip_images_list = [None for _ in range(batch_size)]
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
if prompts_2 is None:
prompt_2_list = [None for _ in range(batch_size)]
else:
prompt_2_list = [[prompt] for prompt in prompts_2]
else:
noisy_latents_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditioned_prompts_list = [prompts_1]
imgs_list = [imgs]
adapter_images_list = [adapter_images]
clip_images_list = [clip_images]
mask_multiplier_list = [mask_multiplier]
if prompts_2 is None:
prompt_2_list = [None]
else:
prompt_2_list = [prompts_2]
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip(
noisy_latents_list,
noise_list,
timesteps_list,
conditioned_prompts_list,
imgs_list,
adapter_images_list,
clip_images_list,
mask_multiplier_list,
prompt_2_list
):
# if self.train_config.negative_prompt is not None:
# # add negative prompt
# conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
# range(len(conditioned_prompts))]
# if prompt_2 is not None:
# prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with (network):
# encode clip adapter here so embeds are active for tokenizer
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('encode_clip_vision_embeds'):
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True
)
else:
# just do a blank one
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
),
is_training=True,
has_been_preprocessed=True,
drop=True
)
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg):
quad_count = random.randint(1, 4)
self.adapter.train()
self.adapter.trigger_pre_te(
tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count,
batch_tensor=batch.tensor if not is_reg else None,
batch_size=noisy_latents.shape[0]
)
with self.timer('encode_prompt'):
unconditional_embeds = None
if self.train_config.unload_text_encoder:
with torch.set_grad_enabled(False):
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
conditional_embeds = concat_prompt_embeds(
[embeds_to_use] * noisy_latents.shape[0]
)
if self.train_config.do_cfg:
unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
unconditional_embeds = concat_prompt_embeds(
[unconditional_embeds] * noisy_latents.shape[0]
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
elif grad_on_text_encoder:
with torch.set_grad_enabled(True):
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
# todo only do one and repeat it
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
if self.train_config.diff_output_preservation:
dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts]
dop_prompts_2 = None
if prompt_2 is not None:
dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2]
self.diff_output_preservation_embeds = self.sd.encode_prompt(
dop_prompts, dop_prompts_2,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
if self.train_config.do_cfg:
unconditional_embeds = unconditional_embeds.detach()
if self.decorator:
conditional_embeds.text_embeds = self.decorator(
conditional_embeds.text_embeds
)
if self.train_config.do_cfg:
unconditional_embeds.text_embeds = self.decorator(
unconditional_embeds.text_embeds,
is_unconditional=True
)
# flush()
pred_kwargs = {}
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
down_block_additional_residuals = adapter(adapter_images)
if self.assistant_adapter:
# not training. detach
down_block_additional_residuals = [
sample.to(dtype=dtype).detach() * adapter_multiplier for sample in
down_block_additional_residuals
]
else:
down_block_additional_residuals = [
sample.to(dtype=dtype) * adapter_multiplier for sample in
down_block_additional_residuals
]
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size
if has_clip_image_embeds:
# todo handle reg images better than this
if is_reg:
# get unconditional image embeds from cache
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
quad_count=quad_count
)
else:
conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
batch.clip_image_embeds_unconditional,
quad_count=quad_count
)
elif is_reg:
# we will zero it out in the img embedder
clip_images = torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach()
# drop will zero it out
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images,
drop=True,
is_training=True,
has_been_preprocessed=False,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, image_size, image_size),
device=self.device_torch, dtype=dtype
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=False,
quad_count=quad_count
)
elif has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count,
# do cfg on clip embeds to normalize the embeddings for when doing cfg
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
# cfg_embed_strength=3.0 if not self.train_config.do_cfg else None
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
drop=True,
has_been_preprocessed=True,
quad_count=quad_count
)
else:
print_acc("No Clip Image")
print_acc([file_item.path for file_item in batch.file_items])
raise ValueError("Could not find clip image")
if not self.adapter_config.train_image_encoder:
# we are not training the image encoder, so we need to detach the embeds
conditional_clip_embeds = conditional_clip_embeds.detach()
if self.train_config.do_cfg:
unconditional_clip_embeds = unconditional_clip_embeds.detach()
with self.timer('encode_adapter'):
self.adapter.train()
conditional_embeds = self.adapter(
conditional_embeds.detach(),
conditional_clip_embeds,
is_unconditional=False
)
if self.train_config.do_cfg:
unconditional_embeds = self.adapter(
unconditional_embeds.detach(),
unconditional_clip_embeds,
is_unconditional=True
)
else:
# wipe out unconsitional
self.adapter.last_unconditional = None
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
# pass in our scheduler
self.adapter.noise_scheduler = self.lr_scheduler
if has_clip_image or has_adapter_img:
img_to_use = clip_images if has_clip_image else adapter_images
# currently 0-1 needs to be -1 to 1
reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype)
self.adapter.set_reference_images(reference_images)
self.adapter.noise_scheduler = self.sd.noise_scheduler
elif is_reg:
self.adapter.set_blank_reference_images(noisy_latents.shape[0])
else:
self.adapter.set_reference_images(None)
prior_pred = None
do_reg_prior = False
# if is_reg and (self.network is not None or self.adapter is not None):
# # we are doing a reg image and we have a network or adapter
# do_reg_prior = True
do_inverted_masked_prior = False
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
do_inverted_masked_prior = True
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
do_guidance_prior = False
if batch.unconditional_latents is not None:
# for this not that, we need a prior pred to normalize
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
if guidance_type == 'tnt':
do_guidance_prior = True
if ((
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
with self.timer('prior predict'):
prior_embeds_to_use = conditional_embeds
# use diff_output_preservation embeds if doing dfe
if self.train_config.diff_output_preservation:
prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,
conditional_embeds=prior_embeds_to_use,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
noise=noise,
batch=batch,
unconditional_embeds=unconditional_embeds,
conditioned_prompts=conditioned_prompts
)
if prior_pred is not None:
prior_pred = prior_pred.detach()
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']):
quad_count = random.randint(1, 4)
self.adapter.train()
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=conditional_embeds,
is_training=True,
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg and unconditional_embeds is not None:
unconditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=clip_images,
prompt_embeds=unconditional_embeds,
is_training=True,
has_been_preprocessed=True,
is_unconditional=True,
quad_count=quad_count
)
if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None:
self.adapter.add_extra_values(batch.extra_values.detach())
if self.train_config.do_cfg:
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
is_unconditional=True)
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if self.train_config.do_cfg:
raise ValueError("ControlNetModel is not supported with CFG")
with torch.set_grad_enabled(self.adapter is not None):
adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
with self.timer('encode_adapter'):
# add_text_embeds is pooled_prompt_embeds for sdxl
added_cond_kwargs = {}
if self.sd.is_xl:
added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds
added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents)
down_block_res_samples, mid_block_res_sample = adapter(
noisy_latents,
timesteps,
encoder_hidden_states=conditional_embeds.text_embeds,
controlnet_cond=adapter_images,
conditioning_scale=1.0,
guess_mode=False,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
# do guided loss
loss = self.get_guided_loss(
noisy_latents=noisy_latents,
conditional_embeds=conditional_embeds,
match_adapter_assist=match_adapter_assist,
network_weight_list=network_weight_list,
timesteps=timesteps,
pred_kwargs=pred_kwargs,
batch=batch,
noise=noise,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
)
else:
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
with self.timer('condition_noisy_latents'):
# do it for the model
noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch)
if self.adapter and isinstance(self.adapter, CustomAdapter):
noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch)
with self.timer('predict_unet'):
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
self.after_unet_predict()
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
prior_to_calculate_loss = prior_pred
# if we are doing diff_output_preservation and not noing inverted masked prior
# then we need to send none here so it will not target the prior
if self.train_config.diff_output_preservation and not do_inverted_masked_prior:
prior_to_calculate_loss = None
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
prior_pred=prior_to_calculate_loss,
)
if self.train_config.diff_output_preservation:
# send the loss backwards otherwise checkpointing will fail
self.accelerator.backward(loss)
normal_loss = loss.detach() # dont send backward again
dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0])
dop_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
batch=batch,
**pred_kwargs
)
dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier
self.accelerator.backward(dop_loss)
loss = normal_loss + dop_loss
loss = loss.clone().detach()
# require grad again so the backward wont fail
loss.requires_grad_(True)
# check if nan
if torch.isnan(loss):
print_acc("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
# if self.is_bfloat:
# loss.backward()
# else:
self.accelerator.backward(loss)
return loss.detach()
# flush()
def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
if isinstance(batch, list):
batch_list = batch
else:
batch_list = [batch]
total_loss = None
self.optimizer.zero_grad()
for batch in batch_list:
loss = self.train_single_accumulation(batch)
if total_loss is None:
total_loss = loss
else:
total_loss += loss
if len(batch_list) > 1 and self.model_config.low_vram:
torch.cuda.empty_cache()
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.adapter and isinstance(self.adapter, CustomAdapter):
self.adapter.post_weight_update()
if self.ema is not None:
with self.timer('ema_update'):
self.ema.update()
else:
# gradient accumulation. Just a place for breakpoint
pass
# TODO Should we only step scheduler on grad step? If so, need to recalculate last step
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None:
with self.timer('restore_embeddings'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('restore_adapter'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.adapter.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}
)
self.end_of_training_loop()
return loss_dict