import os import random from collections import OrderedDict from typing import Union, Literal, List, Optional import numpy as np from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel import torch.functional as F from safetensors.torch import load_file from torch.utils.data import DataLoader, ConcatDataset from toolkit import train_tools from toolkit.basic import value_map, adain, get_mean_std from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.config_modules import GuidanceConfig from toolkit.data_loader import get_dataloader_datasets from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter from toolkit.custom_adapter import CustomAdapter from toolkit.print import print_acc from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ apply_learnable_snr_gos, LearnableSNRGamma import gc import torch from jobs.process import BaseSDTrainProcess from torchvision import transforms from diffusers import EMAModel import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss def flush(): torch.cuda.empty_cache() gc.collect() adapter_transforms = transforms.Compose([ transforms.ToTensor(), ]) class SDTrainer(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): super().__init__(process_id, job, config, **kwargs) self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] self.do_prior_prediction = False self.do_long_prompts = False self.do_guided_loss = False self.taesd: Optional[AutoencoderTiny] = None self._clip_image_embeds_unconditional: Union[List[str], None] = None self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" self.do_grad_scale = True if self.is_fine_tuning and self.is_bfloat: self.do_grad_scale = False if self.adapter_config is not None: if self.adapter_config.train: self.do_grad_scale = False # if self.train_config.dtype in ["fp16", "float16"]: # # patch the scaler to allow fp16 training # org_unscale_grads = self.scaler._unscale_grads_ # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): # return org_unscale_grads(optimizer, inv_scale, found_inf, True) # self.scaler._unscale_grads_ = _unscale_grads_replacer self.cached_blank_embeds: Optional[PromptEmbeds] = None self.cached_trigger_embeds: Optional[PromptEmbeds] = None self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None self.dfe: Optional[DiffusionFeatureExtractor] = None if self.train_config.diff_output_preservation: if self.trigger_word is None: raise ValueError("diff_output_preservation requires a trigger_word to be set") if self.network_config is None: raise ValueError("diff_output_preservation requires a network to be set") if self.train_config.train_text_encoder: raise ValueError("diff_output_preservation is not supported with train_text_encoder") # always do a prior prediction when doing diff output preservation self.do_prior_prediction = True def before_model_load(self): pass def before_dataset_load(self): self.assistant_adapter = None # get adapter assistant if one is set if self.train_config.adapter_assist_name_or_path is not None: adapter_path = self.train_config.adapter_assist_name_or_path if self.train_config.adapter_assist_type == "t2i": # dont name this adapter since we are not training it self.assistant_adapter = T2IAdapter.from_pretrained( adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) ).to(self.device_torch) elif self.train_config.adapter_assist_type == "control_net": self.assistant_adapter = ControlNetModel.from_pretrained( adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) else: raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") self.assistant_adapter.eval() self.assistant_adapter.requires_grad_(False) flush() if self.train_config.train_turbo and self.train_config.show_turbo_outputs: if self.model_config.is_xl: self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=get_torch_dtype(self.train_config.dtype)) else: self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=get_torch_dtype(self.train_config.dtype)) self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch) self.taesd.eval() self.taesd.requires_grad_(False) def hook_before_train_loop(self): super().hook_before_train_loop() if self.train_config.do_prior_divergence: self.do_prior_prediction = True # move vae to device if we did not cache latents if not self.is_latents_cached: self.sd.vae.eval() self.sd.vae.to(self.device_torch) else: # offload it. Already cached self.sd.vae.to('cpu') flush() add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) if self.adapter is not None: self.adapter.to(self.device_torch) # check if we have regs and using adapter and caching clip embeddings has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) if has_reg and is_caching_clip_embeddings: # we need a list of unconditional clip image embeds from other datasets to handle regs unconditional_clip_image_embeds = [] datasets = get_dataloader_datasets(self.data_loader) for i in range(len(datasets)): unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache if len(unconditional_clip_image_embeds) == 0: raise ValueError("No unconditional clip image embeds found. This should not happen") self._clip_image_embeds_unconditional = unconditional_clip_image_embeds if self.train_config.negative_prompt is not None: if os.path.exists(self.train_config.negative_prompt): with open(self.train_config.negative_prompt, 'r') as f: self.negative_prompt_pool = f.readlines() # remove empty self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] else: # single prompt self.negative_prompt_pool = [self.train_config.negative_prompt] # handle unload text encoder if self.train_config.unload_text_encoder: with torch.no_grad(): if self.train_config.train_text_encoder: raise ValueError("Cannot unload text encoder if training text encoder") # cache embeddings print_acc("\n***** UNLOADING TEXT ENCODER *****") print_acc("This will train only with a blank prompt or trigger word, if set") print_acc("If this is not what you want, remove the unload_text_encoder flag") print_acc("***********************************") print_acc("") self.sd.text_encoder_to(self.device_torch) self.cached_blank_embeds = self.sd.encode_prompt("") if self.trigger_word is not None: self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word) if self.train_config.diff_output_preservation: self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) # move back to cpu self.sd.text_encoder_to('cpu') flush() if self.train_config.diffusion_feature_extractor_path is not None: vae = None if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) self.dfe.to(self.device_torch) if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: # must be set to train for gradient checkpointing to work self.dfe.vision_encoder.train() self.dfe.vision_encoder.gradient_checkpointing = True else: self.dfe.eval() # enable gradient checkpointing on the vae if vae is not None and self.train_config.gradient_checkpointing: vae.enable_gradient_checkpointing() vae.train() def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): # to process turbo learning, we make one big step from our current timestep to the end # we then denoise the prediction on that remaining step and target our loss to our target latents # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so. # needs to be done on each item in batch as they may all have different timesteps batch_size = pred.shape[0] pred_chunks = torch.chunk(pred, batch_size, dim=0) noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0) timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0) latent_chunks = torch.chunk(batch.latents, batch_size, dim=0) noise_chunks = torch.chunk(noise, batch_size, dim=0) with torch.no_grad(): # set the timesteps to 1000 so we can capture them to calculate the sigmas self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.config.num_train_timesteps, device=self.device_torch ) train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach() train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach() # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step self.sd.noise_scheduler.set_timesteps( 1, device=self.device_torch ) denoised_pred_chunks = [] target_pred_chunks = [] for i in range(batch_size): pred_item = pred_chunks[i] noisy_latents_item = noisy_latents_chunks[i] timesteps_item = timesteps_chunks[i] latents_item = latent_chunks[i] noise_item = noise_chunks[i] with torch.no_grad(): timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] single_step_timestep_schedule = [timesteps_item.squeeze().item()] # extract the sigma idx for our midpoint timestep sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) # add noise to our target # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach() self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach() # set our single timstep self.sd.noise_scheduler.timesteps = torch.from_numpy( np.array(single_step_timestep_schedule, dtype=np.float32) ).to(device=self.device_torch) # set the step index to None so it will be recalculated on first step self.sd.noise_scheduler._step_index = None denoised_latent = self.sd.noise_scheduler.step( pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False )[0] residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( self.train_config.dtype)) # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) denoised_latent = denoised_latent - residual_noise denoised_pred_chunks.append(denoised_latent) denoised_latents = torch.cat(denoised_pred_chunks, dim=0) # set the scheduler back to the original timesteps self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.config.num_train_timesteps, device=self.device_torch ) output = denoised_latents / self.sd.vae.config['scaling_factor'] output = self.sd.vae.decode(output).sample if self.train_config.show_turbo_outputs: # since we are completely denoising, we can show them here with torch.no_grad(): show_tensors(output) # we return our big partial step denoised latents as our pred and our untouched latents as our target. # you can do mse against the two here or run the denoised through the vae for pixel space loss against the # input tensor images. return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) # you can expand these in a child class to make customization easier def calculate_loss( self, noise_pred: torch.Tensor, noise: torch.Tensor, noisy_latents: torch.Tensor, timesteps: torch.Tensor, batch: 'DataLoaderBatchDTO', mask_multiplier: Union[torch.Tensor, float] = 1.0, prior_pred: Union[torch.Tensor, None] = None, **kwargs ): loss_target = self.train_config.loss_target is_reg = any(batch.get_is_reg_list()) additional_loss = 0.0 prior_mask_multiplier = None target_mask_multiplier = None dtype = get_torch_dtype(self.train_config.dtype) has_mask = batch.mask_tensor is not None with torch.no_grad(): loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) if self.train_config.match_noise_norm: # match the norm of the noise noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) if self.train_config.pred_scaler != 1.0: noise_pred = noise_pred * self.train_config.pred_scaler target = None if self.train_config.target_noise_multiplier != 1.0: noise = noise * self.train_config.target_noise_multiplier if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): if self.train_config.correct_pred_norm and not is_reg: with torch.no_grad(): # this only works if doing a prior pred if prior_pred is not None: prior_mean = prior_pred.mean([2,3], keepdim=True) prior_std = prior_pred.std([2,3], keepdim=True) noise_mean = noise_pred.mean([2,3], keepdim=True) noise_std = noise_pred.std([2,3], keepdim=True) mean_adjust = prior_mean - noise_mean std_adjust = prior_std - noise_std mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier target_mean = noise_mean + mean_adjust target_std = noise_std + std_adjust eps = 1e-5 # match the noise to the prior noise = (noise - noise_mean) / (noise_std + eps) noise = noise * (target_std + eps) + target_mean noise = noise.detach() if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: assert not self.train_config.train_turbo with torch.no_grad(): prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) # resize to size of noise_pred prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic') # stack first channel to match channels of noise_pred prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) prior_mask_multiplier = 1.0 - prior_mask # scale so it is a mean of 1 prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() if self.sd.is_flow_matching: target = (noise - batch.latents).detach() else: target = noise elif prior_pred is not None and not self.train_config.do_prior_divergence: assert not self.train_config.train_turbo # matching adapter prediction target = prior_pred elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) elif hasattr(self.sd, 'get_loss_target'): target = self.sd.get_loss_target( noise=noise, batch=batch, timesteps=timesteps, ).detach() elif self.sd.is_flow_matching: # forward ODE target = (noise - batch.latents).detach() # reverse ODE # target = (batch.latents - noise).detach() else: target = noise if self.dfe is not None: if self.dfe.version == 1: # do diffusion feature extraction on target with torch.no_grad(): rectified_flow_target = noise.float() - batch.latents.float() target_features = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) # do diffusion feature extraction on prediction pred_features = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) additional_loss += torch.nn.functional.mse_loss(pred_features, target_features, reduction="mean") * \ self.train_config.diffusion_feature_extractor_weight elif self.dfe.version == 2: # version 2 # do diffusion feature extraction on target with torch.no_grad(): rectified_flow_target = noise.float() - batch.latents.float() target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) # do diffusion feature extraction on prediction pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) dfe_loss = 0.0 for i in range(len(target_feature_list)): dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 elif self.dfe.version == 3: dfe_loss = self.dfe( noise=noise, noise_pred=noise_pred, noisy_latents=noisy_latents, timesteps=timesteps, batch=batch, scheduler=self.sd.noise_scheduler ) additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight else: raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") if target is None: target = noise pred = noise_pred if self.train_config.train_turbo: pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) ignore_snr = False if loss_target == 'source' or loss_target == 'unaugmented': assert not self.train_config.train_turbo # ignore_snr = True if batch.sigmas is None: raise ValueError("Batch sigmas is None. This should not happen") # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents weighing = batch.sigmas ** -2.0 if loss_target == 'source': # denoise the latent and compare to the latent in the batch target = batch.latents elif loss_target == 'unaugmented': # we have to encode images into latents for now # we also denoise as the unaugmented tensor is not a noisy diffirental with torch.no_grad(): unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier target = unaugmented_latents.detach() # Get the target for loss depending on the prediction type if self.sd.noise_scheduler.config.prediction_type == "epsilon": target = target # we are computing loss against denoise latents elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) else: raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") # mse loss without reduction loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: if self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") elif self.train_config.loss_type == "wavelet": loss = wavelet_loss(pred, batch.latents, noise) else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") # handle linear timesteps and only adjust the weight of the timesteps if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2): # calculate the weights for the timesteps timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( timesteps, v2=self.train_config.linear_timesteps2 ).to(loss.device, dtype=loss.dtype) timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() loss = loss * timestep_weight if self.train_config.do_prior_divergence and prior_pred is not None: loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) if self.train_config.train_turbo: mask_multiplier = mask_multiplier[:, 3:, :, :] # resize to the size of the loss mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') # multiply by our mask try: loss = loss * mask_multiplier except: # todo handle mask with video models pass prior_loss = None if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: assert not self.train_config.train_turbo if self.train_config.loss_type == "mae": prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") else: prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier if torch.isnan(prior_loss).any(): print_acc("Prior loss is nan") prior_loss = None else: prior_loss = prior_loss.mean([1, 2, 3]) # loss = loss + prior_loss # loss = loss + prior_loss # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) # apply loss multiplier before prior loss # multiply by our mask try: loss = loss * loss_multiplier except: # todo handle mask with video models pass if prior_loss is not None: loss = loss + prior_loss if not self.train_config.train_turbo: if self.train_config.learnable_snr_gos: # add snr_gamma loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: # add snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: # add min_snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() # check for additional losses if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: loss = loss + self.adapter.additional_loss.mean() self.adapter.additional_loss = None if self.train_config.target_norm_std: # seperate out the batch and channels pred_std = noise_pred.std([2, 3], keepdim=True) norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() loss = loss + norm_std_loss return loss + additional_loss def get_diff_output_preservation_loss( self, noise_pred: torch.Tensor, noise: torch.Tensor, noisy_latents: torch.Tensor, timesteps: torch.Tensor, batch: 'DataLoaderBatchDTO', mask_multiplier: Union[torch.Tensor, float] = 1.0, prior_pred: Union[torch.Tensor, None] = None, **kwargs ): loss_target = self.train_config.loss_target def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch def get_guided_loss( self, noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds, match_adapter_assist: bool, network_weight_list: list, timesteps: torch.Tensor, pred_kwargs: dict, batch: 'DataLoaderBatchDTO', noise: torch.Tensor, unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): loss = get_guidance_loss( noisy_latents=noisy_latents, conditional_embeds=conditional_embeds, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, pred_kwargs=pred_kwargs, batch=batch, noise=noise, sd=self.sd, unconditional_embeds=unconditional_embeds, train_config=self.train_config, **kwargs ) return loss def get_prior_prediction( self, noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds, match_adapter_assist: bool, network_weight_list: list, timesteps: torch.Tensor, pred_kwargs: dict, batch: 'DataLoaderBatchDTO', noise: torch.Tensor, unconditional_embeds: Optional[PromptEmbeds] = None, conditioned_prompts=None, **kwargs ): # todo for embeddings, we need to run without trigger words was_unet_training = self.sd.unet.training was_network_active = False if self.network is not None: was_network_active = self.network.is_active self.network.is_active = False can_disable_adapter = False was_adapter_active = False if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter) or (isinstance(self.adapter, CustomAdapter)) ): can_disable_adapter = True was_adapter_active = self.adapter.is_active self.adapter.is_active = False if self.train_config.unload_text_encoder and self.adapter is not None and not isinstance(self.adapter, CustomAdapter): raise ValueError("Prior predictions currently do not support unloading text encoder with adapter") # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) embeds_to_use = conditional_embeds.clone().detach() # handle clip vision adapter by removing triggers from prompt and replacing with the class name if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: prompt_list = batch.get_caption_list() class_name = '' triggers = ['[trigger]', '[name]'] remove_tokens = [] if self.embed_config is not None: triggers.append(self.embed_config.trigger) for i in range(1, self.embed_config.tokens): remove_tokens.append(f"{self.embed_config.trigger}_{i}") if self.embed_config.trigger_class_name is not None: class_name = self.embed_config.trigger_class_name if self.adapter is not None: triggers.append(self.adapter_config.trigger) for i in range(1, self.adapter_config.num_tokens): remove_tokens.append(f"{self.adapter_config.trigger}_{i}") if self.adapter_config.trigger_class_name is not None: class_name = self.adapter_config.trigger_class_name for idx, prompt in enumerate(prompt_list): for remove_token in remove_tokens: prompt = prompt.replace(remove_token, '') for trigger in triggers: prompt = prompt.replace(trigger, class_name) prompt_list[idx] = prompt embeds_to_use = self.sd.encode_prompt( prompt_list, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype).detach() # dont use network on this # self.network.multiplier = 0.0 self.sd.unet.eval() if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2: # we need to remove the image embeds from the prompt except for flux embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.clone().detach() unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, rescale_cfg=self.train_config.cfg_rescale, batch=batch, **pred_kwargs # adapter residuals in here ) if was_unet_training: self.sd.unet.train() prior_pred = prior_pred.detach() # remove the residuals as we wont use them on prediction when matching control if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: del pred_kwargs['down_intrablock_additional_residuals'] if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: del pred_kwargs['down_block_additional_residuals'] if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: del pred_kwargs['mid_block_additional_residual'] if can_disable_adapter: self.adapter.is_active = was_adapter_active # restore network # self.network.multiplier = network_weight_list if self.network is not None: self.network.is_active = was_network_active return prior_pred def before_unet_predict(self): pass def after_unet_predict(self): pass def end_of_training_loop(self): pass def predict_noise( self, noisy_latents: torch.Tensor, timesteps: Union[int, torch.Tensor] = 1, conditional_embeds: Union[PromptEmbeds, None] = None, unconditional_embeds: Union[PromptEmbeds, None] = None, batch: Optional['DataLoaderBatchDTO'] = None, **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) return self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, guidance_embedding_scale=self.train_config.cfg_scale, detach_unconditional=False, rescale_cfg=self.train_config.cfg_rescale, bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, batch=batch, **kwargs ) def train_single_accumulation(self, batch: DataLoaderBatchDTO): self.timer.start('preprocess_batch') if isinstance(self.adapter, CustomAdapter): batch = self.adapter.edit_batch_raw(batch) batch = self.preprocess_batch(batch) if isinstance(self.adapter, CustomAdapter): batch = self.adapter.edit_batch_processed(batch) dtype = get_torch_dtype(self.train_config.dtype) # sanity check if self.sd.vae.dtype != self.sd.vae_torch_dtype: self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) if isinstance(self.sd.text_encoder, list): for encoder in self.sd.text_encoder: if encoder.dtype != self.sd.te_torch_dtype: encoder.to(self.sd.te_torch_dtype) else: if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: self.sd.text_encoder.to(self.sd.te_torch_dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) if self.train_config.do_cfg or self.train_config.do_random_cfg: # pick random negative prompts if self.negative_prompt_pool is not None: negative_prompts = [] for i in range(noisy_latents.shape[0]): num_neg = random.randint(1, self.train_config.max_negative_prompts) this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] this_neg_prompt = ', '.join(this_neg_prompts) negative_prompts.append(this_neg_prompt) self.batch_negative_prompt = negative_prompts else: self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] if self.adapter and isinstance(self.adapter, CustomAdapter): # condition the prompt # todo handle more than one adapter image conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) network_weight_list = batch.get_network_weight_list() if self.train_config.single_item_batching: network_weight_list = network_weight_list + network_weight_list has_adapter_img = batch.control_tensor is not None has_clip_image = batch.clip_image_tensor is not None has_clip_image_embeds = batch.clip_image_embeds is not None # force it to be true if doing regs as we handle those differently if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): has_clip_image = True if self._clip_image_embeds_unconditional is not None: has_clip_image_embeds = True # we are caching embeds, handle that differently has_clip_image = False if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: raise ValueError( "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") match_adapter_assist = False # check if we are matching the adapter assistant if self.assistant_adapter: if self.train_config.match_adapter_chance == 1.0: match_adapter_assist = True elif self.train_config.match_adapter_chance > 0.0: match_adapter_assist = torch.rand( (1,), device=self.device_torch, dtype=dtype ) < self.train_config.match_adapter_chance self.timer.stop('preprocess_batch') is_reg = False with torch.no_grad(): loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) for idx, file_item in enumerate(batch.file_items): if file_item.is_reg: loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight is_reg = True adapter_images = None sigmas = None if has_adapter_img and (self.adapter or self.assistant_adapter): with self.timer('get_adapter_images'): # todo move this to data loader if batch.control_tensor is not None: adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() # match in channels if self.assistant_adapter is not None: in_channels = self.assistant_adapter.config.in_channels if adapter_images.shape[1] != in_channels: # we need to match the channels adapter_images = adapter_images[:, :in_channels, :, :] else: raise NotImplementedError("Adapter images now must be loaded with dataloader") clip_images = None if has_clip_image: with self.timer('get_clip_images'): # todo move this to data loader if batch.clip_image_tensor is not None: clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) if batch.mask_tensor is not None: with self.timer('get_mask_multiplier'): # upsampling no supported for bfloat16 mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) mask_multiplier = torch.nn.functional.interpolate( mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) ) # expand to match latents mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() # make avg 1.0 mask_multiplier = mask_multiplier / mask_multiplier.mean() def get_adapter_multiplier(): if self.adapter and isinstance(self.adapter, T2IAdapter): # training a t2i adapter, not using as assistant. return 1.0 elif match_adapter_assist: # training a texture. We want it high adapter_strength_min = 0.9 adapter_strength_max = 1.0 else: # training with assistance, we want it low # adapter_strength_min = 0.4 # adapter_strength_max = 0.7 adapter_strength_min = 0.5 adapter_strength_max = 1.1 adapter_conditioning_scale = torch.rand( (1,), device=self.device_torch, dtype=dtype ) adapter_conditioning_scale = value_map( adapter_conditioning_scale, 0.0, 1.0, adapter_strength_min, adapter_strength_max ) return adapter_conditioning_scale # flush() with self.timer('grad_setup'): # text encoding grad_on_text_encoder = False if self.train_config.train_text_encoder: grad_on_text_encoder = True if self.embedding is not None: grad_on_text_encoder = True if self.adapter and isinstance(self.adapter, ClipVisionAdapter): grad_on_text_encoder = True if self.adapter_config and self.adapter_config.type == 'te_augmenter': grad_on_text_encoder = True # have a blank network so we can wrap it in a context and set multipliers without checking every time if self.network is not None: network = self.network else: network = BlankNetwork() # set the weights network.multiplier = network_weight_list # activate network if it exits prompts_1 = conditioned_prompts prompts_2 = None if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: prompts_1 = batch.get_caption_short_list() prompts_2 = conditioned_prompts # make the batch splits if self.train_config.single_item_batching: if self.model_config.refiner_name_or_path is not None: raise ValueError("Single item batching is not supported when training the refiner") batch_size = noisy_latents.shape[0] # chunk/split everything noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) noise_list = torch.chunk(noise, batch_size, dim=0) timesteps_list = torch.chunk(timesteps, batch_size, dim=0) conditioned_prompts_list = [[prompt] for prompt in prompts_1] if imgs is not None: imgs_list = torch.chunk(imgs, batch_size, dim=0) else: imgs_list = [None for _ in range(batch_size)] if adapter_images is not None: adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) else: adapter_images_list = [None for _ in range(batch_size)] if clip_images is not None: clip_images_list = torch.chunk(clip_images, batch_size, dim=0) else: clip_images_list = [None for _ in range(batch_size)] mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) if prompts_2 is None: prompt_2_list = [None for _ in range(batch_size)] else: prompt_2_list = [[prompt] for prompt in prompts_2] else: noisy_latents_list = [noisy_latents] noise_list = [noise] timesteps_list = [timesteps] conditioned_prompts_list = [prompts_1] imgs_list = [imgs] adapter_images_list = [adapter_images] clip_images_list = [clip_images] mask_multiplier_list = [mask_multiplier] if prompts_2 is None: prompt_2_list = [None] else: prompt_2_list = [prompts_2] for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( noisy_latents_list, noise_list, timesteps_list, conditioned_prompts_list, imgs_list, adapter_images_list, clip_images_list, mask_multiplier_list, prompt_2_list ): # if self.train_config.negative_prompt is not None: # # add negative prompt # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in # range(len(conditioned_prompts))] # if prompt_2 is not None: # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] with (network): # encode clip adapter here so embeds are active for tokenizer if self.adapter and isinstance(self.adapter, ClipVisionAdapter): with self.timer('encode_clip_vision_embeds'): if has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, has_been_preprocessed=True ) else: # just do a blank one conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( torch.zeros( (noisy_latents.shape[0], 3, 512, 512), device=self.device_torch, dtype=dtype ), is_training=True, has_been_preprocessed=True, drop=True ) # it will be injected into the tokenizer when called self.adapter(conditional_clip_embeds) # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg): quad_count = random.randint(1, 4) self.adapter.train() self.adapter.trigger_pre_te( tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise is_training=True, has_been_preprocessed=True, quad_count=quad_count, batch_tensor=batch.tensor if not is_reg else None, batch_size=noisy_latents.shape[0] ) with self.timer('encode_prompt'): unconditional_embeds = None if self.train_config.unload_text_encoder: with torch.set_grad_enabled(False): embeds_to_use = self.cached_blank_embeds.clone().detach().to( self.device_torch, dtype=dtype ) if self.cached_trigger_embeds is not None and not is_reg: embeds_to_use = self.cached_trigger_embeds.clone().detach().to( self.device_torch, dtype=dtype ) conditional_embeds = concat_prompt_embeds( [embeds_to_use] * noisy_latents.shape[0] ) if self.train_config.do_cfg: unconditional_embeds = self.cached_blank_embeds.clone().detach().to( self.device_torch, dtype=dtype ) unconditional_embeds = concat_prompt_embeds( [unconditional_embeds] * noisy_latents.shape[0] ) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False elif grad_on_text_encoder: with torch.set_grad_enabled(True): if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if self.train_config.do_cfg: if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = True # todo only do one and repeat it unconditional_embeds = self.sd.encode_prompt( self.batch_negative_prompt, self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False else: with torch.set_grad_enabled(False): # make sure it is in eval mode if isinstance(self.sd.text_encoder, list): for te in self.sd.text_encoder: te.eval() else: self.sd.text_encoder.eval() if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if self.train_config.do_cfg: if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = True unconditional_embeds = self.sd.encode_prompt( self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False if self.train_config.diff_output_preservation: dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] dop_prompts_2 = None if prompt_2 is not None: dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] self.diff_output_preservation_embeds = self.sd.encode_prompt( dop_prompts, dop_prompts_2, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) # detach the embeddings conditional_embeds = conditional_embeds.detach() if self.train_config.do_cfg: unconditional_embeds = unconditional_embeds.detach() if self.decorator: conditional_embeds.text_embeds = self.decorator( conditional_embeds.text_embeds ) if self.train_config.do_cfg: unconditional_embeds.text_embeds = self.decorator( unconditional_embeds.text_embeds, is_unconditional=True ) # flush() pred_kwargs = {} if has_adapter_img: if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): with torch.set_grad_enabled(self.adapter is not None): adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter adapter_multiplier = get_adapter_multiplier() with self.timer('encode_adapter'): down_block_additional_residuals = adapter(adapter_images) if self.assistant_adapter: # not training. detach down_block_additional_residuals = [ sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals ] else: down_block_additional_residuals = [ sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals ] pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): # number of images to do if doing a quad image quad_count = random.randint(1, 4) image_size = self.adapter.input_size if has_clip_image_embeds: # todo handle reg images better than this if is_reg: # get unconditional image embeds from cache embeds = [ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) ] conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( embeds, quad_count=quad_count ) if self.train_config.do_cfg: embeds = [ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) ] unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( embeds, quad_count=quad_count ) else: conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( batch.clip_image_embeds, quad_count=quad_count ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( batch.clip_image_embeds_unconditional, quad_count=quad_count ) elif is_reg: # we will zero it out in the img embedder clip_images = torch.zeros( (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach() # drop will zero it out conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images, drop=True, is_training=True, has_been_preprocessed=False, quad_count=quad_count ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( torch.zeros( (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach(), is_training=True, drop=True, has_been_preprocessed=False, quad_count=quad_count ) elif has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, has_been_preprocessed=True, quad_count=quad_count, # do cfg on clip embeds to normalize the embeddings for when doing cfg # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, drop=True, has_been_preprocessed=True, quad_count=quad_count ) else: print_acc("No Clip Image") print_acc([file_item.path for file_item in batch.file_items]) raise ValueError("Could not find clip image") if not self.adapter_config.train_image_encoder: # we are not training the image encoder, so we need to detach the embeds conditional_clip_embeds = conditional_clip_embeds.detach() if self.train_config.do_cfg: unconditional_clip_embeds = unconditional_clip_embeds.detach() with self.timer('encode_adapter'): self.adapter.train() conditional_embeds = self.adapter( conditional_embeds.detach(), conditional_clip_embeds, is_unconditional=False ) if self.train_config.do_cfg: unconditional_embeds = self.adapter( unconditional_embeds.detach(), unconditional_clip_embeds, is_unconditional=True ) else: # wipe out unconsitional self.adapter.last_unconditional = None if self.adapter and isinstance(self.adapter, ReferenceAdapter): # pass in our scheduler self.adapter.noise_scheduler = self.lr_scheduler if has_clip_image or has_adapter_img: img_to_use = clip_images if has_clip_image else adapter_images # currently 0-1 needs to be -1 to 1 reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) self.adapter.set_reference_images(reference_images) self.adapter.noise_scheduler = self.sd.noise_scheduler elif is_reg: self.adapter.set_blank_reference_images(noisy_latents.shape[0]) else: self.adapter.set_reference_images(None) prior_pred = None do_reg_prior = False # if is_reg and (self.network is not None or self.adapter is not None): # # we are doing a reg image and we have a network or adapter # do_reg_prior = True do_inverted_masked_prior = False if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: do_inverted_masked_prior = True do_correct_pred_norm_prior = self.train_config.correct_pred_norm do_guidance_prior = False if batch.unconditional_latents is not None: # for this not that, we need a prior pred to normalize guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type if guidance_type == 'tnt': do_guidance_prior = True if (( has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): with self.timer('prior predict'): prior_embeds_to_use = conditional_embeds # use diff_output_preservation embeds if doing dfe if self.train_config.diff_output_preservation: prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, conditional_embeds=prior_embeds_to_use, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, pred_kwargs=pred_kwargs, noise=noise, batch=batch, unconditional_embeds=unconditional_embeds, conditioned_prompts=conditioned_prompts ) if prior_pred is not None: prior_pred = prior_pred.detach() # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']): quad_count = random.randint(1, 4) self.adapter.train() conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=clip_images, prompt_embeds=conditional_embeds, is_training=True, has_been_preprocessed=True, quad_count=quad_count ) if self.train_config.do_cfg and unconditional_embeds is not None: unconditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=clip_images, prompt_embeds=unconditional_embeds, is_training=True, has_been_preprocessed=True, is_unconditional=True, quad_count=quad_count ) if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: self.adapter.add_extra_values(batch.extra_values.detach()) if self.train_config.do_cfg: self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True) if has_adapter_img: if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): if self.train_config.do_cfg: raise ValueError("ControlNetModel is not supported with CFG") with torch.set_grad_enabled(self.adapter is not None): adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter adapter_multiplier = get_adapter_multiplier() with self.timer('encode_adapter'): # add_text_embeds is pooled_prompt_embeds for sdxl added_cond_kwargs = {} if self.sd.is_xl: added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) down_block_res_samples, mid_block_res_sample = adapter( noisy_latents, timesteps, encoder_hidden_states=conditional_embeds.text_embeds, controlnet_cond=adapter_images, conditioning_scale=1.0, guess_mode=False, added_cond_kwargs=added_cond_kwargs, return_dict=False, ) pred_kwargs['down_block_additional_residuals'] = down_block_res_samples pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later if batch.unconditional_latents is not None or self.do_guided_loss: # do guided loss loss = self.get_guided_loss( noisy_latents=noisy_latents, conditional_embeds=conditional_embeds, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, pred_kwargs=pred_kwargs, batch=batch, noise=noise, unconditional_embeds=unconditional_embeds, mask_multiplier=mask_multiplier, prior_pred=prior_pred, ) else: if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() with self.timer('condition_noisy_latents'): # do it for the model noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) if self.adapter and isinstance(self.adapter, CustomAdapter): noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) with self.timer('predict_unet'): noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), timesteps=timesteps, conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, batch=batch, **pred_kwargs ) self.after_unet_predict() with self.timer('calculate_loss'): noise = noise.to(self.device_torch, dtype=dtype).detach() prior_to_calculate_loss = prior_pred # if we are doing diff_output_preservation and not noing inverted masked prior # then we need to send none here so it will not target the prior if self.train_config.diff_output_preservation and not do_inverted_masked_prior: prior_to_calculate_loss = None loss = self.calculate_loss( noise_pred=noise_pred, noise=noise, noisy_latents=noisy_latents, timesteps=timesteps, batch=batch, mask_multiplier=mask_multiplier, prior_pred=prior_to_calculate_loss, ) if self.train_config.diff_output_preservation: # send the loss backwards otherwise checkpointing will fail self.accelerator.backward(loss) normal_loss = loss.detach() # dont send backward again dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) dop_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), timesteps=timesteps, conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, batch=batch, **pred_kwargs ) dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier self.accelerator.backward(dop_loss) loss = normal_loss + dop_loss loss = loss.clone().detach() # require grad again so the backward wont fail loss.requires_grad_(True) # check if nan if torch.isnan(loss): print_acc("loss is nan") loss = torch.zeros_like(loss).requires_grad_(True) with self.timer('backward'): # todo we have multiplier seperated. works for now as res are not in same batch, but need to change loss = loss * loss_multiplier.mean() # IMPORTANT if gradient checkpointing do not leave with network when doing backward # it will destroy the gradients. This is because the network is a context manager # and will change the multipliers back to 0.0 when exiting. They will be # 0.0 for the backward pass and the gradients will be 0.0 # I spent weeks on fighting this. DON'T DO IT # with fsdp_overlap_step_with_backward(): # if self.is_bfloat: # loss.backward() # else: self.accelerator.backward(loss) return loss.detach() # flush() def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): if isinstance(batch, list): batch_list = batch else: batch_list = [batch] total_loss = None self.optimizer.zero_grad() for batch in batch_list: loss = self.train_single_accumulation(batch) if total_loss is None: total_loss = loss else: total_loss += loss if len(batch_list) > 1 and self.model_config.low_vram: torch.cuda.empty_cache() if not self.is_grad_accumulation_step: # fix this for multi params if self.train_config.optimizer != 'adafactor': if isinstance(self.params[0], dict): for i in range(len(self.params)): self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) else: self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) if self.adapter and isinstance(self.adapter, CustomAdapter): self.adapter.post_weight_update() if self.ema is not None: with self.timer('ema_update'): self.ema.update() else: # gradient accumulation. Just a place for breakpoint pass # TODO Should we only step scheduler on grad step? If so, need to recalculate last step with self.timer('scheduler_step'): self.lr_scheduler.step() if self.embedding is not None: with self.timer('restore_embeddings'): # Let's make sure we don't update any embedding weights besides the newly added token self.embedding.restore_embeddings() if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): with self.timer('restore_adapter'): # Let's make sure we don't update any embedding weights besides the newly added token self.adapter.restore_embeddings() loss_dict = OrderedDict( {'loss': loss.item()} ) self.end_of_training_loop() return loss_dict