Spaces:
Paused
Paused
import random | |
from collections import OrderedDict | |
from torch.utils.data import DataLoader | |
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds | |
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork | |
from toolkit.train_tools import get_torch_dtype, apply_snr_weight | |
import gc | |
import torch | |
from jobs.process import BaseSDTrainProcess | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
class ConceptReplacementConfig: | |
def __init__(self, **kwargs): | |
self.concept: str = kwargs.get('concept', '') | |
self.replacement: str = kwargs.get('replacement', '') | |
class ConceptReplacer(BaseSDTrainProcess): | |
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): | |
super().__init__(process_id, job, config, **kwargs) | |
replacement_list = self.config.get('replacements', []) | |
self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list] | |
def before_model_load(self): | |
pass | |
def hook_before_train_loop(self): | |
self.sd.vae.eval() | |
self.sd.vae.to(self.device_torch) | |
# textual inversion | |
if self.embedding is not None: | |
# set text encoder to train. Not sure if this is necessary but diffusers example did it | |
self.sd.text_encoder.train() | |
def hook_train_loop(self, batch): | |
with torch.no_grad(): | |
dtype = get_torch_dtype(self.train_config.dtype) | |
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) | |
network_weight_list = batch.get_network_weight_list() | |
# have a blank network so we can wrap it in a context and set multipliers without checking every time | |
if self.network is not None: | |
network = self.network | |
else: | |
network = BlankNetwork() | |
batch_replacement_list = [] | |
# get a random replacement for each prompt | |
for prompt in conditioned_prompts: | |
replacement = random.choice(self.replacement_list) | |
batch_replacement_list.append(replacement) | |
# build out prompts | |
concept_prompts = [] | |
replacement_prompts = [] | |
for idx, replacement in enumerate(batch_replacement_list): | |
prompt = conditioned_prompts[idx] | |
# insert shuffled concept at beginning and end of prompt | |
shuffled_concept = [x.strip() for x in replacement.concept.split(',')] | |
random.shuffle(shuffled_concept) | |
shuffled_concept = ', '.join(shuffled_concept) | |
concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}") | |
# insert replacement at beginning and end of prompt | |
shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')] | |
random.shuffle(shuffled_replacement) | |
shuffled_replacement = ', '.join(shuffled_replacement) | |
replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}") | |
# predict the replacement without network | |
conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype) | |
replacement_pred = self.sd.predict_noise( | |
latents=noisy_latents.to(self.device_torch, dtype=dtype), | |
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), | |
timestep=timesteps, | |
guidance_scale=1.0, | |
) | |
del conditional_embeds | |
replacement_pred = replacement_pred.detach() | |
self.optimizer.zero_grad() | |
flush() | |
# text encoding | |
grad_on_text_encoder = False | |
if self.train_config.train_text_encoder: | |
grad_on_text_encoder = True | |
if self.embedding: | |
grad_on_text_encoder = True | |
# set the weights | |
network.multiplier = network_weight_list | |
# activate network if it exits | |
with network: | |
with torch.set_grad_enabled(grad_on_text_encoder): | |
# embed the prompts | |
conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype) | |
if not grad_on_text_encoder: | |
# detach the embeddings | |
conditional_embeds = conditional_embeds.detach() | |
self.optimizer.zero_grad() | |
flush() | |
noise_pred = self.sd.predict_noise( | |
latents=noisy_latents.to(self.device_torch, dtype=dtype), | |
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), | |
timestep=timesteps, | |
guidance_scale=1.0, | |
) | |
loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none") | |
loss = loss.mean([1, 2, 3]) | |
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: | |
# add min_snr_gamma | |
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) | |
loss = loss.mean() | |
# back propagate loss to free ram | |
loss.backward() | |
flush() | |
# apply gradients | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
self.lr_scheduler.step() | |
if self.embedding is not None: | |
# Let's make sure we don't update any embedding weights besides the newly added token | |
self.embedding.restore_embeddings() | |
loss_dict = OrderedDict( | |
{'loss': loss.item()} | |
) | |
# reset network multiplier | |
network.multiplier = 1.0 | |
return loss_dict | |